Skip to content

Commit 70c2e06

Browse files
committed
[mlir][nvgpu] Add nvgpu.tma.async.load and nvgpu.tma.descriptor
This work adds `nvgpu.tma.async.load` Op that requests tma load asyncronusly using mbarrier object. It also creates nvgpu.tma.descriptor type. The type is supposed be created by `cuTensorMapEncodeTiled` cuda drivers api. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D155453
1 parent 2469cdd commit 70c2e06

File tree

7 files changed

+213
-13
lines changed

7 files changed

+213
-13
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,17 @@
11
add_mlir_dialect(NVGPU nvgpu)
22
add_mlir_doc(NVGPU NVGPU Dialects/ -gen-dialect-doc)
3+
4+
set(LLVM_TARGET_DEFINITIONS NVGPU.td)
5+
mlir_tablegen(NVGPUEnums.h.inc -gen-enum-decls)
6+
mlir_tablegen(NVGPUEnums.cpp.inc -gen-enum-defs)
7+
add_public_tablegen_target(MLIRNVGPUEnumsIncGen)
8+
9+
set(LLVM_TARGET_DEFINITIONS NVGPU.td)
10+
mlir_tablegen(NVGPUAttrDefs.h.inc -gen-attrdef-decls)
11+
mlir_tablegen(NVGPUAttrDefs.cpp.inc -gen-attrdef-defs)
12+
add_public_tablegen_target(MLIRNVGPUAttributesIncGen)
13+
14+
set(LLVM_TARGET_DEFINITIONS NVGPU.td)
15+
mlir_tablegen(NVGPUAttrTypes.h.inc -gen-typedef-decls)
16+
mlir_tablegen(NVGPUAttrTypes.cpp.inc -gen-typedef-decls)
17+
add_public_tablegen_target(MLIRNVGPUTypesIncGen)

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
include "mlir/Interfaces/SideEffectInterfaces.td"
2424
include "mlir/IR/AttrTypeBase.td"
2525
include "mlir/IR/OpBase.td"
26+
include "mlir/IR/EnumAttr.td"
2627

2728
def NVGPU_Dialect : Dialect {
2829
let name = "nvgpu";
@@ -61,6 +62,58 @@ def NVGPU_Dialect : Dialect {
6162
}];
6263
}
6364

65+
//===----------------------------------------------------------------------===//
66+
// NVGPU Attribute Definitions
67+
//===----------------------------------------------------------------------===//
68+
69+
def TensorMapSwizzleNone : I32EnumAttrCase<"SWIZZLE_NONE", 0, "none">;
70+
def TensorMapSwizzle32B : I32EnumAttrCase<"SWIZZLE_32B", 1, "swizzle_32b">;
71+
def TensorMapSwizzle64B : I32EnumAttrCase<"SWIZZLE_64B", 2, "swizzle_64b">;
72+
def TensorMapSwizzle128B : I32EnumAttrCase<"SWIZZLE_128B", 3, "swizzle_128b">;
73+
def TensorMapSwizzleKind : I32EnumAttr<"TensorMapSwizzleKind",
74+
"Tensor map swizzling mode of shared memory banks",
75+
[ TensorMapSwizzleNone, TensorMapSwizzle32B, TensorMapSwizzle64B,
76+
TensorMapSwizzle128B]> {
77+
let genSpecializedAttr = 0;
78+
let cppNamespace = "::mlir::nvgpu";
79+
}
80+
81+
def TensorMapL2PromoNone : I32EnumAttrCase<"L2PROMO_NONE", 0, "none">;
82+
def TensorMapL2Promo64B : I32EnumAttrCase<"L2PROMO_64B", 1, "l2promo_64b">;
83+
def TensorMapL2Promo128B : I32EnumAttrCase<"L2PROMO_128B", 2, "l2promo_128b">;
84+
def TensorMapL2Promo256B : I32EnumAttrCase<"L2PROMO_256B", 3, "l2promo_256b">;
85+
def TensorMapL2PromoKind : I32EnumAttr<"TensorMapL2PromoKind",
86+
"Tensor map L2 promotion type",
87+
[ TensorMapL2PromoNone, TensorMapL2Promo64B, TensorMapL2Promo128B,
88+
TensorMapL2Promo256B]> {
89+
let genSpecializedAttr = 0;
90+
let cppNamespace = "::mlir::nvgpu";
91+
}
92+
93+
def TensorMapOOBZero : I32EnumAttrCase<"OOB_ZERO", 0, "zero">;
94+
def TensorMapOOBNaN : I32EnumAttrCase<"OOB_NAN", 1, "nan">;
95+
def TensorMapOOBKind : I32EnumAttr<"TensorMapOOBKind",
96+
"Tensor map out-of-bounds fill type",
97+
[ TensorMapOOBZero, TensorMapOOBNaN]> {
98+
let genSpecializedAttr = 0;
99+
let cppNamespace = "::mlir::nvgpu";
100+
}
101+
102+
def TensorMapInterleaveNone : I32EnumAttrCase<"INTERLEAVE_NONE", 0, "none">;
103+
def TensorMapInterleave16B : I32EnumAttrCase<"INTERLEAVE_16B", 1, "interleave_16b">;
104+
def TensorMapInterleave32B : I32EnumAttrCase<"INTERLEAVE_32B", 2, "interleave_32b">;
105+
def TensorMapInterleaveKind : I32EnumAttr<"TensorMapInterleaveKind",
106+
"Tensor map interleave layout type",
107+
[ TensorMapInterleaveNone, TensorMapInterleave16B, TensorMapInterleave32B]> {
108+
let genSpecializedAttr = 0;
109+
let cppNamespace = "::mlir::nvgpu";
110+
}
111+
112+
def TensorMapSwizzleAttr : EnumAttr<NVGPU_Dialect, TensorMapSwizzleKind, "swizzle">;
113+
def TensorMapL2PromoAttr : EnumAttr<NVGPU_Dialect, TensorMapL2PromoKind, "l2promo">;
114+
def TensorMapOOBAttr : EnumAttr<NVGPU_Dialect, TensorMapOOBKind, "oob">;
115+
def TensorMapInterleaveAttr : EnumAttr<NVGPU_Dialect, TensorMapInterleaveKind, "interleave">;
116+
64117
//===----------------------------------------------------------------------===//
65118
// NVGPU Type Definitions
66119
//===----------------------------------------------------------------------===//
@@ -100,6 +153,21 @@ def NVGPU_MBarrier : NVGPU_Type<"MBarrier", "mbarrier.barrier", []> {
100153

101154
def NVGPU_MBarrierToken : NVGPU_Type<"MBarrierToken", "mbarrier.token", []> { }
102155

156+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-map
157+
def NVGPU_TensorMapDescriptor : NVGPU_Type<"TensorMapDescriptor", "tensormap.descriptor", []> {
158+
let summary = "TensorMap descriptor";
159+
let parameters = (ins "MemRefType":$tensor,
160+
EnumParameter<TensorMapSwizzleKind>:$swizzle,
161+
EnumParameter<TensorMapL2PromoKind>:$l2promo,
162+
EnumParameter<TensorMapOOBKind>:$oob,
163+
EnumParameter<TensorMapInterleaveKind>:$interleave);
164+
let description = [{
165+
`nvgpu.tma.descriptor` is a type that represents a TMA descriptor. It is
166+
128-byte object either in constant space or kernel paramater.
167+
}];
168+
let assemblyFormat = "`<` struct(params) `>`";
169+
}
170+
103171
//===----------------------------------------------------------------------===//
104172
// NVGPU Op Definitions
105173
//===----------------------------------------------------------------------===//
@@ -509,4 +577,27 @@ def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
509577
let assemblyFormat = "$barrier `,` $phase `,` $ticks attr-dict `:` type($barrier)";
510578
}
511579

580+
def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", []> {
581+
let summary = "TMA asynchronous load";
582+
let description = [{
583+
The Op loads a tile memory region from global memory to shared memory by
584+
Tensor Memory Access (TMA).
585+
586+
`$tensorMapDescriptor` is tensor map descriptor which has information about
587+
tile shape. The descriptor is created by `nvgpu.tma.create.descriptor`
588+
589+
The Op uses `$barrier` mbarrier based completion mechanism.
590+
}];
591+
let arguments = (ins Arg<AnyMemRef, "", [MemWrite]>:$dst,
592+
NVGPU_MBarrier:$barrier,
593+
NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
594+
Variadic<Index>:$coordinates);
595+
let assemblyFormat = [{
596+
$tensorMapDescriptor `[` $coordinates `]` `,` $barrier `to` $dst
597+
attr-dict `:` type($tensorMapDescriptor) `,` type($barrier) `->` type($dst)
598+
}];
599+
let hasVerifier = 1;
600+
601+
}
602+
512603
#endif // NVGPU

mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
#include "mlir/IR/OpDefinition.h"
2020
#include "mlir/Interfaces/SideEffectInterfaces.h"
2121

22+
#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc"
23+
24+
#define GET_ATTRDEF_CLASSES
25+
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
26+
2227
#define GET_TYPEDEF_CLASSES
2328
#include "mlir/Dialect/NVGPU/IR/NVGPUTypes.h.inc"
2429

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,9 @@ struct ConvertNVGPUToNVVMPass
413413
converter.addConversion([&](nvgpu::MBarrierType type) -> Type {
414414
return converter.convertType(createMBarrierMemrefType(rewriter, type));
415415
});
416+
converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
417+
return converter.getPointerType(type.getTensor().getElementType());
418+
});
416419
populateNVGPUToNVVMConversionPatterns(converter, patterns);
417420
LLVMConversionTarget target(getContext());
418421
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
@@ -770,11 +773,7 @@ struct NVGPUMBarrierInitLowering
770773
Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(),
771774
op.getBarrier(), adaptor.getBarrier());
772775

773-
Value count = adaptor.getCount();
774-
if (!adaptor.getCount().getType().isInteger(32)) {
775-
count = rewriter.create<LLVM::TruncOp>(op->getLoc(),
776-
rewriter.getI32Type(), count);
777-
}
776+
Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
778777

779778
if (isMbarrierShared(op.getBarrier().getType())) {
780779
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
@@ -822,11 +821,7 @@ struct NVGPUMBarrierArriveNoCompleteLowering
822821
op.getBarrier(), adaptor.getBarrier());
823822
Type tokenType = getTypeConverter()->convertType(
824823
nvgpu::MBarrierTokenType::get(op->getContext()));
825-
Value count = adaptor.getCount();
826-
if (!adaptor.getCount().getType().isInteger(32)) {
827-
count = rewriter.create<LLVM::TruncOp>(op->getLoc(),
828-
rewriter.getI32Type(), count);
829-
}
824+
Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
830825
if (isMbarrierShared(op.getBarrier().getType())) {
831826
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
832827
op, tokenType, barrier, count);
@@ -910,6 +905,27 @@ struct NVGPUMBarrierTryWaitParityLowering
910905
}
911906
};
912907

908+
struct NVGPUTmaAsyncLoadOpLowering
909+
: public ConvertOpToLLVMPattern<nvgpu::TmaAsyncLoadOp> {
910+
using ConvertOpToLLVMPattern<nvgpu::TmaAsyncLoadOp>::ConvertOpToLLVMPattern;
911+
LogicalResult
912+
matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
913+
ConversionPatternRewriter &rewriter) const override {
914+
auto dest = rewriter.create<LLVM::ExtractValueOp>(op->getLoc(),
915+
adaptor.getDst(), 1);
916+
Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(),
917+
op.getBarrier(), adaptor.getBarrier());
918+
919+
SmallVector<Value> coords = adaptor.getCoordinates();
920+
for (auto [index, value] : llvm::enumerate(coords)) {
921+
coords[index] = truncToI32(rewriter, op->getLoc(), value);
922+
}
923+
924+
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
925+
op, dest, adaptor.getTensorMapDescriptor(), barrier, coords);
926+
return success();
927+
}
928+
};
913929
} // namespace
914930

915931
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -922,6 +938,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
922938
NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
923939
NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
924940
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
941+
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
925942
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
926943
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
927944
NVGPUMmaSparseSyncLowering>(converter);

mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ add_mlir_dialect_library(MLIRNVGPUDialect
66

77
DEPENDS
88
MLIRNVGPUIncGen
9+
MLIRNVGPUEnumsIncGen
10+
MLIRNVGPUAttributesIncGen
11+
MLIRNVGPUTypesIncGen
912

1013
LINK_LIBS PUBLIC
1114
MLIRGPUDialect

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,31 @@
1414
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1515
#include "mlir/IR/Builders.h"
1616
#include "mlir/IR/BuiltinAttributes.h"
17+
#include "mlir/IR/BuiltinTypes.h"
1718
#include "mlir/IR/Diagnostics.h"
1819
#include "mlir/IR/DialectImplementation.h"
20+
#include "mlir/IR/Matchers.h"
1921
#include "mlir/IR/OpImplementation.h"
22+
#include "mlir/IR/PatternMatch.h"
2023
#include "mlir/IR/TypeUtilities.h"
2124
#include "mlir/IR/Verifier.h"
25+
#include "llvm/ADT/StringExtras.h"
2226
#include "llvm/ADT/TypeSwitch.h"
2327

2428
using namespace mlir;
2529
using namespace mlir::nvgpu;
2630

31+
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
32+
2733
void nvgpu::NVGPUDialect::initialize() {
2834
addTypes<
2935
#define GET_TYPEDEF_LIST
3036
#include "mlir/Dialect/NVGPU/IR/NVGPUTypes.cpp.inc"
3137
>();
38+
addAttributes<
39+
#define GET_ATTRDEF_LIST
40+
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
41+
>();
3242
addOperations<
3343
#define GET_OP_LIST
3444
#include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
@@ -320,11 +330,39 @@ LogicalResult LdMatrixOp::verify() {
320330
return success();
321331
}
322332

333+
//===----------------------------------------------------------------------===//
334+
// NVGPU_TmaAsyncLoadOp
335+
//===----------------------------------------------------------------------===//
336+
337+
LogicalResult TmaAsyncLoadOp::verify() {
338+
// Destination memref
339+
auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
340+
if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
341+
return emitError()
342+
<< "The operation stores data to shared memory, but "
343+
"the destination memref does not have a memory space of "
344+
<< NVGPUDialect::kSharedMemoryAddressSpace;
345+
}
346+
if (getCoordinates().size() > 5) {
347+
return emitError() << "Maximum 5 coordinates are supported.";
348+
}
349+
if (getCoordinates().size() != size_t(dstMemref.getRank())) {
350+
return emitError() << "Destination memref rank is "
351+
<< size_t(dstMemref.getRank()) << " but there are "
352+
<< getCoordinates().size()
353+
<< " coordinates. They must match.";
354+
}
355+
return success();
356+
}
357+
323358
//===----------------------------------------------------------------------===//
324359
// TableGen'd dialect, type, and op definitions
325360
//===----------------------------------------------------------------------===//
326361

327-
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
362+
#define GET_ATTRDEF_CLASSES
363+
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
364+
365+
#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
328366

329367
#define GET_OP_CLASSES
330368
#include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,6 @@ func.func @mbarrier_nocomplete() {
559559
func.return
560560
}
561561

562-
563562
// -----
564563
!barrierType = !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
565564
!tokenType = !nvgpu.mbarrier.token
@@ -603,4 +602,36 @@ func.func @mbarrier_txcount() {
603602
nvgpu.mbarrier.try_wait.parity %barrier, %phase, %ticks : !barrierType
604603

605604
func.return
606-
}
605+
}
606+
607+
// -----
608+
609+
// CHECK-LABEL: func @async_tma_load
610+
!tensorMap1d = !nvgpu.tensormap.descriptor<tensor = memref<128xf32,3>, swizzle=none, l2promo = none, oob = nan, interleave = interleave_16b>
611+
!tensorMap2d = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
612+
!tensorMap3d = !nvgpu.tensormap.descriptor<tensor = memref<2x32x32xf32,3>, swizzle=swizzle_64b, l2promo = l2promo_64b, oob = zero, interleave = none>
613+
!tensorMap4d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x32x32xf32,3>, swizzle=swizzle_128b,l2promo = l2promo_128b,oob = zero, interleave = none>
614+
!tensorMap5d = !nvgpu.tensormap.descriptor<tensor = memref<2x2x2x32x32xf32,3>, swizzle=none, l2promo = none, oob = zero, interleave = none>
615+
!mbarrier = !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
616+
func.func @async_tma_load(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
617+
%buffer1d: memref<128xf32,3>,
618+
%buffer2d: memref<32x32xf32,3>,
619+
%buffer3d: memref<2x32x32xf32,3>,
620+
%buffer4d: memref<2x2x32x32xf32,3>,
621+
%buffer5d: memref<2x2x2x32x32xf32,3>,
622+
%mbarrier: !mbarrier) {
623+
%crd0 = arith.constant 0 : index
624+
%crd1 = arith.constant 0 : index
625+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}]
626+
nvgpu.tma.async.load %tensorMap1d[%crd0], %mbarrier to %buffer1d : !tensorMap1d, !mbarrier -> memref<128xf32,3>
627+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
628+
nvgpu.tma.async.load %tensorMap2d[%crd0, %crd1], %mbarrier to %buffer2d : !tensorMap2d, !mbarrier -> memref<32x32xf32,3>
629+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}]
630+
nvgpu.tma.async.load %tensorMap3d[%crd0, %crd1, %crd0], %mbarrier to %buffer3d : !tensorMap3d, !mbarrier -> memref<2x32x32xf32,3>
631+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
632+
nvgpu.tma.async.load %tensorMap4d[%crd0, %crd1, %crd1, %crd0], %mbarrier to %buffer4d : !tensorMap4d, !mbarrier -> memref<2x2x32x32xf32,3>
633+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
634+
nvgpu.tma.async.load %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], %mbarrier to %buffer5d : !tensorMap5d, !mbarrier -> memref<2x2x2x32x32xf32,3>
635+
func.return
636+
}
637+

0 commit comments

Comments
 (0)