Skip to content

Commit 1b23ebe

Browse files
authored
[MLIR][NVVM] Add Op for TMA Prefetch (#116232)
PR #115527 adds intrinsics for TMA prefetch. This patch adds an NVVM Dialect Op for the same. Lit tests to verify the lowering to LLVM intrinsics as well as verifier tests (for invalid cases) are added. PTX Spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor Signed-off-by: Durgadoss R <[email protected]>
1 parent 7b54976 commit 1b23ebe

File tree

4 files changed

+203
-10
lines changed

4 files changed

+203
-10
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1949,6 +1949,74 @@ def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
19491949
}];
19501950
}
19511951

1952+
def NVVM_CpAsyncBulkTensorPrefetchOp :
1953+
NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
1954+
let arguments = (ins
1955+
LLVM_AnyPointer:$tmaDescriptor,
1956+
Variadic<I32>:$coordinates,
1957+
Variadic<I16>:$im2colOffsets,
1958+
Optional<I64>:$l2CacheHint);
1959+
1960+
let description = [{
1961+
Initiates an asynchronous prefetch operation on the tensor data from global
1962+
memory to L2 cache.
1963+
1964+
The Op has two modes:
1965+
1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
1966+
layout is preserved at the destination.
1967+
1968+
2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
1969+
the elements in the Bounding Box of the source tensor are rearranged into
1970+
columns at the destination. In this mode, the tensor has to be at least
1971+
3-dimensional.
1972+
1973+
The `l2CacheHint` operand is optional, and it is used to specify cache
1974+
eviction policy that may be used during the memory access.
1975+
1976+
[For more information, see PTX ISA]
1977+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor)
1978+
}];
1979+
1980+
let assemblyFormat = [{
1981+
$tmaDescriptor `,`
1982+
`box` `[`$coordinates `]`
1983+
(`im2col` `[` $im2colOffsets^ `]` )?
1984+
(`l2_cache_hint` `=` $l2CacheHint^ )?
1985+
attr-dict `:` type($tmaDescriptor)
1986+
}];
1987+
1988+
let extraClassDeclaration = [{
1989+
static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col);
1990+
}];
1991+
1992+
let hasVerifier = 1;
1993+
1994+
string llvmBuilder = [{
1995+
// Arguments to the intrinsic:
1996+
// tmaDesc, tensorDims, im2colOffsets
1997+
// cache_hint(if applicable) and flag(boolean)
1998+
llvm::SmallVector<llvm::Value *> translatedOperands;
1999+
translatedOperands.push_back($tmaDescriptor);
2000+
2001+
for (auto v : op.getCoordinates())
2002+
translatedOperands.push_back(moduleTranslation.lookupValue(v));
2003+
2004+
for (auto v : op.getIm2colOffsets())
2005+
translatedOperands.push_back(moduleTranslation.lookupValue(v));
2006+
2007+
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2008+
auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));
2009+
2010+
bool isCacheHint = op.getL2CacheHint() ? true : false;
2011+
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
2012+
translatedOperands.push_back(builder.getInt1(isCacheHint));
2013+
2014+
auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
2015+
op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
2016+
createIntrinsicCall(builder, intId, translatedOperands);
2017+
}];
2018+
}
2019+
19522020
//===----------------------------------------------------------------------===//
19532021
// NVVM Wgmma Ops
19542022
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,22 +75,32 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
7575

7676
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
7777

78-
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
79-
if (getCoordinates().empty() || getCoordinates().size() > 5)
80-
return emitError("expects coordinates between 1 to 5 dimension");
81-
82-
// Check for im2col mode
83-
if (!getIm2colOffsets().empty()) {
84-
if (getCoordinates().size() < 3)
78+
// This verifier is shared across:
79+
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) and
80+
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch) Ops.
81+
static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
82+
size_t numIm2ColOffsets,
83+
Location loc) {
84+
if (tensorDims < 1 || tensorDims > 5)
85+
return emitError(loc, "expects coordinates between 1 to 5 dimension");
86+
87+
if (numIm2ColOffsets) {
88+
if (tensorDims < 3)
8589
return emitError(
90+
loc,
8691
"to use im2col mode, the tensor has to be at least 3-dimensional");
87-
if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
92+
if (tensorDims != (numIm2ColOffsets + 2))
8893
return emitError(
89-
"im2col offsets must be 2 less than number of coordinates");
94+
loc, "im2col offsets must be 2 less than number of coordinates");
9095
}
9196
return success();
9297
}
9398

99+
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
100+
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
101+
getIm2colOffsets().size(), getLoc());
102+
}
103+
94104
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
95105
if (getCoordinates().size() > 5)
96106
return emitError("Maximum 5 coordinates and dimension is supported.");
@@ -108,6 +118,11 @@ LogicalResult CpAsyncOp::verify() {
108118
return success();
109119
}
110120

121+
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
122+
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
123+
getIm2colOffsets().size(), getLoc());
124+
}
125+
111126
// Given the element type of an operand and whether or not it is an accumulator,
112127
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
113128
// operand's element type.
@@ -1055,6 +1070,30 @@ LogicalResult NVVM::BarrierOp::verify() {
10551070
return success();
10561071
}
10571072

1073+
llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1074+
bool isIm2Col) {
1075+
switch (tensorDims) {
1076+
case 1:
1077+
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1078+
case 2:
1079+
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1080+
case 3:
1081+
return isIm2Col
1082+
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1083+
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1084+
case 4:
1085+
return isIm2Col
1086+
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1087+
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1088+
case 5:
1089+
return isIm2Col
1090+
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1091+
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1092+
default:
1093+
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1094+
}
1095+
}
1096+
10581097
//===----------------------------------------------------------------------===//
10591098
// NVVMDialect initialization, type parsing, and registration.
10601099
//===----------------------------------------------------------------------===//

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,28 @@ llvm.func @nvvm_fence_proxy_release() {
3030
// expected-error @below {{'nvvm.fence.proxy.release' op uni-directional proxies only support tensormap for to_proxy attribute}}
3131
nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy=#nvvm.proxy_kind<generic> to_proxy=#nvvm.proxy_kind<generic>
3232
llvm.return
33-
}
33+
}
34+
35+
// -----
36+
37+
llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
38+
// expected-error @below {{expects coordinates between 1 to 5 dimension}}
39+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
40+
llvm.return
41+
}
42+
43+
// -----
44+
45+
llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
46+
// expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
47+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
48+
llvm.return
49+
}
50+
51+
// -----
52+
53+
llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
54+
// expected-error @below {{im2col offsets must be 2 less than number of coordinates}}
55+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] : !llvm.ptr
56+
llvm.return
57+
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,3 +715,65 @@ llvm.func @nvvm_breakpoint() {
715715
nvvm.breakpoint
716716
llvm.return
717717
}
718+
719+
// -----
720+
721+
// CHECK-LABEL: @tma_prefetch_1d
722+
llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
723+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 undef, i1 false)
724+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 %{{.*}}, i1 true)
725+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] : !llvm.ptr
726+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] l2_cache_hint = %ch : !llvm.ptr
727+
llvm.return
728+
}
729+
730+
// CHECK-LABEL: @tma_prefetch_2d
731+
llvm.func @tma_prefetch_2d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
732+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
733+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
734+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] : !llvm.ptr
735+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] l2_cache_hint = %ch : !llvm.ptr
736+
llvm.return
737+
}
738+
739+
// CHECK-LABEL: @tma_prefetch_3d
740+
llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %ch : i64) {
741+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
742+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
743+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] : !llvm.ptr
744+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] l2_cache_hint = %ch : !llvm.ptr
745+
746+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
747+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
748+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] : !llvm.ptr
749+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
750+
llvm.return
751+
}
752+
753+
// CHECK-LABEL: @tma_prefetch_4d
754+
llvm.func @tma_prefetch_4d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
755+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
756+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
757+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] : !llvm.ptr
758+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch : !llvm.ptr
759+
760+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
761+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
762+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] : !llvm.ptr
763+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch : !llvm.ptr
764+
llvm.return
765+
}
766+
767+
// CHECK-LABEL: @tma_prefetch_5d
768+
llvm.func @tma_prefetch_5d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
769+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
770+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
771+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] : !llvm.ptr
772+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch : !llvm.ptr
773+
774+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
775+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
776+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] : !llvm.ptr
777+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch : !llvm.ptr
778+
llvm.return
779+
}

0 commit comments

Comments
 (0)