|
14 | 14 | #include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
15 | 15 | #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
|
16 | 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h"
|
| 17 | +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
17 | 18 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
|
18 | 19 | #include "mlir/Dialect/SCF/IR/SCF.h"
|
19 | 20 | #include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
@@ -237,25 +238,17 @@ DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp,
|
237 | 238 | std::optional<int64_t> blockDimZ) {
|
238 | 239 |
|
239 | 240 | // TODO: pass a configuration object to set the limits properly.
|
240 |
| - static constexpr int maxTotalBlockdim = 1024; |
241 |
| - static constexpr int maxBlockdimx = 1024; |
242 |
| - static constexpr int maxBlockdimy = 1024; |
243 |
| - static constexpr int maxBlockdimz = 64; |
244 |
| - static constexpr int maxTotalGriddim = 2147483647; |
245 |
| - static constexpr int maxGriddimx = 2147483647; |
246 |
| - static constexpr int maxGriddimy = 65535; |
247 |
| - static constexpr int maxGriddimz = 65535; |
248 | 241 |
|
249 | 242 | if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) >
|
250 |
| - maxTotalBlockdim || |
| 243 | + kMaxTotalBlockdim || |
251 | 244 | (gridDimX.value_or(1) * gridDimY.value_or(1) * gridDimZ.value_or(1)) >
|
252 |
| - maxTotalGriddim || |
253 |
| - blockDimX.value_or(1) > maxBlockdimx || |
254 |
| - blockDimY.value_or(1) > maxBlockdimy || |
255 |
| - blockDimZ.value_or(1) > maxBlockdimz || |
256 |
| - gridDimY.value_or(1) > maxGriddimy || |
257 |
| - gridDimZ.value_or(1) > maxGriddimz || |
258 |
| - gridDimX.value_or(1) > maxGriddimx) { |
| 245 | + kMaxTotalGriddim || |
| 246 | + blockDimX.value_or(1) > kMaxBlockdimx || |
| 247 | + blockDimY.value_or(1) > kMaxBlockdimy || |
| 248 | + blockDimZ.value_or(1) > kMaxBlockdimz || |
| 249 | + gridDimY.value_or(1) > kMaxGriddimy || |
| 250 | + gridDimZ.value_or(1) > kMaxGriddimz || |
| 251 | + gridDimX.value_or(1) > kMaxGriddimx) { |
259 | 252 | return transformOp.emitSilenceableError()
|
260 | 253 | << "Trying to launch a GPU kernel with grid_dims = ("
|
261 | 254 | << gridDimX.value_or(1) << ", " << gridDimY.value_or(1) << ", "
|
|
0 commit comments