Skip to content

[MLIR][NVVM] Add Op for TMA Prefetch #116232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1949,6 +1949,74 @@ def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
}];
}

def NVVM_CpAsyncBulkTensorPrefetchOp :
NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
let arguments = (ins
LLVM_AnyPointer:$tmaDescriptor,
Variadic<I32>:$coordinates,
Variadic<I16>:$im2colOffsets,
Optional<I64>:$l2CacheHint);

let description = [{
Initiates an asynchronous prefetch operation on the tensor data from global
memory to L2 cache.

The Op has two modes:
1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
layout is preserved at the destination.

2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
the elements in the Bounding Box of the source tensor are rearranged into
columns at the destination. In this mode, the tensor has to be at least
3-dimensional.

The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.

[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor)
}];

let assemblyFormat = [{
$tmaDescriptor `,`
`box` `[`$coordinates `]`
(`im2col` `[` $im2colOffsets^ `]` )?
(`l2_cache_hint` `=` $l2CacheHint^ )?
attr-dict `:` type($tmaDescriptor)
}];

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col);
}];

let hasVerifier = 1;

string llvmBuilder = [{
// Arguments to the intrinsic:
// tmaDesc, tensorDims, im2colOffsets
// cache_hint(if applicable) and flag(boolean)
llvm::SmallVector<llvm::Value *> translatedOperands;
translatedOperands.push_back($tmaDescriptor);

for (auto v : op.getCoordinates())
translatedOperands.push_back(moduleTranslation.lookupValue(v));

for (auto v : op.getIm2colOffsets())
translatedOperands.push_back(moduleTranslation.lookupValue(v));

llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));

bool isCacheHint = op.getL2CacheHint() ? true : false;
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
translatedOperands.push_back(builder.getInt1(isCacheHint));

auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
createIntrinsicCall(builder, intId, translatedOperands);
}];
}

//===----------------------------------------------------------------------===//
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
Expand Down
57 changes: 48 additions & 9 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,32 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {

void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }

LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
if (getCoordinates().empty() || getCoordinates().size() > 5)
return emitError("expects coordinates between 1 to 5 dimension");

// Check for im2col mode
if (!getIm2colOffsets().empty()) {
if (getCoordinates().size() < 3)
// This verifier is shared across:
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) and
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch) Ops.
static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
size_t numIm2ColOffsets,
Location loc) {
if (tensorDims < 1 || tensorDims > 5)
return emitError(loc, "expects coordinates between 1 to 5 dimension");

if (numIm2ColOffsets) {
if (tensorDims < 3)
return emitError(
loc,
"to use im2col mode, the tensor has to be at least 3-dimensional");
if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
if (tensorDims != (numIm2ColOffsets + 2))
return emitError(
"im2col offsets must be 2 less than number of coordinates");
loc, "im2col offsets must be 2 less than number of coordinates");
}
return success();
}

LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
getIm2colOffsets().size(), getLoc());
}

LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
if (getCoordinates().size() > 5)
return emitError("Maximum 5 coordinates and dimension is supported.");
Expand All @@ -108,6 +118,11 @@ LogicalResult CpAsyncOp::verify() {
return success();
}

LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
getIm2colOffsets().size(), getLoc());
}

// Given the element type of an operand and whether or not it is an accumulator,
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
// operand's element type.
Expand Down Expand Up @@ -1055,6 +1070,30 @@ LogicalResult NVVM::BarrierOp::verify() {
return success();
}

llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
bool isIm2Col) {
switch (tensorDims) {
case 1:
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
case 2:
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
case 3:
return isIm2Col
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
case 4:
return isIm2Col
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
case 5:
return isIm2Col
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
default:
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
}
}

//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 25 additions & 1 deletion mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,28 @@ llvm.func @nvvm_fence_proxy_release() {
// expected-error @below {{'nvvm.fence.proxy.release' op uni-directional proxies only support tensormap for to_proxy attribute}}
nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy=#nvvm.proxy_kind<generic> to_proxy=#nvvm.proxy_kind<generic>
llvm.return
}
}

// -----

llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
// expected-error @below {{expects coordinates between 1 to 5 dimension}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
llvm.return
}

// -----

llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
// expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}

// -----

llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
// expected-error @below {{im2col offsets must be 2 less than number of coordinates}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] : !llvm.ptr
llvm.return
}
62 changes: 62 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,65 @@ llvm.func @nvvm_breakpoint() {
nvvm.breakpoint
llvm.return
}

// -----

// CHECK-LABEL: @tma_prefetch_1d
llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}

// CHECK-LABEL: @tma_prefetch_2d
llvm.func @tma_prefetch_2d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}

// CHECK-LABEL: @tma_prefetch_3d
llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %ch : i64) {
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] l2_cache_hint = %ch : !llvm.ptr

// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}

// CHECK-LABEL: @tma_prefetch_4d
llvm.func @tma_prefetch_4d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch : !llvm.ptr

// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}

// CHECK-LABEL: @tma_prefetch_5d
llvm.func @tma_prefetch_5d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch : !llvm.ptr

// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}
Loading