Skip to content

Commit a119578

Browse files
committed
[mlir][nvgpu] Improve nvgpu->nvvm transformation of nvgpu.warpgroup.mma Op (NFC)
This PR introduces substantial improvements to the readability and maintainability of the `nvgpu.warpgroup.mma` Op transformation from nvgpu->nvvm. This transformation plays a crucial role in GEMM and manages complex operations such as generating multiple wgmma ops and iterating their descriptors. The prior code lacked clarity, but this PR addresses that issue effectively. PR introduces a helper class `WarpgroupGemm`. This class encapsulates the necessary functionality, making the code cleaner and more understandable. Each function within the helper class is thoroughly documented to provide clear insights into its purpose and functionality.
1 parent ba149f6 commit a119578

File tree

1 file changed

+240
-110
lines changed

1 file changed

+240
-110
lines changed

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 240 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace mlir {
3939

4040
using namespace mlir;
4141

42-
/// Number of bits that needs to excluded when building matrix descriptor for
42+
/// Number of bits that needs to be excluded when building matrix descriptor for
4343
/// wgmma operations.
4444
constexpr int exclude4LSB = 4;
4545

@@ -1160,137 +1160,267 @@ struct NVGPUWarpgroupMmaOpLowering
11601160
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
11611161
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
11621162

1163-
LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
1164-
int &wgmmaShapeM, int &wgmmaShapeN,
1165-
int &wgmmaShapeK) const {
1166-
wgmmaShapeM = 64;
1167-
wgmmaShapeN = sizeN;
1168-
if (inputElemType.isTF32()) {
1169-
wgmmaShapeK = 8;
1170-
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
1171-
wgmmaShapeK = 16;
1172-
} else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
1173-
inputElemType.isInteger(16)) {
1174-
wgmmaShapeK = 32;
1175-
} else if (inputElemType.isInteger(1)) {
1176-
wgmmaShapeK = 256;
1177-
} else {
1178-
llvm_unreachable("msg: not supported K shape");
1163+
/// This class assists in generating WgmmaMmaAsyncOp instructions to complete
1164+
/// a specified shape. If the GEMM shape is larger than the shape of a wgmma
1165+
/// instrution, it can generate multiple wgmma instructions, group and execute
1166+
/// them asynchronously. The class also handles waiting for instruction
1167+
/// completion and iterates through GenerateGmmaDescriptor to create
1168+
/// descriptors for each instruction.
1169+
class WarpgroupGemm {
1170+
nvgpu::WarpgroupMmaOp op;
1171+
ConversionPatternRewriter &rewriter;
1172+
OpAdaptor adaptor;
1173+
const LLVMTypeConverter &typeConverter;
1174+
1175+
// Entire shape of the given Op
1176+
int64_t totalM, totalN, totalK;
1177+
1178+
// Shape of one wgmma instruction
1179+
int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1180+
1181+
// Iteration counts for GEMM
1182+
int iterationM = 0, iterationN = 0, iterationK = 0;
1183+
1184+
/// The function returns the shape of wgmma instruction that is defined in
1185+
/// PTX programming guide.
1186+
/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1187+
void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1188+
wgmmaM = 64;
1189+
wgmmaN = sizeN;
1190+
if (inputElemType.isTF32()) {
1191+
wgmmaK = 8;
1192+
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
1193+
wgmmaK = 16;
1194+
} else if (inputElemType.isFloat8E4M3FN() ||
1195+
inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
1196+
wgmmaK = 32;
1197+
} else if (inputElemType.isInteger(1)) {
1198+
wgmmaK = 256;
1199+
} else {
1200+
llvm_unreachable("msg: not supported K shape");
1201+
}
1202+
LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1203+
<< ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
11791204
}
1180-
LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
1181-
<< ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
1182-
<< "]\n");
1183-
return success();
1184-
}
11851205

1186-
Value generateNVVMWgmmaOp(ImplicitLocOpBuilder &b, int m, int n, int k,
1187-
Type resultStructType, Value inout,
1188-
Value descriptorA, Value descriptorB) const {
1189-
MLIRContext *ctx = b.getContext();
1190-
auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
1191-
auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
1192-
auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
1193-
auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
1194-
auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
1195-
// todo: handle other input and output types
1196-
auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
1197-
auto overflow =
1198-
NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
1199-
Value res = b.create<NVVM::WgmmaMmaAsyncOp>(
1200-
resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
1201-
scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
1202-
return res;
1203-
}
1204-
1205-
LogicalResult
1206-
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1207-
ConversionPatternRewriter &rewriter) const override {
1208-
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1209-
int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1210-
int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1211-
int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1212-
1213-
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
1214-
<< sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
1215-
<< sizeN << "] ---===\n");
1206+
/// Generates WGMMATypesAttr from MLIR Type
1207+
NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
1208+
auto getWgmmaType = [](Type elemType) {
1209+
if (elemType.isF32() || elemType.isTF32())
1210+
return NVVM::WGMMATypes::tf32;
1211+
if (elemType.isF16())
1212+
return NVVM::WGMMATypes::f16;
1213+
if (elemType.isBF16())
1214+
return NVVM::WGMMATypes::bf16;
1215+
if (elemType.isFloat8E4M3FN())
1216+
return NVVM::WGMMATypes::e4m3;
1217+
if (elemType.isFloat8E5M2())
1218+
return NVVM::WGMMATypes::e5m2;
1219+
if (elemType.isInteger(1))
1220+
return NVVM::WGMMATypes::b1;
1221+
if (elemType.isInteger(8))
1222+
return NVVM::WGMMATypes::s8;
1223+
if (elemType.isUnsignedInteger(8))
1224+
return NVVM::WGMMATypes::u8;
1225+
llvm_unreachable("unsupported type");
1226+
};
1227+
return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1228+
}
12161229

1217-
int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
1218-
if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
1219-
wgmmaShapeN, wgmmaShapeK))) {
1220-
return failure();
1230+
/// Generates layout attribute for the input matrix for wgmma instruction
1231+
NVVM::MMALayoutAttr
1232+
generateWgmmaLayout(std::optional<bool> transpose) const {
1233+
if (transpose.value_or(false))
1234+
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1235+
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
12211236
}
12221237

1223-
Value descriptorA = adaptor.getDescriptorA();
1224-
Value descriptorB = adaptor.getDescriptorB();
1238+
/// Generates shape attribute for wgmma instruction
1239+
NVVM::MMAShapeAttr generateWgmmaShape() const {
1240+
return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1241+
}
12251242

1226-
// Generate wgmma group
1227-
MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
1228-
MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
1243+
/// Generates scale attributes of output matrix for wgmma instruction
1244+
NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1245+
return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1246+
NVVM::WGMMAScaleOut::one);
1247+
}
1248+
/// Generates scale attributes of input matrix for wgmma instruction
1249+
NVVM::WGMMAScaleInAttr generateScaleIn() const {
1250+
return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1251+
NVVM::WGMMAScaleIn::one);
1252+
}
12291253

1230-
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1231-
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1254+
/// Basic function to generate Add
1255+
Value makeAdd(Value lhs, Value rhs) {
1256+
return rewriter.create<LLVM::AddOp>(op->getLoc(), lhs.getType(), lhs,
1257+
rhs);
12321258
};
12331259

1234-
auto iterateDescA = [&](Value desc, int iterM, int iterN,
1235-
int iterK) -> Value {
1236-
// todo : Handle column major
1237-
int byte = typeTensorA.getElementTypeBitWidth() / 8;
1238-
int tileShapeA = typeTensorA.getDimSize(1);
1239-
int incrementVal =
1240-
((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
1260+
/// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1261+
/// Currently, it only handles row-major.
1262+
///
1263+
/// It moves the pointer like below for [128][64] size:
1264+
/// +2 +4 +6
1265+
/// ↓ ↓ ↓
1266+
/// descA ---> +--+--+--+--+
1267+
/// |->|->|->|->|
1268+
/// | | | | |
1269+
/// | | | | |
1270+
/// | | | | |
1271+
/// descA+512---> +-----------+
1272+
/// | | | | |
1273+
/// | | | | |
1274+
/// | | | | |
1275+
/// | | | | |
1276+
/// +-----------+
1277+
///
1278+
Value iterateDescriptorA(Value desc, int i, int j, int k) {
1279+
MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1280+
Type elemA = matrixTypeA.getElementType();
1281+
int byte = elemA.getIntOrFloatBitWidth() / 8;
1282+
int tileShapeA = matrixTypeA.getDimSize(1);
1283+
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
12411284
incrementVal = incrementVal >> exclude4LSB;
1242-
LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
1243-
<< iterK << "] [wgmma descriptors] Descriptor A + "
1285+
LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1286+
<< "] [wgmma descriptors] Descriptor A + "
12441287
<< incrementVal << " | \t ");
12451288
if (!incrementVal)
12461289
return desc;
1247-
return makeAdd(desc, makeI64Const(b, incrementVal));
1248-
};
1290+
return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
1291+
}
12491292

1250-
auto iterateDescB = [&](Value desc, int iterM, int iterN,
1251-
int iterK) -> Value {
1252-
// todo : Handle row major
1253-
int byte = typeTensorB.getElementTypeBitWidth() / 8;
1254-
int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
1293+
/// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1294+
/// Currently, it only handles column-major.
1295+
///
1296+
/// It moves the pointer like below for [128][64] size:
1297+
/// descB ---> +--+--+--+--+--+--+--+--+
1298+
/// |↓ | | | | | | | |
1299+
/// |↓ | | | | | | | |
1300+
/// |↓ | | | | | | | |
1301+
/// |↓ | | | | | | | |
1302+
/// +--+--+--+--+--+--+--+--+
1303+
///
1304+
Value iterateDescriptorB(Value desc, int i, int j, int k) {
1305+
MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1306+
Type elemB = matrixTypeB.getElementType();
1307+
int byte = elemB.getIntOrFloatBitWidth() / 8;
1308+
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
12551309
incrementVal = incrementVal >> exclude4LSB;
12561310
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
12571311
if (!incrementVal)
12581312
return desc;
1259-
return makeAdd(desc, makeI64Const(b, incrementVal));
1260-
};
1313+
return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
1314+
}
1315+
1316+
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1317+
/// descriptors and arranges them based on induction variables: i, j, and k.
1318+
Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
1319+
LLVM_DEBUG(DBGS() << "\t wgmma."
1320+
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1321+
<< "(A[" << (iterationM * wgmmaM) << ":"
1322+
<< (iterationM * wgmmaM) + wgmmaM << "]["
1323+
<< (iterationK * wgmmaK) << ":"
1324+
<< (iterationK * wgmmaK + wgmmaK) << "] * "
1325+
<< " B[" << (iterationK * wgmmaK) << ":"
1326+
<< (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1327+
<< wgmmaN << "])\n");
1328+
1329+
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1330+
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1331+
1332+
Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1333+
NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1334+
1335+
Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1336+
NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1337+
1338+
NVVM::MMAShapeAttr shape = generateWgmmaShape();
1339+
NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1340+
NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1341+
NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1342+
NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
1343+
1344+
auto overflow = NVVM::MMAIntOverflowAttr::get(
1345+
op->getContext(), NVVM::MMAIntOverflow::wrapped);
1346+
1347+
Type resultStructType = typeConverter.convertType(matrixD.getType());
1348+
1349+
return rewriter.create<NVVM::WgmmaMmaAsyncOp>(
1350+
op->getLoc(), resultStructType, matrixC, descriptorA, descriptorB,
1351+
shape, itypeA, itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1352+
overflow);
1353+
}
12611354

1262-
b.create<NVVM::WgmmaFenceAlignedOp>();
1263-
1264-
SmallVector<Value> wgmmaResults;
1265-
for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
1266-
Value matrixC = adaptor.getMatrixC()[iterM];
1267-
Value matrixD = op.getMatrixD()[iterM];
1268-
Type structType = getTypeConverter()->convertType(matrixD.getType());
1269-
LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
1270-
<< (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
1271-
<< ":" << wgmmaShapeN << "] += \n");
1272-
for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
1273-
Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
1274-
Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
1275-
LLVM_DEBUG(DBGS() << "\t wgmma."
1276-
<< "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
1277-
<< wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
1278-
<< ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
1279-
<< (iterK * wgmmaShapeK) << ":"
1280-
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
1281-
<< " B[" << (iterK * wgmmaShapeK) << ":"
1282-
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
1283-
<< ":" << wgmmaShapeN << "])\n");
1284-
matrixC = generateNVVMWgmmaOp(b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
1285-
structType, matrixC, descA, descB);
1355+
/// Generates multiple wgmma instructions to complete the given GEMM shape
1356+
SmallVector<Value> generateWgmmaGroup() {
1357+
SmallVector<Value> wgmmaResults;
1358+
1359+
// Perform GEMM
1360+
for (int i = 0; i < iterationM; ++i) {
1361+
Value matrixC = adaptor.getMatrixC()[i];
1362+
Value matrixD = op.getMatrixD()[i];
1363+
for (int j = 0; j < iterationN; ++j)
1364+
for (int k = 0; k < iterationK; ++k)
1365+
matrixC = generateWgmma(i, j, k, matrixC, matrixD);
1366+
wgmmaResults.push_back(matrixC);
12861367
}
1287-
wgmmaResults.push_back(matrixC);
1368+
1369+
return wgmmaResults;
1370+
}
1371+
1372+
public:
1373+
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ConversionPatternRewriter &rewriter,
1374+
OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
1375+
: op(op), rewriter(rewriter), adaptor(adaptor),
1376+
typeConverter(typeConverter) {
1377+
// Find the entire GEMM Shape
1378+
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1379+
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1380+
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1381+
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1382+
<< "] += A[" << totalM << "][" << totalK << "] * B["
1383+
<< totalK << "][" << totalN << "] ---===\n");
1384+
1385+
// Find the shape for one wgmma instruction
1386+
findWgmmaShape(
1387+
totalM, totalN,
1388+
op.getDescriptorA().getType().getTensor().getElementType());
1389+
1390+
// Iterations counts to complete the given shape with wgmma shape
1391+
iterationM = totalM / wgmmaM;
1392+
iterationN = totalN / wgmmaN;
1393+
iterationK = totalK / wgmmaK;
12881394
}
1289-
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1290-
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
12911395

1292-
ValueRange myres(wgmmaResults);
1293-
rewriter.replaceOp(op, myres);
1396+
/// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1397+
/// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1398+
/// instructions and group synchronization, as well as waiting
1399+
/// (WgmmaGroupSyncAlignedOp) for group synchronization
1400+
/// (WgmmaWaitGroupSyncOp) after the instructions.
1401+
SmallVector<Value> generateWarpgroupMma() {
1402+
Location loc = op->getLoc();
1403+
rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
1404+
SmallVector<Value> wgmmaResults = generateWgmmaGroup();
1405+
rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
1406+
rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
1407+
1408+
return wgmmaResults;
1409+
}
1410+
};
1411+
1412+
LogicalResult
1413+
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1414+
ConversionPatternRewriter &rewriter) const override {
1415+
// Step 1. Build a helper class
1416+
WarpgroupGemm warpgroupGemm(op, rewriter, adaptor,
1417+
*this->getTypeConverter());
1418+
1419+
// Step 2. Get the entire GEMM Shape
1420+
SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
1421+
1422+
// Step 3. Replace fragmented result struct with the op results
1423+
rewriter.replaceOp(op, wgmmaResults);
12941424
return success();
12951425
}
12961426
};

0 commit comments

Comments
 (0)