-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][NVGPU] Improve and Cleanup verifier of TMA OPs #70923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR improves and cleans-up verifiers of TmaCreateDescriptor and TmaAsyncLoad Ops and unifies them. The PR verifiers followings that didn't before: - address space - rank match between descriptor and memref - element type match between descriptor and memref - shape type match between descriptor and memref
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Guray Ozen (grypp) ChangesThis PR improves and cleans-up verifiers of TmaCreateDescriptor and TmaAsyncLoad Ops and unifies them. The PR verifiers followings that didn't before:
Full diff: https://github.com/llvm/llvm-project/pull/70923.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index e6bba7e6082964b..2888fed27795751 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -25,6 +25,8 @@ constexpr int kWarpSize = 32;
/// M size of wgmma.mma_async instruction
constexpr int kWgmmaSizeM = 64;
+/// Maximum tensor dimension that TMA supports
+constexpr int kMaxTMATensorDimension = 5;
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index f5b02fe1b515591..cec9a87f40929d2 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -335,34 +335,78 @@ LogicalResult LdMatrixOp::verify() {
// NVGPU_TmaAsyncLoadOp
//===----------------------------------------------------------------------===//
-LogicalResult TmaAsyncLoadOp::verify() {
- // Destination memref
- auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
+std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
+ Operation *op, nvgpu::TensorMapDescriptorType descType,
+ std::optional<MemRefType> memrefType = std::nullopt) {
+ MemRefType descMemref = descType.getTensor();
+ // Limitation
+ if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
+ return op->emitError() << "Interleave options are not supported yet.";
+
+ // Address space check for shared memory check
+ if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
+ return op->emitError() << "the tensor map descriptor has incorrect address "
+ "space, it must be shared memory address space.";
+ }
+ // Support only static shape for the time being
+ if (!descMemref.hasStaticShape())
+ return op->emitError() << "the tensor map descriptor must be static shaped";
+
+ // No verification if memref type is not provided
+ if (!memrefType.has_value())
+ return std::nullopt;
+
+ MemRefType dstMemref = memrefType.value();
+
+ // Check element type
+ if (descMemref.getElementType() != dstMemref.getElementType()) {
+ return op->emitError() << "the element type of tensor map descriptor and "
+ "memref must be same";
+ }
+
if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
- return emitError()
- << "The operation stores data to shared memory, but "
- "the destination memref does not have a memory space of "
- << NVGPUDialect::kSharedMemoryAddressSpace;
+ return op->emitError() << "the destination memref has incorrect address "
+ "space, it must be shared memory address space.";
}
- if (getCoordinates().size() > 5) {
- return emitError() << "Maximum 5 coordinates are supported.";
+ if (!dstMemref.hasStaticShape())
+ return op->emitError() << "the destination memref must be static shaped";
+
+ if (dstMemref.getRank() != descMemref.getRank()) {
+ return op->emitError() << "the shape of tensor map descriptor and "
+ "memref must have same rank";
}
- if (getCoordinates().size() != size_t(dstMemref.getRank())) {
- return emitError() << "Destination memref rank is "
- << size_t(dstMemref.getRank()) << " but there are "
- << getCoordinates().size()
- << " coordinates. They must match.";
+ if (!descMemref.getShape().equals(dstMemref.getShape())) {
+ return op->emitError() << "memref and tensor map shapes mismatch "
+ << descMemref << " != " << dstMemref;
}
+
+ return std::nullopt;
+}
+
+LogicalResult TmaAsyncLoadOp::verify() {
+ std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
+ *this, getTensorMapDescriptor().getType(), getDst().getType());
+ if (error.has_value())
+ return error.value();
+
+ if (getCoordinates().size() > kMaxTMATensorDimension) {
+ return emitError() << "Maximum " << kMaxTMATensorDimension
+ << " coordinates are supported.";
+ }
+
return success();
}
LogicalResult TmaCreateDescriptorOp::verify() {
- if (getBoxDimensions().size() > 5) {
- return emitError() << "Maximum 5 dimensional box is supported.";
+ if (getBoxDimensions().size() > kMaxTMATensorDimension) {
+ return emitError() << "Maximum " << kMaxTMATensorDimension
+ << " coordinates are supported.";
}
- nvgpu::TensorMapDescriptorType desc = getTensorMap().getType();
- if (desc.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
- return emitError() << "Interleave options are not supported yet.";
+
+ std::optional<InFlightDiagnostic> error =
+ verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
+ if (error.has_value())
+ return error.value();
return success();
}
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 123a661193c4901..0b2ef67bf0634bc 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -640,7 +640,7 @@ func.func @mbarrier_txcount_pred() {
!tensorMap1d = !nvgpu.tensormap.descriptor<tensor = memref<128xf32,3>, swizzle=none, l2promo = none, oob = nan, interleave = none>
!tensorMap2d = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
!tensorMap3d = !nvgpu.tensormap.descriptor<tensor = memref<2x32x32xf32,3>, swizzle=swizzle_64b, l2promo = l2promo_64b, oob = zero, interleave = none>
-!tensorMap4d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x32x32xf32,3>, swizzle=swizzle_128b,l2promo = l2promo_128b,oob = zero, interleave = interleave_16b>
+!tensorMap4d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x32x32xf32,3>, swizzle=swizzle_128b,l2promo = l2promo_128b,oob = zero, interleave = none>
!tensorMap5d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x2x32x32xf32,3>, swizzle=none, l2promo = none, oob = zero, interleave = none>
!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
func.func @async_tma_load(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index 41b29fa74b125d4..7b402483db68474 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -265,3 +265,46 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc: !tR
%0 = nvgpu.warpgroup.mma %descA, %descB, %acc: !tDescA, !tDescB, !tResult -> !tResult
return
}
+
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_load_1(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
+ %c0 = arith.constant 0 : index
+ // Pass fine
+ nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
+ // expected-error @+1 {{Maximum 5 coordinates are supported.}}
+ nvgpu.tma.async.load %desc[%c0, %c0, %c0, %c0, %c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
+ return
+}
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_load_2(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
+ %c0 = arith.constant 0 : index
+ // expected-error @+1 {{the tensor map descriptor has incorrect address space, it must be shared memory address space.}}
+ nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
+ return
+}
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_load_3(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
+ %c0 = arith.constant 0 : index
+ // expected-error @+1 {{the destination memref has incorrect address space, it must be shared memory address space}}
+ nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer3 : !desc, !mbarrier -> memref<32x32xf32>
+ return
+}
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_load_4(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
+ %c0 = arith.constant 0 : index
+ // expected-error @+1 {{the shape of tensor map descriptor and memref must have same rank}}
+ nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer1 : !desc, !mbarrier -> memref<128xf32,3>
+ return
+}
\ No newline at end of file
|
if (getCoordinates().size() > 5) { | ||
return emitError() << "Maximum 5 coordinates are supported."; | ||
if (!dstMemref.hasStaticShape()) | ||
return op->emitError() << "the destination memref must be static shaped"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move this check close to its descMemref
counterpart.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
#70923 improved verifier. The verifier caught that the tensor map type in the tma descriptor in this test isn't correct. The program was working correctly anway since the offset is calculated correctly. This work fixes the test.
llvm#70923 improved verifier. The verifier caught that the tensor map type in the tma descriptor in this test isn't correct. The program was working correctly anway since the offset is calculated correctly. This work fixes the test.
This PR improves and cleans-up verifiers of TmaCreateDescriptor and TmaAsyncLoad Ops and unifies them.
The PR verifiers followings that didn't before: