Skip to content

[MLIR][NVGPU] Improve and Cleanup verifier of TMA OPs #70923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 8, 2023

Conversation

grypp
Copy link
Member

@grypp grypp commented Nov 1, 2023

This PR improves and cleans-up verifiers of TmaCreateDescriptor and TmaAsyncLoad Ops and unifies them.

The PR verifiers followings that didn't before:

  • address space
  • rank match between descriptor and memref
  • element type match between descriptor and memref
  • shape type match between descriptor and memref

This PR improves and cleans-up verifiers of TmaCreateDescriptor and TmaAsyncLoad Ops and unifies them.

The PR verifiers followings that didn't before:
- address space
- rank match between descriptor and memref
- element type match between descriptor and memref
- shape type match between descriptor and memref
@llvmbot
Copy link
Member

llvmbot commented Nov 1, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-nvgpu

@llvm/pr-subscribers-mlir-gpu

Author: Guray Ozen (grypp)

Changes

This PR improves and cleans-up verifiers of TmaCreateDescriptor and TmaAsyncLoad Ops and unifies them.

The PR verifiers followings that didn't before:

  • address space
  • rank match between descriptor and memref
  • element type match between descriptor and memref
  • shape type match between descriptor and memref

Full diff: https://github.com/llvm/llvm-project/pull/70923.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h (+2)
  • (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+63-19)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+1-1)
  • (modified) mlir/test/Dialect/NVGPU/invalid.mlir (+43)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index e6bba7e6082964b..2888fed27795751 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -25,6 +25,8 @@ constexpr int kWarpSize = 32;
 
 /// M size of wgmma.mma_async instruction
 constexpr int kWgmmaSizeM = 64;
+/// Maximum tensor dimension that TMA supports
+constexpr int kMaxTMATensorDimension = 5;
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index f5b02fe1b515591..cec9a87f40929d2 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -335,34 +335,78 @@ LogicalResult LdMatrixOp::verify() {
 // NVGPU_TmaAsyncLoadOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult TmaAsyncLoadOp::verify() {
-  // Destination memref
-  auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
+std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
+    Operation *op, nvgpu::TensorMapDescriptorType descType,
+    std::optional<MemRefType> memrefType = std::nullopt) {
+  MemRefType descMemref = descType.getTensor();
+  // Limitation
+  if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
+    return op->emitError() << "Interleave options are not supported yet.";
+
+  // Address space check for shared memory check
+  if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
+    return op->emitError() << "the tensor map descriptor has incorrect address "
+                              "space, it must be shared memory address space.";
+  }
+  // Support only static shape for the time being
+  if (!descMemref.hasStaticShape())
+    return op->emitError() << "the tensor map descriptor must be static shaped";
+
+  // No verification if memref type is not provided
+  if (!memrefType.has_value())
+    return std::nullopt;
+
+  MemRefType dstMemref = memrefType.value();
+
+  // Check element type
+  if (descMemref.getElementType() != dstMemref.getElementType()) {
+    return op->emitError() << "the element type of tensor map descriptor and "
+                              "memref must be same";
+  }
+
   if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
-    return emitError()
-           << "The operation stores data to shared memory, but "
-              "the destination memref does not have a memory space of "
-           << NVGPUDialect::kSharedMemoryAddressSpace;
+    return op->emitError() << "the destination memref has incorrect address "
+                              "space, it must be shared memory address space.";
   }
-  if (getCoordinates().size() > 5) {
-    return emitError() << "Maximum 5 coordinates are supported.";
+  if (!dstMemref.hasStaticShape())
+    return op->emitError() << "the destination memref must be static shaped";
+
+  if (dstMemref.getRank() != descMemref.getRank()) {
+    return op->emitError() << "the shape of tensor map descriptor and "
+                              "memref must have same rank";
   }
-  if (getCoordinates().size() != size_t(dstMemref.getRank())) {
-    return emitError() << "Destination memref rank is "
-                       << size_t(dstMemref.getRank()) << " but there are  "
-                       << getCoordinates().size()
-                       << " coordinates. They must match.";
+  if (!descMemref.getShape().equals(dstMemref.getShape())) {
+    return op->emitError() << "memref and tensor map shapes mismatch "
+                           << descMemref << " != " << dstMemref;
   }
+
+  return std::nullopt;
+}
+
+LogicalResult TmaAsyncLoadOp::verify() {
+  std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
+      *this, getTensorMapDescriptor().getType(), getDst().getType());
+  if (error.has_value())
+    return error.value();
+
+  if (getCoordinates().size() > kMaxTMATensorDimension) {
+    return emitError() << "Maximum " << kMaxTMATensorDimension
+                       << " coordinates are supported.";
+  }
+
   return success();
 }
 
 LogicalResult TmaCreateDescriptorOp::verify() {
-  if (getBoxDimensions().size() > 5) {
-    return emitError() << "Maximum 5 dimensional box is supported.";
+  if (getBoxDimensions().size() > kMaxTMATensorDimension) {
+    return emitError() << "Maximum " << kMaxTMATensorDimension
+                       << " coordinates are supported.";
   }
-  nvgpu::TensorMapDescriptorType desc = getTensorMap().getType();
-  if (desc.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
-    return emitError() << "Interleave options are not supported yet.";
+
+  std::optional<InFlightDiagnostic> error =
+      verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
+  if (error.has_value())
+    return error.value();
 
   return success();
 }
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 123a661193c4901..0b2ef67bf0634bc 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -640,7 +640,7 @@ func.func @mbarrier_txcount_pred() {
 !tensorMap1d = !nvgpu.tensormap.descriptor<tensor = memref<128xf32,3>,         swizzle=none,        l2promo = none,        oob = nan,  interleave = none>
 !tensorMap2d = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>,       swizzle=swizzle_32b, l2promo = none,        oob = zero, interleave = none>
 !tensorMap3d = !nvgpu.tensormap.descriptor<tensor = memref<2x32x32xf32,3>,     swizzle=swizzle_64b, l2promo = l2promo_64b, oob = zero, interleave = none>
-!tensorMap4d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x32x32xf32,3>,   swizzle=swizzle_128b,l2promo = l2promo_128b,oob = zero, interleave = interleave_16b>
+!tensorMap4d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x32x32xf32,3>,   swizzle=swizzle_128b,l2promo = l2promo_128b,oob = zero, interleave = none>
 !tensorMap5d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x2x32x32xf32,3>, swizzle=none,        l2promo = none,        oob = zero, interleave = none>
 !mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
 func.func @async_tma_load(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d, 
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index 41b29fa74b125d4..7b402483db68474 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -265,3 +265,46 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc: !tR
   %0 = nvgpu.warpgroup.mma %descA, %descB, %acc: !tDescA, !tDescB, !tResult -> !tResult
   return
 }
+
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_load_1(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
+  %c0 = arith.constant 0 : index
+  // Pass fine
+  nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
+  // expected-error @+1 {{Maximum 5 coordinates are supported.}}
+  nvgpu.tma.async.load %desc[%c0, %c0, %c0, %c0, %c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
+  return
+}
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_load_2(%desc: !desc,  %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
+  %c0 = arith.constant 0 : index
+  // expected-error @+1 {{the tensor map descriptor has incorrect address space, it must be shared memory address space.}}
+  nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<32x32xf32,3>
+  return
+}
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_load_3(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
+  %c0 = arith.constant 0 : index
+  // expected-error @+1 {{the destination memref has incorrect address space, it must be shared memory address space}}
+  nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer3 : !desc, !mbarrier -> memref<32x32xf32>
+  return
+}
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_load_4(%desc: !desc,  %buffer1: memref<128xf32,3>, %buffer2: memref<32x32xf32,3>, %buffer3: memref<32x32xf32>, %mbarrier: !mbarrier) {
+  %c0 = arith.constant 0 : index
+  // expected-error @+1 {{the shape of tensor map descriptor and memref must have same rank}}
+  nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer1 : !desc, !mbarrier -> memref<128xf32,3>
+  return
+}
\ No newline at end of file

if (getCoordinates().size() > 5) {
return emitError() << "Maximum 5 coordinates are supported.";
if (!dstMemref.hasStaticShape())
return op->emitError() << "the destination memref must be static shaped";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move this check close to its descMemref counterpart.

Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@grypp grypp merged commit 6eb97f0 into llvm:main Nov 8, 2023
grypp added a commit that referenced this pull request Nov 10, 2023
#70923 improved verifier. The verifier caught that the tensor map type in the tma descriptor in this test isn't correct. The program was working correctly anway since the offset is calculated correctly.

This work fixes the test.
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
llvm#70923 improved verifier. The verifier caught that the tensor map type in the tma descriptor in this test isn't correct. The program was working correctly anway since the offset is calculated correctly.

This work fixes the test.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants