Skip to content

Commit b038dc2

Browse files
authored
[MLIR][NVVM] Add TMA linear prefetch Op (#141211)
This patch adds an Op for the TMA prefetch (non-tensor) variant. llvm-lit tests are added to verify the lowering to the intrinsics. Signed-off-by: Durgadoss R <[email protected]>
1 parent 38cec04 commit b038dc2

File tree

3 files changed

+72
-0
lines changed

3 files changed

+72
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2344,6 +2344,49 @@ def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
23442344
}];
23452345
}
23462346

2347+
def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {
2348+
let summary = "Async bulk prefetch from global memory to L2 cache";
2349+
let description = [{
2350+
Initiates an asynchronous prefetch of data from the location
2351+
specified by `srcMem` to the L2 cache.
2352+
2353+
The `l2CacheHint` operand is optional, and it is used to specify cache
2354+
eviction policy that may be used during the memory access.
2355+
2356+
Example:
2357+
```mlir
2358+
nvvm.cp.async.bulk.prefetch %src, %size : !llvm.ptr<1>
2359+
2360+
// with l2_cache_hint
2361+
nvvm.cp.async.bulk.prefetch %src, %size l2_cache_hint = %ch : !llvm.ptr<1>
2362+
```
2363+
2364+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch)
2365+
}];
2366+
2367+
let arguments = (ins
2368+
LLVM_PointerGlobal:$srcMem,
2369+
I32:$size,
2370+
Optional<I64>:$l2CacheHint);
2371+
2372+
let assemblyFormat = [{
2373+
$srcMem `,` $size (`l2_cache_hint` `=` $l2CacheHint^ )?
2374+
attr-dict `:` type($srcMem)
2375+
}];
2376+
2377+
let extraClassDeclaration = [{
2378+
static mlir::NVVM::IDArgPair
2379+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2380+
llvm::IRBuilderBase& builder);
2381+
}];
2382+
2383+
string llvmBuilder = [{
2384+
auto [id, args] = NVVM::CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
2385+
*op, moduleTranslation, builder);
2386+
createIntrinsicCall(builder, id, args);
2387+
}];
2388+
}
2389+
23472390
def NVVM_CpAsyncBulkTensorPrefetchOp :
23482391
NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
23492392
let arguments = (ins

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,26 @@ CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
12541254
return id;
12551255
}
12561256

1257+
mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
1258+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1259+
auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1260+
llvm::SmallVector<llvm::Value *> args;
1261+
llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
1262+
1263+
// Fill the Intrinsic Args
1264+
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1265+
args.push_back(mt.lookupValue(thisOp.getSize()));
1266+
1267+
mlir::Value cacheHint = thisOp.getL2CacheHint();
1268+
const bool hasCacheHint = static_cast<bool>(cacheHint);
1269+
llvm::Value *i64Unused =
1270+
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1271+
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1272+
args.push_back(builder.getInt1(hasCacheHint));
1273+
1274+
return {id, std::move(args)};
1275+
}
1276+
12571277
mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
12581278
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
12591279
auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);

mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
22

3+
// CHECK-LABEL: @tma_bulk_prefetch
4+
llvm.func @tma_bulk_prefetch(%src : !llvm.ptr<1>, %size : i32, %ch : i64) {
5+
// CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
6+
// CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
7+
nvvm.cp.async.bulk.prefetch %src, %size : !llvm.ptr<1>
8+
nvvm.cp.async.bulk.prefetch %src, %size l2_cache_hint = %ch : !llvm.ptr<1>
9+
llvm.return
10+
}
11+
312
// CHECK-LABEL: @tma_prefetch_1d
413
llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
514
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 0, i1 false)

0 commit comments

Comments
 (0)