Skip to content

Commit 6dc7717

Browse files
authored
[MLIR][NVGPU] Change name wgmma.descriptor to warpgroup.descriptor (NFC) (#67526)
NVGPU dialect is gaining large support for warpgroup level operations, and their names always starts with `warpgroup....`. This PR changes name of Op and type from `wgmma.descriptor` to `warpgroup.descriptor` for sake of consistency.
1 parent 5ef904b commit 6dc7717

File tree

5 files changed

+37
-36
lines changed

5 files changed

+37
-36
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

+9-8
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def NVGPU_TensorMapDescriptor : NVGPU_Type<"TensorMapDescriptor", "tensormap.des
174174
let assemblyFormat = "`<` struct(params) `>`";
175175
}
176176

177-
def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "wgmma.descriptor", []> {
177+
def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "warpgroup.descriptor", []> {
178178
let summary = "Warpgroup matrix descriptor type";
179179
let description = [{
180180
The descriptor specifies the properties of the matrix in shared memory that
@@ -667,11 +667,12 @@ def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
667667
let hasVerifier = 1;
668668
}
669669

670-
def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> {
671-
let summary = "Generate a wgmma matrix descriptor";
670+
def NVGPU_GenerateWarpgroupDescriptorOp : NVGPU_Op<"warpgroup.generate.descriptor", []> {
671+
let summary = "Generate a warpgroup matrix descriptor";
672672
let description = [{
673-
This Op builds a `nvgpu.wgmma.descriptor` that is used by warpgroup-level
674-
matrix multiply and accumulate.
673+
This Op builds a `nvgpu.warpgroup.descriptor` that is used by
674+
`nvgpu.warpgroup.mma` to perform warpgroup-level matrix multiply and
675+
accumulate.
675676

676677
The descriptor specifies the properties of the matrix in shared memory that
677678
is a multiplicand in the matrix multiply and accumulate operation.
@@ -702,9 +703,9 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
702703

703704
Example:
704705
```mlir
705-
%r1,%r2 = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc1, %acc2:
706-
!nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
707-
!nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
706+
%r1,%r2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2:
707+
!nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
708+
!nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
708709
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
709710
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
710711
->

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -967,13 +967,13 @@ struct NVGPUTmaAsyncLoadOpLowering
967967
return success();
968968
}
969969
};
970-
struct NVGPUGenerateGmmaDescriptorLowering
971-
: public ConvertOpToLLVMPattern<nvgpu::GenerateGmmaDescriptorOp> {
970+
struct NVGPUGenerateWarpgroupDescriptorLowering
971+
: public ConvertOpToLLVMPattern<nvgpu::GenerateWarpgroupDescriptorOp> {
972972
using ConvertOpToLLVMPattern<
973-
nvgpu::GenerateGmmaDescriptorOp>::ConvertOpToLLVMPattern;
973+
nvgpu::GenerateWarpgroupDescriptorOp>::ConvertOpToLLVMPattern;
974974

975975
LogicalResult
976-
matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor,
976+
matchAndRewrite(nvgpu::GenerateWarpgroupDescriptorOp op, OpAdaptor adaptor,
977977
ConversionPatternRewriter &rewriter) const override {
978978

979979
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
@@ -1035,7 +1035,7 @@ struct NVGPUGenerateGmmaDescriptorLowering
10351035
// // [0,14) start_address
10361036
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
10371037

1038-
LLVM_DEBUG(DBGS() << "Generating wgmma.descriptor: "
1038+
LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
10391039
<< "leading_off:" << leadDimVal << "\t"
10401040
<< "stride_off :" << strideDimVal << "\t"
10411041
<< "base_offset:" << offsetVal << "\t"
@@ -1309,8 +1309,8 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
13091309
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
13101310
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
13111311
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1312-
NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor
1313-
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1312+
NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1313+
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
13141314
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
13151315
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
13161316
NVGPUMmaSparseSyncLowering>(converter);

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,10 @@ LogicalResult TmaCreateDescriptorOp::verify() {
367367
}
368368

369369
//===----------------------------------------------------------------------===//
370-
// NVGPU_GenerateGmmaDescriptorOp
370+
// NVGPU_GenerateWarpgroupDescriptorOp
371371
//===----------------------------------------------------------------------===//
372372

373-
LogicalResult GenerateGmmaDescriptorOp::verify() {
373+
LogicalResult GenerateWarpgroupDescriptorOp::verify() {
374374
MemRefType memrefType = getTensor().getType();
375375
MemRefType tensorMapType = getTensorMap().getType().getTensor();
376376

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

+11-11
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ module @mymodule {
674674
!tensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16,3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
675675
memref.global "private" @dynamicShmem : memref<0xf16,3>
676676
// CHECK-LABEL: func @create_wgmma_descriptor(
677-
func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>{
677+
func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16,3>>{
678678
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
679679
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
680680
// CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3>
@@ -706,22 +706,22 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.desc
706706
// CHECK: %[[S25:.+]] = llvm.mlir.constant(0 : i64) : i64
707707
// CHECK: %[[S26:.+]] = llvm.shl %[[S7]], %[[S25]] : i64
708708
// CHECK: %[[S27:.+]] = llvm.or %[[S24]], %[[S26]] : i64
709-
// CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
709+
// CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
710710
// CHECK: return %[[ret]]
711-
%descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap -> !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
712-
func.return %descA : !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
711+
%descA = nvgpu.warpgroup.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16,3>>
712+
func.return %descA : !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16,3>>
713713
}
714714

715715
// CHECK-LABEL: @warpgroup_mma_128_128_64(
716-
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
716+
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
717717
func.func @warpgroup_mma_128_128_64(
718-
%descA: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
719-
%descB: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
718+
%descA: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
719+
%descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
720720
%acc1: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
721721
%acc2: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>)
722722
{
723-
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>> to i64
724-
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>> to i64
723+
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
724+
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
725725
// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
726726
// CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[arg3]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
727727
// CHECK: nvvm.wgmma.fence.aligned
@@ -762,8 +762,8 @@ func.func @warpgroup_mma_128_128_64(
762762
// CHECK: nvvm.wgmma.commit.group.sync.aligned
763763
// CHECK: nvvm.wgmma.wait.group.sync.aligned 1
764764
%wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2 {transposeB}:
765-
!nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
766-
!nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
765+
!nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
766+
!nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
767767
!nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
768768
!nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
769769
->

mlir/test/Dialect/NVGPU/invalid.mlir

+8-8
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ func.func @async_cp_size_invalid_f64(
225225
// -----
226226

227227
!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
228-
!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
229-
!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x121xf16, 3>>
228+
!tDescA = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
229+
!tDescB = !nvgpu.warpgroup.descriptor<tensor = memref<64x121xf16, 3>>
230230

231231
func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
232232
// expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}}
@@ -237,8 +237,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t
237237
// -----
238238

239239
!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<128xf32>>
240-
!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
241-
!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>
240+
!tDescA = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
241+
!tDescB = !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>
242242
func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
243243
// expected-error @+1 {{'nvgpu.warpgroup.mma' op has matrices A, B, C and D, they must be 2 dimensional}}
244244
%0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
@@ -247,8 +247,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t
247247

248248
// -----
249249
!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
250-
!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
251-
!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf32, 3>>
250+
!tDescA = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
251+
!tDescB = !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf32, 3>>
252252
func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
253253
// expected-error @+1 {{'nvgpu.warpgroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}}
254254
%0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
@@ -258,8 +258,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !t
258258
// -----
259259

260260
!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
261-
!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
262-
!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x512xf16, 3>>
261+
!tDescA = !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>
262+
!tDescB = !nvgpu.warpgroup.descriptor<tensor = memref<64x512xf16, 3>>
263263
func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
264264
// expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 512 ) != 2nd dim matrix-C ( 128 )}}
265265
%0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult

0 commit comments

Comments
 (0)