Skip to content

Commit 2388222

Browse files
authored
[MLIR][NVGPU] Adding nvgpu.warpgroup.mma Op for Hopper GPUs (#65440)
This work introduces a new operation called `warpgroup.mma` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate warpgroup-level matrix multiply and accumulate (WGMMA) operations on Hopper GPUs with sm_90a architecture. Previously, the `nvvm.wgmma.mma_async` operation was introduced to support warpgroup-level matrix operations in NVVM dialect. This op is used multiple instances of `nvvm.wgmma.mma_async` to achieve the desired shape. The new `nvgpu.warpgroup.mma` operation abstracts this complexity and provides a higher-level interface for performing warpgroup-level matrix operations. The `nvgpu.warpgroup.mma` does followings: 1) Corresponds multiple `wgmma` instructions. 2) Iterates input matrix descriptors to achieve the desired computation shape. 3) Groups and runs `wgmma` instructions asynchronously, and eventually waits them. This are done by `wgmma.fence.aligned`, `wgmma.commit.group.sync.aligned`, and `wgmma.wait.group.sync.aligned` 4) Results fragmented matrices Here's an example usage of the `nvgpu.warpgroup.mma` operation: ``` %wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2 {transposeB}: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> -> !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> ``` The op will result following PTX: ``` wgmma.fence.sync.aligned; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f1, %f2, 62 more registers}, %descA, %descB, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f1, %f2, 62 more registers}, %descA+2, %descB+128, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f1, %f2, 62 more registers}, %descA+4, %descB+256, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f1, %f2, 62 more registers}, %descA+8, %descB+348, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f500,%f501, 62 more registers}, %descA+512, %descB, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f500,%f501, 62 more registers}, %descA+514, %descB+128, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f500,%f501, 62 more registers}, %descA+516, %descB+256, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f500,%f501, 62 more registers}, %descA+518, %descB+348, p, 1, 1, 0, 1; wgmma.commit_group.sync.aligned; wgmma.wait_group.sync.aligned 1; ``` The Op keeps - first 64 registers (`{%f1, %f2, 62 more registers}`) -> `%acc1` - second 64 registers (`{%f500,%f501, 62 more registers}`) -> `%acc2`.
1 parent 6d26799 commit 2388222

File tree

8 files changed

+474
-11
lines changed

8 files changed

+474
-11
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,9 +1610,9 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
16101610
PredOpTrait<"input struct and result struct must be the same type",
16111611
TCresIsSameAsOpBase<0, 0>>,]>
16121612
{
1613-
let results = (outs LLVM_AnyAggregate:$results);
1613+
let results = (outs LLVM_AnyStruct:$results);
16141614
let arguments = (ins
1615-
LLVM_AnyAggregate:$inouts,
1615+
LLVM_AnyStruct:$inouts,
16161616
I64:$descriptorA,
16171617
I64:$descriptorB,
16181618
NVVM_MMAShapeAttr:$shape,

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,19 @@ def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "w
192192
let assemblyFormat = "`<` struct(params) `>`";
193193
}
194194

195+
def NVGPU_WarpgroupAccumulator : NVGPU_Type<"WarpgroupAccumulator", "warpgroup.accumulator", []> {
196+
let parameters = (ins "VectorType":$fragmented);
197+
let assemblyFormat = "`<` struct(params) `>`";
198+
let description = [{
199+
This type represents the result matrix obtained from `nvgpu.warpgroup.mma`.
200+
The `$fragmented` type signifies the distributed or fragmented result
201+
vector that is collectively owned by all the threads in the warp-group
202+
that executed `nvgpu.warpgroup.mma`.
203+
[See the details of register fragment layout for accumulator matrix D]
204+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
205+
}];
206+
}
207+
195208
//===----------------------------------------------------------------------===//
196209
// NVGPU Op Definitions
197210
//===----------------------------------------------------------------------===//
@@ -664,5 +677,48 @@ def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> {
664677
let hasVerifier = 1;
665678
}
666679

680+
def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
681+
let description = [{
682+
The `nvgpu.warpgroup.mma` op performs the warpgroup-level (4 warps)
683+
matrix-multiply-and-accumulate (mma) operation that results in
684+
`nvvm.wgmma.mma_async`.
685+
686+
The operands are `descriptorA` and `descriptorB` that are wgmma matrix
687+
descriptors that shows the properties of the matrix in shared memory. The
688+
results are thread-level ownership to the warpgroup-level mma operation
689+
shape. The shape is deduced from the descriptor types and output vector.
690+
691+
The Op corresponds multiple `nvvm.wgmma.mma_async` operations to complete the
692+
given shape. As the instruction `nvvm.wgmma.async` is an asynchronous,
693+
this Op groups the `nvvm.wgmma.async` and surrounds them between
694+
`wgmma.fence.aligned` and `wgmma.commit.group.sync.aligned`,
695+
`wgmma.wait.group.sync.aligned` Ops.
696+
697+
Example:
698+
```mlir
699+
%r1,%r2 = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc1, %acc2:
700+
!nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
701+
!nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
702+
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
703+
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
704+
->
705+
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
706+
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
707+
```
708+
}];
709+
710+
let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA,
711+
NVGPU_WarpgroupMatrixDescriptor:$descriptorB,
712+
DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
713+
OptionalAttr<UnitAttr>:$transposeA,
714+
OptionalAttr<UnitAttr>:$transposeB,
715+
Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
716+
let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixD);
717+
let assemblyFormat = [{
718+
$descriptorA`,` $descriptorB`,` $matrixC attr-dict
719+
`:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
720+
}];
721+
let hasVerifier = 1;
722+
}
667723

668724
#endif // NVGPU

mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc"
2323

24+
constexpr int kWarpSize = 32;
25+
2426
#define GET_ATTRDEF_CLASSES
2527
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
2628

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 162 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
1818
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1919
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
20+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
2021
#include "mlir/IR/PatternMatch.h"
2122
#include "mlir/IR/TypeUtilities.h"
2223
#include "mlir/Pass/Pass.h"
2324
#include "llvm/Support/Debug.h"
25+
#include "llvm/Support/ErrorHandling.h"
2426
#include "llvm/Support/raw_ostream.h"
2527

2628
#define DEBUG_TYPE "nvgpu-to-nvvm"
@@ -34,6 +36,10 @@ namespace mlir {
3436

3537
using namespace mlir;
3638

39+
/// Number of bits that needs to excluded when building matrix descriptor for
40+
/// wgmma operations.
41+
constexpr int exclude4LSB = 4;
42+
3743
/// GPU has 32 bit registers, this function truncates values when larger width
3844
/// is not needed.
3945
static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc,
@@ -419,6 +425,15 @@ struct ConvertNVGPUToNVVMPass
419425
converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
420426
return converter.convertType(IntegerType::get(type.getContext(), 32));
421427
});
428+
converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
429+
VectorType vtype = type.getFragmented();
430+
SmallVector<Type> structBody;
431+
for (unsigned i = 0; i < vtype.getDimSize(0); i++)
432+
structBody.push_back(vtype.getElementType());
433+
auto convertedType =
434+
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
435+
return converter.convertType(convertedType);
436+
});
422437
converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
423438
return converter.convertType(IntegerType::get(type.getContext(), 64));
424439
});
@@ -438,6 +453,8 @@ struct ConvertNVGPUToNVVMPass
438453
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
439454
target.addLegalDialect<::mlir::memref::MemRefDialect>();
440455
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
456+
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
457+
converter, patterns, target);
441458
if (failed(applyPartialConversion(getOperation(), target,
442459
std::move(patterns))))
443460
signalPassFailure();
@@ -984,10 +1001,9 @@ struct NVGPUGenerateGmmaDescriptorLowering
9841001
shiftLeft(val, startBit));
9851002
};
9861003

987-
int ex4LSB = 4;
9881004
int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
989-
uint64_t strideDimVal = (layout << 3) >> ex4LSB;
990-
uint64_t leadDimVal = (sizeN * layout) >> ex4LSB;
1005+
uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1006+
uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
9911007
uint64_t offsetVal = 0;
9921008

9931009
Value strideDim = makeConst(strideDimVal);
@@ -1141,6 +1157,148 @@ struct NVGPUTmaCreateDescriptorOpLowering
11411157
}
11421158
};
11431159

1160+
struct NVGPUWarpgroupMmaOpLowering
1161+
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1162+
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1163+
1164+
LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
1165+
int &wgmmaShapeM, int &wgmmaShapeN,
1166+
int &wgmmaShapeK) const {
1167+
wgmmaShapeM = 64;
1168+
wgmmaShapeN = sizeN;
1169+
if (inputElemType.isTF32()) {
1170+
wgmmaShapeK = 8;
1171+
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
1172+
wgmmaShapeK = 16;
1173+
} else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
1174+
inputElemType.isInteger(16)) {
1175+
wgmmaShapeK = 32;
1176+
} else if (inputElemType.isInteger(1)) {
1177+
wgmmaShapeK = 256;
1178+
} else {
1179+
llvm_unreachable("msg: not supported K shape");
1180+
}
1181+
LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
1182+
<< ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
1183+
<< "]\n");
1184+
return success();
1185+
}
1186+
1187+
Value generateNVVMWgmmaOp(MLIRContext *ctx,
1188+
ConversionPatternRewriter &rewriter, Location loc,
1189+
int m, int n, int k, Type resultStructType,
1190+
Value inout, Value descriptorA,
1191+
Value descriptorB) const {
1192+
auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
1193+
auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
1194+
auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
1195+
auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
1196+
auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
1197+
// todo: handle other input and output types
1198+
auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
1199+
auto overflow =
1200+
NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
1201+
Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
1202+
loc, resultStructType, inout, descriptorA, descriptorB, shape, itype,
1203+
itype, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
1204+
return res;
1205+
}
1206+
1207+
LogicalResult
1208+
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1209+
ConversionPatternRewriter &rewriter) const override {
1210+
int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1211+
int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1212+
int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1213+
1214+
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
1215+
<< sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
1216+
<< sizeN << "] ---===\n");
1217+
1218+
int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
1219+
if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
1220+
wgmmaShapeN, wgmmaShapeK))) {
1221+
return failure();
1222+
}
1223+
1224+
Value descriptorA = adaptor.getDescriptorA();
1225+
Value descriptorB = adaptor.getDescriptorB();
1226+
1227+
// Generate wgmma group
1228+
1229+
auto loc = op->getLoc();
1230+
MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
1231+
MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
1232+
1233+
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1234+
return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
1235+
};
1236+
1237+
auto iterateDescA = [&](Value desc, int iterM, int iterN,
1238+
int iterK) -> Value {
1239+
// todo : Handle column major
1240+
int byte = typeTensorA.getElementTypeBitWidth() / 8;
1241+
int tileShapeA = typeTensorA.getDimSize(1);
1242+
int incrementVal =
1243+
((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
1244+
incrementVal = incrementVal >> exclude4LSB;
1245+
LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
1246+
<< iterK << "] [wgmma descriptors] Descriptor A + "
1247+
<< incrementVal << " | \t ");
1248+
if (!incrementVal)
1249+
return desc;
1250+
return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
1251+
};
1252+
1253+
auto iterateDescB = [&](Value desc, int iterM, int iterN,
1254+
int iterK) -> Value {
1255+
// todo : Handle row major
1256+
int byte = typeTensorB.getElementTypeBitWidth() / 8;
1257+
int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
1258+
incrementVal = incrementVal >> exclude4LSB;
1259+
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
1260+
if (!incrementVal)
1261+
return desc;
1262+
return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
1263+
};
1264+
1265+
rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
1266+
1267+
SmallVector<Value> wgmmaResults;
1268+
for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
1269+
Value matrixC = adaptor.getMatrixC()[iterM];
1270+
Value matrixD = op.getMatrixD()[iterM];
1271+
Type structType = getTypeConverter()->convertType(matrixD.getType());
1272+
LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
1273+
<< (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
1274+
<< ":" << wgmmaShapeN << "] += \n");
1275+
for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
1276+
Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
1277+
Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
1278+
LLVM_DEBUG(DBGS() << "\t wgmma."
1279+
<< "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
1280+
<< wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
1281+
<< ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
1282+
<< (iterK * wgmmaShapeK) << ":"
1283+
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
1284+
<< " B[" << (iterK * wgmmaShapeK) << ":"
1285+
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
1286+
<< ":" << wgmmaShapeN << "])\n");
1287+
matrixC = generateNVVMWgmmaOp(op->getContext(), rewriter, loc,
1288+
wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
1289+
structType, matrixC, descA, descB);
1290+
}
1291+
wgmmaResults.push_back(matrixC);
1292+
}
1293+
rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
1294+
rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
1295+
1296+
ValueRange myres(wgmmaResults);
1297+
rewriter.replaceOp(op, myres);
1298+
return success();
1299+
}
1300+
};
1301+
11441302
} // namespace
11451303

11461304
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1156,6 +1314,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
11561314
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
11571315
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
11581316
NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor
1317+
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
11591318
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
11601319
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
11611320
NVGPUMmaSparseSyncLowering>(converter);

0 commit comments

Comments
 (0)