Skip to content

Commit b74cfc1

Browse files
authored
[mlir][nvgpu] Improve nvgpu->nvvm transformation of warpgroup.mma Op (NFC) (llvm#67325)
This PR introduces substantial improvements to the readability and maintainability of the `nvgpu.warpgroup.mma` Op transformation from nvgpu->nvvm. This transformation plays a crucial role in GEMM and manages complex operations such as generating multiple wgmma ops and iterating their descriptors. The prior code lacked clarity, but this PR addresses that issue effectively. **PR does followings:** **Introduces a helper class:** `WarpgroupGemm` class encapsulates the necessary functionality, making the code cleaner and more understandable. **Detailed Documentation:** Each function within the helper class is thoroughly documented to provide clear insights into its purpose and functionality.
1 parent 7eb2b99 commit b74cfc1

File tree

1 file changed

+246
-107
lines changed

1 file changed

+246
-107
lines changed

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 246 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace mlir {
3939

4040
using namespace mlir;
4141

42-
/// Number of bits that needs to excluded when building matrix descriptor for
42+
/// Number of bits that needs to be excluded when building matrix descriptor for
4343
/// wgmma operations.
4444
constexpr int exclude4LSB = 4;
4545

@@ -1160,137 +1160,276 @@ struct NVGPUWarpgroupMmaOpLowering
11601160
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
11611161
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
11621162

1163-
LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
1164-
int &wgmmaShapeM, int &wgmmaShapeN,
1165-
int &wgmmaShapeK) const {
1166-
wgmmaShapeM = 64;
1167-
wgmmaShapeN = sizeN;
1168-
if (inputElemType.isTF32()) {
1169-
wgmmaShapeK = 8;
1170-
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
1171-
wgmmaShapeK = 16;
1172-
} else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
1173-
inputElemType.isInteger(16)) {
1174-
wgmmaShapeK = 32;
1175-
} else if (inputElemType.isInteger(1)) {
1176-
wgmmaShapeK = 256;
1177-
} else {
1178-
llvm_unreachable("msg: not supported K shape");
1163+
/// This is a helper class to generate required NVVM Ops for warp-group level
1164+
/// matrix multiplication.
1165+
/// When the given GEMM shape is larger than the shape of
1166+
/// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1167+
/// Op(s), group and execute them asynchronously. The class also handles
1168+
/// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1169+
/// create descriptors for each instruction.
1170+
///
1171+
/// For example this is the case when the shape of GEMM is 128x128x128
1172+
///
1173+
/// nvvm.wgmma.fence.aligned
1174+
///
1175+
/// nvvm.wgmma.mma.async descA, descB
1176+
/// iterate(descA, descB)
1177+
/// nvvm.wgmma.mma.async descA, descB
1178+
/// [6x times more]
1179+
///
1180+
/// nvvm.wgmma.group.sync.aligned
1181+
/// nvvm.wgmma.wait.group.sync [groupId]
1182+
///
1183+
class WarpgroupGemm {
1184+
nvgpu::WarpgroupMmaOp op;
1185+
ImplicitLocOpBuilder b;
1186+
OpAdaptor adaptor;
1187+
const LLVMTypeConverter &typeConverter;
1188+
1189+
// Entire shape of the given Op
1190+
int64_t totalM, totalN, totalK;
1191+
1192+
// Shape of one wgmma instruction
1193+
int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1194+
1195+
// Iteration counts for GEMM
1196+
int iterationM = 0, iterationN = 0, iterationK = 0;
1197+
1198+
/// The function returns the shape of wgmma instruction that is defined in
1199+
/// PTX programming guide.
1200+
/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1201+
void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1202+
wgmmaM = 64;
1203+
wgmmaN = sizeN;
1204+
if (inputElemType.isTF32()) {
1205+
wgmmaK = 8;
1206+
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
1207+
wgmmaK = 16;
1208+
} else if (inputElemType.isFloat8E4M3FN() ||
1209+
inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
1210+
wgmmaK = 32;
1211+
} else if (inputElemType.isInteger(1)) {
1212+
wgmmaK = 256;
1213+
} else {
1214+
llvm_unreachable("msg: not supported K shape");
1215+
}
1216+
LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1217+
<< ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
11791218
}
1180-
LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
1181-
<< ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
1182-
<< "]\n");
1183-
return success();
1184-
}
11851219

1186-
Value generateNVVMWgmmaOp(ImplicitLocOpBuilder &b, int m, int n, int k,
1187-
Type resultStructType, Value inout,
1188-
Value descriptorA, Value descriptorB) const {
1189-
MLIRContext *ctx = b.getContext();
1190-
auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
1191-
auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
1192-
auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
1193-
auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
1194-
auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
1195-
// todo: handle other input and output types
1196-
auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
1197-
auto overflow =
1198-
NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
1199-
Value res = b.create<NVVM::WgmmaMmaAsyncOp>(
1200-
resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
1201-
scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
1202-
return res;
1203-
}
1204-
1205-
LogicalResult
1206-
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1207-
ConversionPatternRewriter &rewriter) const override {
1208-
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1209-
int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1210-
int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1211-
int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1212-
1213-
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
1214-
<< sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
1215-
<< sizeN << "] ---===\n");
1220+
/// Generates WGMMATypesAttr from MLIR Type
1221+
NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
1222+
auto getWgmmaType = [](Type elemType) {
1223+
if (elemType.isF32() || elemType.isTF32())
1224+
return NVVM::WGMMATypes::tf32;
1225+
if (elemType.isF16())
1226+
return NVVM::WGMMATypes::f16;
1227+
if (elemType.isBF16())
1228+
return NVVM::WGMMATypes::bf16;
1229+
if (elemType.isFloat8E4M3FN())
1230+
return NVVM::WGMMATypes::e4m3;
1231+
if (elemType.isFloat8E5M2())
1232+
return NVVM::WGMMATypes::e5m2;
1233+
if (elemType.isInteger(1))
1234+
return NVVM::WGMMATypes::b1;
1235+
if (elemType.isInteger(8))
1236+
return NVVM::WGMMATypes::s8;
1237+
if (elemType.isUnsignedInteger(8))
1238+
return NVVM::WGMMATypes::u8;
1239+
llvm_unreachable("unsupported type");
1240+
};
1241+
return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1242+
}
12161243

1217-
int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
1218-
if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
1219-
wgmmaShapeN, wgmmaShapeK))) {
1220-
return failure();
1244+
/// Generates layout attribute for the input matrix for wgmma instruction
1245+
NVVM::MMALayoutAttr
1246+
generateWgmmaLayout(std::optional<bool> transpose) const {
1247+
if (transpose.value_or(false))
1248+
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1249+
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
12211250
}
12221251

1223-
Value descriptorA = adaptor.getDescriptorA();
1224-
Value descriptorB = adaptor.getDescriptorB();
1252+
/// Generates shape attribute for wgmma instruction
1253+
NVVM::MMAShapeAttr generateWgmmaShape() const {
1254+
return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1255+
}
12251256

1226-
// Generate wgmma group
1227-
MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
1228-
MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
1257+
/// Generates scale attributes of output matrix for wgmma instruction
1258+
NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1259+
return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1260+
NVVM::WGMMAScaleOut::one);
1261+
}
1262+
/// Generates scale attributes of input matrix for wgmma instruction
1263+
NVVM::WGMMAScaleInAttr generateScaleIn() const {
1264+
return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1265+
NVVM::WGMMAScaleIn::one);
1266+
}
12291267

1230-
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1268+
/// Basic function to generate Add
1269+
Value makeAdd(Value lhs, Value rhs) {
12311270
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
12321271
};
12331272

