Skip to content

[mlir][nvgpu] make TmaCreateDescriptorOp can use static box and add folder function to it. #135497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

linuxlonelyeagle
Copy link
Member

This PR make TmaCreateDescriptorOp can use static box,this should make it simpler to use and not show the creation of arith.constant in unnecessary cases.In addition, the folder function is introduced.It is used to fold dynamic constants into static constants via -canonicalize.

@llvmbot
Copy link
Member

llvmbot commented Apr 12, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-nvgpu

Author: lonely eagle (linuxlonelyeagle)

Changes

This PR make TmaCreateDescriptorOp can use static box,this should make it simpler to use and not show the creation of arith.constant in unnecessary cases.In addition, the folder function is introduced.It is used to fold dynamic constants into static constants via -canonicalize.


Full diff: https://github.com/llvm/llvm-project/pull/135497.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td (+4-2)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+9-1)
  • (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+47)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+2-1)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+26)
  • (modified) mlir/test/Dialect/NVGPU/canonicalization.mlir (+16-1)
  • (modified) mlir/test/Dialect/NVGPU/tmaload-transform.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
index 73d86283a5940..3f1f655c041f2 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
@@ -546,12 +546,14 @@ def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
   }];
 
   let arguments = (ins AnyUnrankedMemRef:$tensor,
-                       Variadic<Index>:$boxDimensions);
+                       Variadic<Index>:$boxDimensions,
+                       DenseI64ArrayAttr:$static_boxDimensions);
   let results = (outs NVGPU_TensorMapDescriptor:$tensorMap);
   let assemblyFormat = [{
-         $tensor `box` `[` $boxDimensions `]` attr-dict `:` type($tensor) `->` type($tensorMap)
+    $tensor `box` custom<DynamicIndexList>($boxDimensions, $static_boxDimensions) attr-dict `:` type($tensor) `->` type($tensorMap)
   }];
   let hasVerifier = 1;
+  let hasFolder = 1;
 }
 
 def NVGPU_WarpgroupGenerateDescriptorOp : NVGPU_Op<"warpgroup.generate.descriptor", []> {
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 69fa62c8196e4..a5e8efb745179 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1183,9 +1183,17 @@ struct NVGPUTmaCreateDescriptorOpLowering
 
     Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
                                                  makeI64Const(b, 5));
-    for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
+    unsigned idx = 0;
+    ValueRange dynamicDim = adaptor.getBoxDimensions();
+    for (auto [index, shape] :
+         llvm::enumerate(adaptor.getStaticBoxDimensions())) {
       Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
                                         boxArrayPtr, makeI64Const(b, index));
+      Value value;
+      if (ShapedType::isDynamic(shape))
+        value = dynamicDim[idx++];
+      else
+        value = makeI64Const(b, shape);
       b.create<LLVM::StoreOp>(value, gep);
     }
 
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index abbdb6a0f53ec..b09c51a6690a7 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -23,6 +23,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -458,6 +459,10 @@ LogicalResult TmaAsyncStoreOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// NVGPU_TmaAsyncStoreOp
+//===----------------------------------------------------------------------===//
+
 LogicalResult TmaCreateDescriptorOp::verify() {
   if (getBoxDimensions().size() > kMaxTMATensorDimension) {
     return emitError() << "Maximum " << kMaxTMATensorDimension
@@ -472,6 +477,48 @@ LogicalResult TmaCreateDescriptorOp::verify() {
   return success();
 }
 
+static Value
+TmaCreateDescriptorFoldBoxConstant(TmaCreateDescriptorOp op,
+                                   TmaCreateDescriptorOp::FoldAdaptor adaptor) {
+  std::vector<int64_t> staticBoxDimensions = op.getStaticBoxDimensions().vec();
+  OperandRange dynamicBoxDimensions = op.getBoxDimensions();
+  SmallVector<Value> operands = {op.getTensor()};
+  ArrayRef<Attribute> dynamicBoxDimensionAttrs = adaptor.getBoxDimensions();
+  if (staticBoxDimensions.empty())
+    return {};
+
+  // `opChange` is a flag. If it is true, it means to update `op` in place.
+  bool opChange = false;
+  unsigned idx = 0;
+
+  for (unsigned i = 0, e = staticBoxDimensions.size(); i < e; ++i) {
+    if (!ShapedType::isDynamic(staticBoxDimensions[i]))
+      continue;
+    Attribute dynamicBoxDimensionAttr = dynamicBoxDimensionAttrs[idx];
+    Value dynamicDimension = dynamicBoxDimensions[idx++];
+    if (auto attr =
+            mlir::dyn_cast_if_present<IntegerAttr>(dynamicBoxDimensionAttr)) {
+      staticBoxDimensions[i] = attr.getInt();
+      opChange = true;
+      continue;
+    }
+    operands.push_back(dynamicDimension);
+  }
+
+  if (opChange) {
+    op.setStaticBoxDimensions(staticBoxDimensions);
+    op.getOperation()->setOperands(operands);
+    return op.getResult();
+  }
+  return {};
+}
+
+OpFoldResult TmaCreateDescriptorOp::fold(FoldAdaptor adaptor) {
+  if (auto val = TmaCreateDescriptorFoldBoxConstant(*this, adaptor))
+    return val;
+  return OpFoldResult();
+}
+
 //===----------------------------------------------------------------------===//
 // NVGPU_WarpgroupGenerateDescriptorOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 556922a64b093..cce9a59d4a00c 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -962,6 +962,7 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
   SmallVector<Value> sizes =
       getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
 
+  SmallVector<int64_t> static_dims(sizes.size(), ShapedType::kDynamic);
   auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
   Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
       loc,
@@ -972,7 +973,7 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
           TensorMapSwizzleKind::SWIZZLE_NONE,
           TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
           TensorMapInterleaveKind::INTERLEAVE_NONE),
-      unrankedMemRef, sizes);
+      unrankedMemRef, sizes, static_dims);
   return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
 }
 
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index d0bc806e0aa8c..ffb7e62f250b0 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -813,6 +813,32 @@ func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : m
   func.return
 }
 
+func.func @create_tensor_map_constant_box_dim(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
+  %devicePtr2d_unranked = memref.cast %devicePtr2d : memref<64x128xf32> to memref<*xf32>
+  // CHECK: %[[C5_0:.*]] = llvm.mlir.constant(5 : i32) : i64
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[C5_0]] x i64 : (i64) -> !llvm.ptr
+  // CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i32) : i64
+  // CHECK: %[[GEP_0:.*]] = llvm.getelementptr %[[ALLOCA]]{{\[}}%[[C0_0]]] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.ptr
+  // CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : i32) : i64
+  // CHECK: llvm.store %[[C64]], %[[GEP_0]] : i64, !llvm.ptr
+  // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i64
+  // CHECK: %[[GEP_1:.*]] = llvm.getelementptr %[[ALLOCA]]{{\[}}%[[C1]]] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.ptr
+  // CHECK: %[[C128_0:.*]] = llvm.mlir.constant(128 : i32) : i64
+  // CHECK: llvm.store %[[C128_0]], %[[GEP_1]] : i64, !llvm.ptr
+  // CHECK: llvm.call @mgpuTensorMapEncodeTiledMemref({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[ALLOCA]])
+  %tensorMap2d = nvgpu.tma.create.descriptor %devicePtr2d_unranked box[64, 128] : memref<*xf32> -> !tensorMap2d
+  %devicePtr1d_unranked = memref.cast %devicePtr1d : memref<128xf32> to memref<*xf32>
+  // CHECK: %[[C5_1:.*]] = llvm.mlir.constant(5 : i32) : i64
+  // CHECK: %[[ALLOCA_1:.*]] = llvm.alloca %[[C5_1]] x i64 : (i64) -> !llvm.ptr
+  // CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i32) : i64
+  // CHECK: %[[GEP_2:.*]] = llvm.getelementptr %[[ALLOCA_1]]{{\[}}%[[C0_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.ptr
+  // CHECK: %[[C128_1:.*]] = llvm.mlir.constant(128 : i32) : i64
+  // CHECK: llvm.store %[[C128_1]], %[[GEP_2]] : i64, !llvm.ptr
+  // CHECK: llvm.call @mgpuTensorMapEncodeTiledMemref({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[ALLOCA_1]])
+  %tensorMap1d = nvgpu.tma.create.descriptor %devicePtr1d_unranked box[128] : memref<*xf32> -> !tensorMap1d
+  func.return
+}
+
 // CHECK-LABEL: @tma_prefetch(
 // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.tensormap.descriptor<tensor = memref<128xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none>, %[[arg1:[a-zA-Z0-9_]+]]: i1
 func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
diff --git a/mlir/test/Dialect/NVGPU/canonicalization.mlir b/mlir/test/Dialect/NVGPU/canonicalization.mlir
index a7fbfd8067395..9939461769c30 100644
--- a/mlir/test/Dialect/NVGPU/canonicalization.mlir
+++ b/mlir/test/Dialect/NVGPU/canonicalization.mlir
@@ -27,4 +27,19 @@ gpu.module @main_kernel {
     nvvm.cp.async.bulk.wait_group 0
     gpu.return
   }
-}
\ No newline at end of file
+}
+
+// -----
+
+!descriptor = !nvgpu.tensormap.descriptor<tensor = memref<64x16xf16, 3>, swizzle = none, l2promo=none, oob=zero, interleave=none>
+
+func.func @main() {
+  %a_host = memref.alloc() : memref<64x16xf16>
+  %c16 = arith.constant 16 : index
+  %c64 = arith.constant 64 : index
+  %a_device = gpu.alloc() : memref<64x16xf16>
+  %a_device_unranked = memref.cast %a_device : memref<64x16xf16> to memref<*xf16>
+  // CHECK: nvgpu.tma.create.descriptor %{{.*}} box [64, 16]
+  %a_device_map = nvgpu.tma.create.descriptor %a_device_unranked box[%c64, %c16] : memref<*xf16> -> !descriptor
+  return
+}
diff --git a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
index 40acd82cd0558..aa981b2688b81 100644
--- a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
+++ b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
@@ -18,12 +18,12 @@ func.func @main() {
   //      CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x32xf32> to memref<*xf32>
   //      CHECK: %[[c64:.*]] = arith.constant 64 : index
   //      CHECK: %[[c32:.*]] = arith.constant 32 : index
-  //      CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c32]]]
+  //      CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box [%[[c64]], %[[c32]]]
   // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
   //      CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x32xf32> to memref<*xf32>
   //      CHECK: %[[c8_2:.*]] = arith.constant 8 : index
   //      CHECK: %[[c32_2:.*]] = arith.constant 32 : index
-  //      CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c32_2]]]
+  //      CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box [%[[c8_2]], %[[c32_2]]]
   // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
   // CHECK: gpu.launch
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)

@grypp
Copy link
Member

grypp commented Apr 13, 2025

Thanks for implementing that. I'm definitely not against using static constants.

Using only dynamic SSA variables keeps the codebase simpler, while using static attributes makes the IR easier to read. I’d prefer to keep the codebase simple.

The generated code is the same either way.

@grypp
Copy link
Member

grypp commented Apr 13, 2025

Can you check does your change affects python builders? We don't test them as pre-merge.
https://github.com/llvm/llvm-project/blob/main/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py#L68

If all good, we can land this PR

@linuxlonelyeagle
Copy link
Member Author

Originally this code solved the problem.

if (descMemref.getRank() > 1 &&

The check here doesn't look right.Should have checked the shape of the box instead.So I expect it to be a static constant thing in the last dimension.But this PR doesn't look like it's going to solve the problem entirely.Because dynamic constant becomes static constant you need to run -canonicalize.

@linuxlonelyeagle
Copy link
Member Author

Can you check does your change affects python builders? We don't test them as pre-merge. https://github.com/llvm/llvm-project/blob/main/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py#L68

If all good, we can land this PR

I do see that coding is indeed a little bit easier now. PR does have a bit of a presence now though. I'll check the python builer soon.

@linuxlonelyeagle
Copy link
Member Author

Originally this code solved the problem.

if (descMemref.getRank() > 1 &&

The check here doesn't look right.Should have checked the shape of the box instead.So I expect it to be a static constant thing in the last dimension.But this PR doesn't look like it's going to solve the problem entirely.Because dynamic constant becomes static constant you need to run -canonicalize.

I'll fix it for the next PR.

@linuxlonelyeagle
Copy link
Member Author

 python matmul.py
===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
PASS 
===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
PASS 
===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
PASS 
===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
PASS 
===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
PASS 
===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
PASS 
===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
PASS 
===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
PASS 
===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
PASS 
===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
PASS 
===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
PASS 
===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
PASS 
===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
PASS 
===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
PASS 
===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
PASS 
===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
PASS 

But nvgpu toy is reporting errors for other reasons.

python Ch2.py
...
ValueError: Operand 3 of operation "nvgpu.tma.async.load" must be a Sequence of Values (std::bad_cast)

Copy link

github-actions bot commented Apr 13, 2025

✅ With the latest revision this PR passed the Python code formatter.

@linuxlonelyeagle
Copy link
Member Author

But nvgpu toy is reporting errors for other reasons.

This problem, which still exists on the current main branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants