Skip to content

[mlir][nvgpu] make TmaCreateDescriptorOp can use static box and add folder function to it. #135497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -546,12 +546,14 @@ def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
}];

let arguments = (ins AnyUnrankedMemRef:$tensor,
Variadic<Index>:$boxDimensions);
Variadic<Index>:$boxDimensions,
DenseI64ArrayAttr:$static_boxDimensions);
let results = (outs NVGPU_TensorMapDescriptor:$tensorMap);
let assemblyFormat = [{
$tensor `box` `[` $boxDimensions `]` attr-dict `:` type($tensor) `->` type($tensorMap)
$tensor `box` custom<DynamicIndexList>($boxDimensions, $static_boxDimensions) attr-dict `:` type($tensor) `->` type($tensorMap)
}];
let hasVerifier = 1;
let hasFolder = 1;
}

def NVGPU_WarpgroupGenerateDescriptorOp : NVGPU_Op<"warpgroup.generate.descriptor", []> {
Expand Down
10 changes: 9 additions & 1 deletion mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1183,9 +1183,17 @@ struct NVGPUTmaCreateDescriptorOpLowering

Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
makeI64Const(b, 5));
for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
unsigned idx = 0;
ValueRange dynamicDim = adaptor.getBoxDimensions();
for (auto [index, shape] :
llvm::enumerate(adaptor.getStaticBoxDimensions())) {
Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
boxArrayPtr, makeI64Const(b, index));
Value value;
if (ShapedType::isDynamic(shape))
value = dynamicDim[idx++];
else
value = makeI64Const(b, shape);
b.create<LLVM::StoreOp>(value, gep);
}

Expand Down
47 changes: 47 additions & 0 deletions mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
Expand Down Expand Up @@ -458,6 +459,10 @@ LogicalResult TmaAsyncStoreOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// NVGPU_TmaCreateDescriptorOp
//===----------------------------------------------------------------------===//

LogicalResult TmaCreateDescriptorOp::verify() {
if (getBoxDimensions().size() > kMaxTMATensorDimension) {
return emitError() << "Maximum " << kMaxTMATensorDimension
Expand All @@ -472,6 +477,48 @@ LogicalResult TmaCreateDescriptorOp::verify() {
return success();
}

static Value
TmaCreateDescriptorFoldBoxConstant(TmaCreateDescriptorOp op,
TmaCreateDescriptorOp::FoldAdaptor adaptor) {
std::vector<int64_t> staticBoxDimensions = op.getStaticBoxDimensions().vec();
OperandRange dynamicBoxDimensions = op.getBoxDimensions();
SmallVector<Value> operands = {op.getTensor()};
ArrayRef<Attribute> dynamicBoxDimensionAttrs = adaptor.getBoxDimensions();
if (staticBoxDimensions.empty())
return {};

// `opChange` is a flag. If it is true, it means to update `op` in place.
bool opChange = false;
unsigned idx = 0;

for (unsigned i = 0, e = staticBoxDimensions.size(); i < e; ++i) {
if (!ShapedType::isDynamic(staticBoxDimensions[i]))
continue;
Attribute dynamicBoxDimensionAttr = dynamicBoxDimensionAttrs[idx];
Value dynamicDimension = dynamicBoxDimensions[idx++];
if (auto attr =
mlir::dyn_cast_if_present<IntegerAttr>(dynamicBoxDimensionAttr)) {
staticBoxDimensions[i] = attr.getInt();
opChange = true;
continue;
}
operands.push_back(dynamicDimension);
}

if (opChange) {
op.setStaticBoxDimensions(staticBoxDimensions);
op.getOperation()->setOperands(operands);
return op.getResult();
}
return {};
}

OpFoldResult TmaCreateDescriptorOp::fold(FoldAdaptor adaptor) {
if (auto val = TmaCreateDescriptorFoldBoxConstant(*this, adaptor))
return val;
return OpFoldResult();
}

//===----------------------------------------------------------------------===//
// NVGPU_WarpgroupGenerateDescriptorOp
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,7 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
SmallVector<Value> sizes =
getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);

SmallVector<int64_t> static_dims(sizes.size(), ShapedType::kDynamic);
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
loc,
Expand All @@ -972,7 +973,7 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
TensorMapSwizzleKind::SWIZZLE_NONE,
TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
TensorMapInterleaveKind::INTERLEAVE_NONE),
unrankedMemRef, sizes);
unrankedMemRef, sizes, static_dims);
return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
}

Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,32 @@ func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : m
func.return
}

func.func @create_tensor_map_constant_box_dim(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
%devicePtr2d_unranked = memref.cast %devicePtr2d : memref<64x128xf32> to memref<*xf32>
// CHECK: %[[C5_0:.*]] = llvm.mlir.constant(5 : i32) : i64
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[C5_0]] x i64 : (i64) -> !llvm.ptr
// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i32) : i64
// CHECK: %[[GEP_0:.*]] = llvm.getelementptr %[[ALLOCA]]{{\[}}%[[C0_0]]] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.ptr
// CHECK: %[[C64:.*]] = llvm.mlir.constant(64 : i32) : i64
// CHECK: llvm.store %[[C64]], %[[GEP_0]] : i64, !llvm.ptr
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i64
// CHECK: %[[GEP_1:.*]] = llvm.getelementptr %[[ALLOCA]]{{\[}}%[[C1]]] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.ptr
// CHECK: %[[C128_0:.*]] = llvm.mlir.constant(128 : i32) : i64
// CHECK: llvm.store %[[C128_0]], %[[GEP_1]] : i64, !llvm.ptr
// CHECK: llvm.call @mgpuTensorMapEncodeTiledMemref({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[ALLOCA]])
%tensorMap2d = nvgpu.tma.create.descriptor %devicePtr2d_unranked box[64, 128] : memref<*xf32> -> !tensorMap2d
%devicePtr1d_unranked = memref.cast %devicePtr1d : memref<128xf32> to memref<*xf32>
// CHECK: %[[C5_1:.*]] = llvm.mlir.constant(5 : i32) : i64
// CHECK: %[[ALLOCA_1:.*]] = llvm.alloca %[[C5_1]] x i64 : (i64) -> !llvm.ptr
// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i32) : i64
// CHECK: %[[GEP_2:.*]] = llvm.getelementptr %[[ALLOCA_1]]{{\[}}%[[C0_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.ptr
// CHECK: %[[C128_1:.*]] = llvm.mlir.constant(128 : i32) : i64
// CHECK: llvm.store %[[C128_1]], %[[GEP_2]] : i64, !llvm.ptr
// CHECK: llvm.call @mgpuTensorMapEncodeTiledMemref({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[ALLOCA_1]])
%tensorMap1d = nvgpu.tma.create.descriptor %devicePtr1d_unranked box[128] : memref<*xf32> -> !tensorMap1d
func.return
}

// CHECK-LABEL: @tma_prefetch(
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.tensormap.descriptor<tensor = memref<128xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none>, %[[arg1:[a-zA-Z0-9_]+]]: i1
func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
Expand Down
17 changes: 16 additions & 1 deletion mlir/test/Dialect/NVGPU/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,19 @@ gpu.module @main_kernel {
nvvm.cp.async.bulk.wait_group 0
gpu.return
}
}
}

// -----

!descriptor = !nvgpu.tensormap.descriptor<tensor = memref<64x16xf16, 3>, swizzle = none, l2promo=none, oob=zero, interleave=none>

func.func @main() {
%a_host = memref.alloc() : memref<64x16xf16>
%c16 = arith.constant 16 : index
%c64 = arith.constant 64 : index
%a_device = gpu.alloc() : memref<64x16xf16>
%a_device_unranked = memref.cast %a_device : memref<64x16xf16> to memref<*xf16>
// CHECK: nvgpu.tma.create.descriptor %{{.*}} box [64, 16]
%a_device_map = nvgpu.tma.create.descriptor %a_device_unranked box[%c64, %c16] : memref<*xf16> -> !descriptor
return
}
4 changes: 2 additions & 2 deletions mlir/test/Dialect/NVGPU/tmaload-transform.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ func.func @main() {
// CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x32xf32> to memref<*xf32>
// CHECK: %[[c64:.*]] = arith.constant 64 : index
// CHECK: %[[c32:.*]] = arith.constant 32 : index
// CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c32]]]
// CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box [%[[c64]], %[[c32]]]
// CHECK-SAME: : memref<*xf32> -> <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
// CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x32xf32> to memref<*xf32>
// CHECK: %[[c8_2:.*]] = arith.constant 8 : index
// CHECK: %[[c32_2:.*]] = arith.constant 32 : index
// CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c32_2]]]
// CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box [%[[c8_2]], %[[c32_2]]]
// CHECK-SAME: : memref<*xf32> -> <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
// CHECK: gpu.launch
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
Expand Down
6 changes: 5 additions & 1 deletion mlir/test/Examples/NVGPU/tools/nvdsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,12 @@ def create_descriptor(self, device_ptr):
),
device_ptr,
)
box_static_dim = [MLIR_DYNAMIC] * len(self.tma_box_shape)
self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape)
tma_descriptor_ty,
device_unranked_memref,
map(const, self.tma_box_shape),
box_static_dim,
)
return self.tma_descriptor.result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,12 @@ def tma_descriptor_op(self, device_ptr):
),
device_ptr,
)
box_static_dim = [MLIR_DYNAMIC] * len(self.tma_box_shape)
tma_descriptor_op = nvgpu.TmaCreateDescriptorOp(
tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape)
tma_descriptor_ty,
device_unranked_memref,
map(c, self.tma_box_shape),
box_static_dim,
)
return tma_descriptor_op.result

Expand Down
Loading