1234-
auto iterateDescA = [&](Value desc, int iterM, int iterN,
1235-
int iterK) -> Value {
1236-
// todo : Handle column major
1237-
int byte = typeTensorA.getElementTypeBitWidth() / 8;
1238-
int tileShapeA = typeTensorA.getDimSize(1);
1239-
int incrementVal =
1240-
((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
1273+
/// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1274+
/// Currently, it only handles row-major.
1275+
///
1276+
/// It moves the pointer like below for [128][64] size:
1277+
/// +2 +4 +6
1278+
/// ↓ ↓ ↓
1279+
/// descA ---> +--+--+--+--+
1280+
/// |->|->|->|->|
1281+
/// | | | | |
1282+
/// | | | | |
1283+
/// | | | | |
1284+
/// descA+512---> +-----------+
1285+
/// | | | | |
1286+
/// | | | | |
1287+
/// | | | | |
1288+
/// | | | | |
1289+
/// +-----------+
1290+
///
1291+
Value iterateDescriptorA(Value desc, int i, int j, int k) {
1292+
MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1293+
Type elemA = matrixTypeA.getElementType();
1294+
int byte = elemA.getIntOrFloatBitWidth() / 8;
1295+
int tileShapeA = matrixTypeA.getDimSize(1);
1296+
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
12411297
incrementVal = incrementVal >> exclude4LSB;
1242-
LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
1243-
<< iterK << "] [wgmma descriptors] Descriptor A + "
1298+
LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1299+
<< "] [wgmma descriptors] Descriptor A + "
12441300
<< incrementVal << " | \t ");
12451301
if (!incrementVal)
12461302
return desc;
12471303
return makeAdd(desc, makeI64Const(b, incrementVal));
1248-
};
1304+
}
12491305

1250-
auto iterateDescB = [&](Value desc, int iterM, int iterN,
1251-
int iterK) -> Value {
1252-
// todo : Handle row major
1253-
int byte = typeTensorB.getElementTypeBitWidth() / 8;
1254-
int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
1306+
/// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1307+
/// Currently, it only handles column-major.
1308+
///
1309+
/// It moves the pointer like below for [128][64] size:
1310+
/// descB ---> +--+--+--+--+--+--+--+--+
1311+
/// |↓ | | | | | | | |
1312+
/// |↓ | | | | | | | |
1313+
/// |↓ | | | | | | | |
1314+
/// |↓ | | | | | | | |
1315+
/// +--+--+--+--+--+--+--+--+
1316+
///
1317+
Value iterateDescriptorB(Value desc, int i, int j, int k) {
1318+
MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1319+
Type elemB = matrixTypeB.getElementType();
1320+
int byte = elemB.getIntOrFloatBitWidth() / 8;
1321+
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
12551322
incrementVal = incrementVal >> exclude4LSB;
12561323
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
12571324
if (!incrementVal)
12581325
return desc;
12591326
return makeAdd(desc, makeI64Const(b, incrementVal));
1260-
};
1327+
}
1328+
1329+
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1330+
/// descriptors and arranges them based on induction variables: i, j, and k.
1331+
Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
1332+
LLVM_DEBUG(DBGS() << "\t wgmma."
1333+
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1334+
<< "(A[" << (iterationM * wgmmaM) << ":"
1335+
<< (iterationM * wgmmaM) + wgmmaM << "]["
1336+
<< (iterationK * wgmmaK) << ":"
1337+
<< (iterationK * wgmmaK + wgmmaK) << "] * "
1338+
<< " B[" << (iterationK * wgmmaK) << ":"
1339+
<< (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1340+
<< wgmmaN << "])\n");
1341+
1342+
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1343+
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1344+
1345+
Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1346+
NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1347+
1348+
Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1349+
NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1350+
1351+
NVVM::MMAShapeAttr shape = generateWgmmaShape();
1352+
NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1353+
NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1354+
NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1355+
NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
1356+
1357+
auto overflow = NVVM::MMAIntOverflowAttr::get(
1358+
op->getContext(), NVVM::MMAIntOverflow::wrapped);
1359+
1360+
Type resultStructType = typeConverter.convertType(matrixD.getType());
1361+
1362+
return b.create<NVVM::WgmmaMmaAsyncOp>(
1363+
resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
1364+
itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
1365+
}
12611366

1262-
b.create<NVVM::WgmmaFenceAlignedOp>();
1263-
1264-
SmallVector<Value> wgmmaResults;
1265-
for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
1266-
Value matrixC = adaptor.getMatrixC()[iterM];
1267-
Value matrixD = op.getMatrixD()[iterM];
1268-
Type structType = getTypeConverter()->convertType(matrixD.getType());
1269-
LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
1270-
<< (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
1271-
<< ":" << wgmmaShapeN << "] += \n");
1272-
for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
1273-
Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
1274-
Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
1275-
LLVM_DEBUG(DBGS() << "\t wgmma."
1276-
<< "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
1277-
<< wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
1278-
<< ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
1279-
<< (iterK * wgmmaShapeK) << ":"
1280-
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
1281-
<< " B[" << (iterK * wgmmaShapeK) << ":"
1282-
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
1283-
<< ":" << wgmmaShapeN << "])\n");
1284-
matrixC = generateNVVMWgmmaOp(b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
1285-
structType, matrixC, descA, descB);
1367+
/// Generates multiple wgmma instructions to complete the given GEMM shape
1368+
SmallVector<Value> generateWgmmaGroup() {
1369+
SmallVector<Value> wgmmaResults;
1370+
1371+
// Perform GEMM
1372+
for (int i = 0; i < iterationM; ++i) {
1373+
Value matrixC = adaptor.getMatrixC()[i];
1374+
Value matrixD = op.getMatrixD()[i];
1375+
for (int j = 0; j < iterationN; ++j)
1376+
for (int k = 0; k < iterationK; ++k)
1377+
matrixC = generateWgmma(i, j, k, matrixC, matrixD);
1378+
wgmmaResults.push_back(matrixC);
12861379
}
1287-
wgmmaResults.push_back(matrixC);
1380+
1381+
return wgmmaResults;
1382+
}
1383+
1384+
public:
1385+
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1386+
OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
1387+
: op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
1388+
// Find the entire GEMM Shape
1389+
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1390+
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1391+
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1392+
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1393+
<< "] += A[" << totalM << "][" << totalK << "] * B["
1394+
<< totalK << "][" << totalN << "] ---===\n");
1395+
1396+
// Find the shape for one wgmma instruction
1397+
findWgmmaShape(
1398+
totalM, totalN,
1399+
op.getDescriptorA().getType().getTensor().getElementType());
1400+
1401+
// Iterations counts to complete the given shape with wgmma shape
1402+
iterationM = totalM / wgmmaM;
1403+
iterationN = totalN / wgmmaN;
1404+
iterationK = totalK / wgmmaK;
12881405
}
1289-
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1290-
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
12911406

1292-
ValueRange myres(wgmmaResults);
1293-
rewriter.replaceOp(op, myres);
1407+
/// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1408+
/// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1409+
/// instructions and group synchronization, as well as waiting
1410+
/// (WgmmaGroupSyncAlignedOp) for group synchronization
1411+
/// (WgmmaWaitGroupSyncOp) after the instructions.
1412+
SmallVector<Value> generateWarpgroupMma() {
1413+
b.create<NVVM::WgmmaFenceAlignedOp>();
1414+
SmallVector<Value> wgmmaResults = generateWgmmaGroup();
1415+
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1416+
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1417+
return wgmmaResults;
1418+
}
1419+
};
1420+
1421+
LogicalResult
1422+
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1423+
ConversionPatternRewriter &rewriter) const override {
1424+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1425+
// Step 1. Build a helper class
1426+
WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
1427+
1428+
// Step 2. Get the entire GEMM Shape
1429+
SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
1430+
1431+
// Step 3. Replace fragmented result struct with the op results
1432+
rewriter.replaceOp(op, wgmmaResults);
12941433
return success();
12951434
}
12961435
};

0 commit comments

Comments
 (0)