Skip to content

Commit 6eb97f0

Browse files
authored
[MLIR][NVGPU] Improve and Cleanup verifier of TMA OPs (#70923)
This PR improves and cleans-up verifiers of TmaCreateDescriptor and TmaAsyncLoad Ops and unifies them. The PR verifiers followings that didn't before: - address space - rank match between descriptor and memref - element type match between descriptor and memref - shape type match between descriptor and memref
1 parent 96b5e09 commit 6eb97f0

File tree

4 files changed

+118
-31
lines changed

4 files changed

+118
-31
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ constexpr int kWarpSize = 32;
2525

2626
/// M size of wgmma.mma_async instruction
2727
constexpr int kWgmmaSizeM = 64;
28+
/// Maximum tensor dimension that TMA supports
29+
constexpr int kMaxTMATensorDimension = 5;
2830

2931
#define GET_ATTRDEF_CLASSES
3032
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -335,34 +335,83 @@ LogicalResult LdMatrixOp::verify() {
335335
// NVGPU_TmaAsyncLoadOp
336336
//===----------------------------------------------------------------------===//
337337

338-
LogicalResult TmaAsyncLoadOp::verify() {
339-
// Destination memref
340-
auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
338+
std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
339+
Operation *op, nvgpu::TensorMapDescriptorType descType,
340+
std::optional<MemRefType> memrefType = std::nullopt) {
341+
MemRefType descMemref = descType.getTensor();
342+
// Limitation
343+
if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
344+
return op->emitError() << "Interleave options are not supported yet.";
345+
346+
// Address space check for shared memory check
347+
if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
348+
return op->emitError() << "the tensor map descriptor has incorrect address "
349+
"space, it must be shared memory address space.";
350+
}
351+
// Support only static shape for the time being
352+
if (!descMemref.hasStaticShape())
353+
return op->emitError() << "the tensor map descriptor must be static shaped";
354+
355+
// No verification if memref type is not provided
356+
if (!memrefType.has_value())
357+
return std::nullopt;
358+
359+
MemRefType dstMemref = memrefType.value();
360+
361+
// Check element type
362+
if (descMemref.getElementType() != dstMemref.getElementType()) {
363+
return op->emitError() << "the element type of tensor map descriptor and "
364+
"memref must be same";
365+
}
366+
341367
if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
342-
return emitError()
343-
<< "The operation stores data to shared memory, but "
344-
"the destination memref does not have a memory space of "
345-
<< NVGPUDialect::kSharedMemoryAddressSpace;
368+
return op->emitError() << "the destination memref has incorrect address "
369+
"space, it must be shared memory address space.";
346370
}
347-
if (getCoordinates().size() > 5) {
348-
return emitError() << "Maximum 5 coordinates are supported.";
371+
if (!dstMemref.hasStaticShape())
372+
return op->emitError() << "the destination memref must be static shaped";
373+
374+
if (dstMemref.getRank() != descMemref.getRank()) {
375+
return op->emitError() << "the shape of tensor map descriptor and "
376+
"memref must have same rank";
349377
}
350-
if (getCoordinates().size() != size_t(dstMemref.getRank())) {
351-
return emitError() << "Destination memref rank is "
352-
<< size_t(dstMemref.getRank()) << " but there are "
353-
<< getCoordinates().size()
354-
<< " coordinates. They must match.";
378+
if (!descMemref.getShape().equals(dstMemref.getShape())) {
379+
return op->emitError() << "memref and tensor map shapes mismatch "
380+
<< descMemref << " != " << dstMemref;
355381
}
382+
383+
return std::nullopt;
384+
}
385+
386+
LogicalResult TmaAsyncLoadOp::verify() {
387+
std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
388+
*this, getTensorMapDescriptor().getType(), getDst().getType());
389+
if (error.has_value())
390+
return error.value();
391+
392+
if (getCoordinates().size() > kMaxTMATensorDimension) {
393+
return emitError() << "Maximum " << kMaxTMATensorDimension
394+
<< " coordinates are supported.";
395+
}
396+
if (getCoordinates().size() !=
397+
getTensorMapDescriptor().getType().getTensor().getRank()) {
398+
return emitError() << "number of coordinates do not match with the rank of "
399+
"tensor descriptor map.";
400+
}
401+
356402
return success();
357403
}
358404

359405
LogicalResult TmaCreateDescriptorOp::verify() {
360-
if (getBoxDimensions().size() > 5) {
361-
return emitError() << "Maximum 5 dimensional box is supported.";
406+
if (getBoxDimensions().size() > kMaxTMATensorDimension) {
407+
return emitError() << "Maximum " << kMaxTMATensorDimension
408+
<< " coordinates are supported.";
362409
}
363-
nvgpu::TensorMapDescriptorType desc = getTensorMap().getType();
364-
if (desc.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
365-
return emitError() << "Interleave options are not supported yet.";
410+
411+
std::optional<InFlightDiagnostic> error =
412+
verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
413+
if (error.has_value())
414+
return error.value();
366415

367416
return success();
368417
}
@@ -372,17 +421,10 @@ LogicalResult TmaCreateDescriptorOp::verify() {
372421
//===----------------------------------------------------------------------===//
373422

374423
LogicalResult WarpgroupGenerateDescriptorOp::verify() {
375-
MemRefType memrefType = getTensor().getType();
376-
MemRefType tensorMapType = getTensorMap().getType().getTensor();
377-
378-
if (memrefType != tensorMapType)
379-
return emitError() << "memref and tensor map type mismatch";
380-
381-
if (!memrefType.hasStaticShape() || !tensorMapType.hasStaticShape())
382-
return emitError() << "supports only static shapes";
383-
384-
if (memrefType.getRank() != 2)
385-
return emitError() << "supports only 2d memref is supported for now";
424+
std::optional<InFlightDiagnostic> error =
425+
verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
426+
if (error.has_value())
427+
return error.value();
386428

387429
if (getTensorMap().getType().getSwizzle() !=
388430
TensorMapSwizzleKind::SWIZZLE_128B) {

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ func.func @mbarrier_txcount_pred() {
640640
!tensorMap1d = !nvgpu.tensormap.descriptor<tensor = memref<128xf32,3>, swizzle=none, l2promo = none, oob = nan, interleave = none>
641641
!tensorMap2d = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
642642
!tensorMap3d = !nvgpu.tensormap.descriptor<tensor = memref<2x32x32xf32,3>, swizzle=swizzle_64b, l2promo = l2promo_64b, oob = zero, interleave = none>
643-
!tensorMap4d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x32x32xf32,3>, swizzle=swizzle_128b,l2promo = l2promo_128b,oob = zero, interleave = interleave_16b>
643+
!tensorMap4d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x32x32xf32,3>, swizzle=swizzle_128b,l2promo = l2promo_128b,oob = zero, interleave = none>
644644
!tensorMap5d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x2x32x32xf32,3>, swizzle=none, l2promo = none, oob = zero, interleave = none>
645645
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
646646
func.func @async_tma_load(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,

mlir/test/Dialect/NVGPU/invalid.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,46 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc: !tR
265265
%0 = nvgpu.warpgroup.mma %descA, %descB, %acc: !tDescA, !tDescB, !tResult -> !tResult
266266
return
267267
}
268+
269+
// -----
270+
271+
!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
272+
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
273+
func.func @tma_load_1(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
274+
%c0 = arith.constant 0 : index
275+
// Pass fine
276+
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
277+
// expected-error @+1 {{Maximum 5 coordinates are supported.}}
278+
nvgpu.tma.async.load %desc[%c0, %c0, %c0, %c0, %c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
279+
return
280+
}
281+
// -----
282+
283+
!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
284+
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
285+
func.func @tma_load_2(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
286+
%c0 = arith.constant 0 : index
287+
// expected-error @+1 {{the tensor map descriptor has incorrect address space, it must be shared memory address space.}}
288+
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
289+
return
290+
}
291+
// -----
292+
293+
!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
294+
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
295+
func.func @tma_load_3(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
296+
%c0 = arith.constant 0 : index
297+
// expected-error @+1 {{the destination memref has incorrect address space, it must be shared memory address space}}
298+
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer3 : !desc, !mbarrier -> memref<32x32xf32>
299+
return
300+
}
301+
// -----
302+
303+
!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
304+
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
305+
func.func @tma_load_4(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
306+
%c0 = arith.constant 0 : index
307+
// expected-error @+1 {{the shape of tensor map descriptor and memref must have same rank}}
308+
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer1 : !desc, !mbarrier -> memref<128xf32,3>
309+
return
310+
}

0 commit comments

Comments
 (0)