-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][NVVM] Add Op for TMA Prefetch #116232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Add Op for TMA Prefetch #116232
Conversation
@grypp , Kindly help with review. |
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Durgadoss R (durga4github) ChangesPR #115527 adds intrinsics for TMA prefetch. Lit tests to verify the lowering to LLVM intrinsics as well as PTX Spec reference: Full diff: https://github.com/llvm/llvm-project/pull/116232.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7cb4b5c346ad97..6b462de144d1ff 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1949,6 +1949,74 @@ def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
}];
}
+def NVVM_CpAsyncBulkTensorPrefetchOp :
+ NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
+ let arguments = (ins
+ LLVM_AnyPointer:$tmaDescriptor,
+ Variadic<I32>:$coordinates,
+ Variadic<I16>:$im2colOffsets,
+ Optional<I64>:$l2CacheHint);
+
+ let description = [{
+ Initiates an asynchronous prefetch operation on the tensor data from global
+ memory to L2 cache.
+
+ The Op has two modes:
+ 1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
+ layout is preserved at the destination.
+
+ 2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
+ the elements in the Bounding Box of the source tensor are rearranged into
+ columns at the destination. In this mode, the tensor has to be at least
+ 3-dimensional.
+
+ The `l2CacheHint` operand is optional, and it is used to specify cache
+ eviction policy that may be used during the memory access.
+
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor)
+ }];
+
+ let assemblyFormat = [{
+ $tmaDescriptor `,`
+ `box` `[`$coordinates `]`
+ (`im2col` `[` $im2colOffsets^ `]` )?
+ (`l2_cache_hint` `=` $l2CacheHint^ )?
+ attr-dict `:` type($tmaDescriptor)
+ }];
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col);
+ }];
+
+ let hasVerifier = 1;
+
+ string llvmBuilder = [{
+ // Arguments to the intrinsic:
+ // tmaDesc, tensorDims, im2colOffsets
+ // cache_hint(if applicable) and flag(boolean)
+ llvm::SmallVector<llvm::Value *> translatedOperands;
+ translatedOperands.push_back($tmaDescriptor);
+
+ for (auto v : op.getCoordinates())
+ translatedOperands.push_back(moduleTranslation.lookupValue(v));
+
+ for (auto v : op.getIm2colOffsets())
+ translatedOperands.push_back(moduleTranslation.lookupValue(v));
+
+ llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
+ auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));
+
+ bool isCacheHint = op.getL2CacheHint() ? true : false;
+ translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
+ translatedOperands.push_back(builder.getInt1(isCacheHint));
+
+ auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
+ op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
+ createIntrinsicCall(builder, intId, translatedOperands);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 5ab64ea1b2097a..4a96b2b97c817d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -108,6 +108,22 @@ LogicalResult CpAsyncOp::verify() {
return success();
}
+LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
+ if (getCoordinates().empty() || getCoordinates().size() > 5)
+ return emitError("expects coordinates between 1 to 5 dimension");
+
+ // Check for im2col mode
+ if (!getIm2colOffsets().empty()) {
+ if (getCoordinates().size() < 3)
+ return emitError(
+ "to use im2col mode, the tensor has to be at least 3-dimensional");
+ if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
+ return emitError(
+ "im2col offsets must be 2 less than number of coordinates");
+ }
+ return success();
+}
+
// Given the element type of an operand and whether or not it is an accumulator,
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
// operand's element type.
@@ -1055,6 +1071,30 @@ LogicalResult NVVM::BarrierOp::verify() {
return success();
}
+llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
+ bool isIm2Col) {
+ switch (tensorDims) {
+ case 1:
+ return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
+ case 2:
+ return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
+ case 3:
+ return isIm2Col
+ ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
+ : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
+ case 4:
+ return isIm2Col
+ ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
+ : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
+ case 5:
+ return isIm2Col
+ ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
+ : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
+ default:
+ llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
+ }
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 0e563808da970b..58282adf4dda85 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -30,4 +30,28 @@ llvm.func @nvvm_fence_proxy_release() {
// expected-error @below {{'nvvm.fence.proxy.release' op uni-directional proxies only support tensormap for to_proxy attribute}}
nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy=#nvvm.proxy_kind<generic> to_proxy=#nvvm.proxy_kind<generic>
llvm.return
-}
\ No newline at end of file
+}
+
+// -----
+
+llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
+ // expected-error @below {{expects coordinates between 1 to 5 dimension}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
+ // expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
+ // expected-error @below {{im2col offsets must be 2 less than number of coordinates}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] : !llvm.ptr
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 75ce958b43fd34..e5ea03ff7e0017 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -715,3 +715,65 @@ llvm.func @nvvm_breakpoint() {
nvvm.breakpoint
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @tma_prefetch_1d
+llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] l2_cache_hint = %ch : !llvm.ptr
+ llvm.return
+}
+
+// CHECK-LABEL: @tma_prefetch_2d
+llvm.func @tma_prefetch_2d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] l2_cache_hint = %ch : !llvm.ptr
+ llvm.return
+}
+
+// CHECK-LABEL: @tma_prefetch_3d
+llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %ch : i64) {
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] l2_cache_hint = %ch : !llvm.ptr
+
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
+ llvm.return
+}
+
+// CHECK-LABEL: @tma_prefetch_4d
+llvm.func @tma_prefetch_4d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch : !llvm.ptr
+
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch : !llvm.ptr
+ llvm.return
+}
+
+// CHECK-LABEL: @tma_prefetch_5d
+llvm.func @tma_prefetch_5d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch : !llvm.ptr
+
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch : !llvm.ptr
+ llvm.return
+}
|
PR llvm#115527 adds intrinsics for TMA prefetch. This patch adds an NVVM Dialect Op for the same. Lit tests to verify the lowering to LLVM intrinsics as well as verifier tests (for invalid cases) are added. Signed-off-by: Durgadoss R <[email protected]>
5576eba
to
d94f0fd
Compare
Refactored to a common function and builds are clean. Merging the change now. |
PR #115527 adds intrinsics for TMA prefetch.
This patch adds an NVVM Dialect Op for the same.
Lit tests to verify the lowering to LLVM intrinsics as well as
verifier tests (for invalid cases) are added.
PTX Spec reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor