Skip to content

Commit 5576eba

Browse files
committed
[MLIR][NVVM] Add Op for TMA Prefetch
PR llvm#115527 adds intrinsics for TMA prefetch. This patch adds an NVVM Dialect Op for the same. Lit tests to verify the lowering to LLVM intrinsics as well as verifier tests (for invalid cases) are added. Signed-off-by: Durgadoss R <[email protected]>
1 parent 9685681 commit 5576eba

File tree

4 files changed

+195
-1
lines changed

4 files changed

+195
-1
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1949,6 +1949,74 @@ def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
19491949
}];
19501950
}
19511951

1952+
def NVVM_CpAsyncBulkTensorPrefetchOp :
1953+
NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
1954+
let arguments = (ins
1955+
LLVM_AnyPointer:$tmaDescriptor,
1956+
Variadic<I32>:$coordinates,
1957+
Variadic<I16>:$im2colOffsets,
1958+
Optional<I64>:$l2CacheHint);
1959+
1960+
let description = [{
1961+
Initiates an asynchronous prefetch operation on the tensor data from global
1962+
memory to L2 cache.
1963+
1964+
The Op has two modes:
1965+
1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
1966+
layout is preserved at the destination.
1967+
1968+
2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
1969+
the elements in the Bounding Box of the source tensor are rearranged into
1970+
columns at the destination. In this mode, the tensor has to be at least
1971+
3-dimensional.
1972+
1973+
The `l2CacheHint` operand is optional, and it is used to specify cache
1974+
eviction policy that may be used during the memory access.
1975+
1976+
[For more information, see PTX ISA]
1977+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor)
1978+
}];
1979+
1980+
let assemblyFormat = [{
1981+
$tmaDescriptor `,`
1982+
`box` `[`$coordinates `]`
1983+
(`im2col` `[` $im2colOffsets^ `]` )?
1984+
(`l2_cache_hint` `=` $l2CacheHint^ )?
1985+
attr-dict `:` type($tmaDescriptor)
1986+
}];
1987+
1988+
let extraClassDeclaration = [{
1989+
static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col);
1990+
}];
1991+
1992+
let hasVerifier = 1;
1993+
1994+
string llvmBuilder = [{
1995+
// Arguments to the intrinsic:
1996+
// tmaDesc, tensorDims, im2colOffsets
1997+
// cache_hint(if applicable) and flag(boolean)
1998+
llvm::SmallVector<llvm::Value *> translatedOperands;
1999+
translatedOperands.push_back($tmaDescriptor);
2000+
2001+
for (auto v : op.getCoordinates())
2002+
translatedOperands.push_back(moduleTranslation.lookupValue(v));
2003+
2004+
for (auto v : op.getIm2colOffsets())
2005+
translatedOperands.push_back(moduleTranslation.lookupValue(v));
2006+
2007+
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2008+
auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));
2009+
2010+
bool isCacheHint = op.getL2CacheHint() ? true : false;
2011+
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
2012+
translatedOperands.push_back(builder.getInt1(isCacheHint));
2013+
2014+
auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
2015+
op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
2016+
createIntrinsicCall(builder, intId, translatedOperands);
2017+
}];
2018+
}
2019+
19522020
//===----------------------------------------------------------------------===//
19532021
// NVVM Wgmma Ops
19542022
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,22 @@ LogicalResult CpAsyncOp::verify() {
108108
return success();
109109
}
110110

111+
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
112+
if (getCoordinates().empty() || getCoordinates().size() > 5)
113+
return emitError("expects coordinates between 1 to 5 dimension");
114+
115+
// Check for im2col mode
116+
if (!getIm2colOffsets().empty()) {
117+
if (getCoordinates().size() < 3)
118+
return emitError(
119+
"to use im2col mode, the tensor has to be at least 3-dimensional");
120+
if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
121+
return emitError(
122+
"im2col offsets must be 2 less than number of coordinates");
123+
}
124+
return success();
125+
}
126+
111127
// Given the element type of an operand and whether or not it is an accumulator,
112128
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
113129
// operand's element type.
@@ -1055,6 +1071,30 @@ LogicalResult NVVM::BarrierOp::verify() {
10551071
return success();
10561072
}
10571073

1074+
llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1075+
bool isIm2Col) {
1076+
switch (tensorDims) {
1077+
case 1:
1078+
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1079+
case 2:
1080+
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1081+
case 3:
1082+
return isIm2Col
1083+
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1084+
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1085+
case 4:
1086+
return isIm2Col
1087+
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1088+
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1089+
case 5:
1090+
return isIm2Col
1091+
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1092+
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1093+
default:
1094+
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1095+
}
1096+
}
1097+
10581098
//===----------------------------------------------------------------------===//
10591099
// NVVMDialect initialization, type parsing, and registration.
10601100
//===----------------------------------------------------------------------===//

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,28 @@ llvm.func @nvvm_fence_proxy_release() {
3030
// expected-error @below {{'nvvm.fence.proxy.release' op uni-directional proxies only support tensormap for to_proxy attribute}}
3131
nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy=#nvvm.proxy_kind<generic> to_proxy=#nvvm.proxy_kind<generic>
3232
llvm.return
33-
}
33+
}
34+
35+
// -----
36+
37+
llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
38+
// expected-error @below {{expects coordinates between 1 to 5 dimension}}
39+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
40+
llvm.return
41+
}
42+
43+
// -----
44+
45+
llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
46+
// expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
47+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
48+
llvm.return
49+
}
50+
51+
// -----
52+
53+
llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
54+
// expected-error @below {{im2col offsets must be 2 less than number of coordinates}}
55+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] : !llvm.ptr
56+
llvm.return
57+
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,3 +715,65 @@ llvm.func @nvvm_breakpoint() {
715715
nvvm.breakpoint
716716
llvm.return
717717
}
718+
719+
// -----
720+
721+
// CHECK-LABEL: @tma_prefetch_1d
722+
llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
723+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 undef, i1 false)
724+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 %{{.*}}, i1 true)
725+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] : !llvm.ptr
726+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] l2_cache_hint = %ch : !llvm.ptr
727+
llvm.return
728+
}
729+
730+
// CHECK-LABEL: @tma_prefetch_2d
731+
llvm.func @tma_prefetch_2d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
732+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
733+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
734+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] : !llvm.ptr
735+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] l2_cache_hint = %ch : !llvm.ptr
736+
llvm.return
737+
}
738+
739+
// CHECK-LABEL: @tma_prefetch_3d
740+
llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %ch : i64) {
741+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
742+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
743+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] : !llvm.ptr
744+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] l2_cache_hint = %ch : !llvm.ptr
745+
746+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
747+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
748+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] : !llvm.ptr
749+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
750+
llvm.return
751+
}
752+
753+
// CHECK-LABEL: @tma_prefetch_4d
754+
llvm.func @tma_prefetch_4d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
755+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
756+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
757+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] : !llvm.ptr
758+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch : !llvm.ptr
759+
760+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
761+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
762+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] : !llvm.ptr
763+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch : !llvm.ptr
764+
llvm.return
765+
}
766+
767+
// CHECK-LABEL: @tma_prefetch_5d
768+
llvm.func @tma_prefetch_5d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
769+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
770+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
771+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] : !llvm.ptr
772+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch : !llvm.ptr
773+
774+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
775+
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
776+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] : !llvm.ptr
777+
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch : !llvm.ptr
778+
llvm.return
779+
}

0 commit comments

Comments
 (0)