@@ -39,7 +39,7 @@ namespace mlir {
39
39
40
40
using namespace mlir ;
41
41
42
- // / Number of bits that needs to excluded when building matrix descriptor for
42
+ // / Number of bits that needs to be excluded when building matrix descriptor for
43
43
// / wgmma operations.
44
44
constexpr int exclude4LSB = 4 ;
45
45
@@ -1160,137 +1160,267 @@ struct NVGPUWarpgroupMmaOpLowering
1160
1160
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1161
1161
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1162
1162
1163
- LogicalResult getWgmmaShape (int64_t sizeM, int64_t sizeN, Type inputElemType,
1164
- int &wgmmaShapeM, int &wgmmaShapeN,
1165
- int &wgmmaShapeK) const {
1166
- wgmmaShapeM = 64 ;
1167
- wgmmaShapeN = sizeN;
1168
- if (inputElemType.isTF32 ()) {
1169
- wgmmaShapeK = 8 ;
1170
- } else if (inputElemType.isF16 () || inputElemType.isBF16 ()) {
1171
- wgmmaShapeK = 16 ;
1172
- } else if (inputElemType.isFloat8E4M3FN () || inputElemType.isFloat8E5M2 () ||
1173
- inputElemType.isInteger (16 )) {
1174
- wgmmaShapeK = 32 ;
1175
- } else if (inputElemType.isInteger (1 )) {
1176
- wgmmaShapeK = 256 ;
1177
- } else {
1178
- llvm_unreachable (" msg: not supported K shape" );
1163
+ // / This class assists in generating WgmmaMmaAsyncOp instructions to complete
1164
+ // / a specified shape. If the GEMM shape is larger than the shape of a wgmma
1165
+ // / instrution, it can generate multiple wgmma instructions, group and execute
1166
+ // / them asynchronously. The class also handles waiting for instruction
1167
+ // / completion and iterates through GenerateGmmaDescriptor to create
1168
+ // / descriptors for each instruction.
1169
+ class WarpgroupGemm {
1170
+ nvgpu::WarpgroupMmaOp op;
1171
+ ConversionPatternRewriter &rewriter;
1172
+ OpAdaptor adaptor;
1173
+ const LLVMTypeConverter &typeConverter;
1174
+
1175
+ // Entire shape of the given Op
1176
+ int64_t totalM, totalN, totalK;
1177
+
1178
+ // Shape of one wgmma instruction
1179
+ int wgmmaM = 0 , wgmmaN = 0 , wgmmaK = 0 ;
1180
+
1181
+ // Iteration counts for GEMM
1182
+ int iterationM = 0 , iterationN = 0 , iterationK = 0 ;
1183
+
1184
+ // / The function returns the shape of wgmma instruction that is defined in
1185
+ // / PTX programming guide.
1186
+ // / https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1187
+ void findWgmmaShape (int64_t sizeM, int64_t sizeN, Type inputElemType) {
1188
+ wgmmaM = 64 ;
1189
+ wgmmaN = sizeN;
1190
+ if (inputElemType.isTF32 ()) {
1191
+ wgmmaK = 8 ;
1192
+ } else if (inputElemType.isF16 () || inputElemType.isBF16 ()) {
1193
+ wgmmaK = 16 ;
1194
+ } else if (inputElemType.isFloat8E4M3FN () ||
1195
+ inputElemType.isFloat8E5M2 () || inputElemType.isInteger (16 )) {
1196
+ wgmmaK = 32 ;
1197
+ } else if (inputElemType.isInteger (1 )) {
1198
+ wgmmaK = 256 ;
1199
+ } else {
1200
+ llvm_unreachable (" msg: not supported K shape" );
1201
+ }
1202
+ LLVM_DEBUG (DBGS () << " Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1203
+ << " , n = " << wgmmaN << " , k = " << wgmmaK << " ]\n " );
1179
1204
}
1180
- LLVM_DEBUG (DBGS () << " Generating wgmma.mma.async shape[m = " << wgmmaShapeM
1181
- << " , n = " << wgmmaShapeN << " , k = " << wgmmaShapeK
1182
- << " ]\n " );
1183
- return success ();
1184
- }
1185
1205
1186
- Value generateNVVMWgmmaOp (ImplicitLocOpBuilder &b, int m, int n, int k,
1187
- Type resultStructType, Value inout,
1188
- Value descriptorA, Value descriptorB) const {
1189
- MLIRContext *ctx = b.getContext ();
1190
- auto shape = NVVM::MMAShapeAttr::get (ctx, m, n, k);
1191
- auto scaleOut = NVVM::WGMMAScaleOutAttr::get (ctx, NVVM::WGMMAScaleOut::one);
1192
- auto scaleIn = NVVM::WGMMAScaleInAttr::get (ctx, NVVM::WGMMAScaleIn::one);
1193
- auto layoutA = NVVM::MMALayoutAttr::get (ctx, NVVM::MMALayout::row);
1194
- auto layoutB = NVVM::MMALayoutAttr::get (ctx, NVVM::MMALayout::col);
1195
- // todo: handle other input and output types
1196
- auto itype = NVVM::WGMMATypesAttr::get (ctx, NVVM::WGMMATypes::f16);
1197
- auto overflow =
1198
- NVVM::MMAIntOverflowAttr::get (ctx, NVVM::MMAIntOverflow::wrapped);
1199
- Value res = b.create <NVVM::WgmmaMmaAsyncOp>(
1200
- resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
1201
- scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
1202
- return res;
1203
- }
1204
-
1205
- LogicalResult
1206
- matchAndRewrite (nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1207
- ConversionPatternRewriter &rewriter) const override {
1208
- ImplicitLocOpBuilder b (op->getLoc (), rewriter);
1209
- int64_t sizeM = op.getDescriptorA ().getType ().getTensor ().getDimSize (0 );
1210
- int64_t sizeN = op.getDescriptorB ().getType ().getTensor ().getDimSize (1 );
1211
- int64_t sizeK = op.getDescriptorA ().getType ().getTensor ().getDimSize (1 );
1212
-
1213
- LLVM_DEBUG (DBGS () << " ===--- GEMM D[" << sizeM << " ][" << sizeN << " ] += A["
1214
- << sizeM << " ][" << sizeK << " ] * B[" << sizeK << " ]["
1215
- << sizeN << " ] ---===\n " );
1206
+ // / Generates WGMMATypesAttr from MLIR Type
1207
+ NVVM::WGMMATypesAttr generateWgmmaType (Type type) const {
1208
+ auto getWgmmaType = [](Type elemType) {
1209
+ if (elemType.isF32 () || elemType.isTF32 ())
1210
+ return NVVM::WGMMATypes::tf32;
1211
+ if (elemType.isF16 ())
1212
+ return NVVM::WGMMATypes::f16;
1213
+ if (elemType.isBF16 ())
1214
+ return NVVM::WGMMATypes::bf16;
1215
+ if (elemType.isFloat8E4M3FN ())
1216
+ return NVVM::WGMMATypes::e4m3;
1217
+ if (elemType.isFloat8E5M2 ())
1218
+ return NVVM::WGMMATypes::e5m2;
1219
+ if (elemType.isInteger (1 ))
1220
+ return NVVM::WGMMATypes::b1;
1221
+ if (elemType.isInteger (8 ))
1222
+ return NVVM::WGMMATypes::s8;
1223
+ if (elemType.isUnsignedInteger (8 ))
1224
+ return NVVM::WGMMATypes::u8;
1225
+ llvm_unreachable (" unsupported type" );
1226
+ };
1227
+ return NVVM::WGMMATypesAttr::get (op->getContext (), getWgmmaType (type));
1228
+ }
1216
1229
1217
- int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
1218
- if (failed (getWgmmaShape (sizeM, sizeN, rewriter.getF16Type (), wgmmaShapeM,
1219
- wgmmaShapeN, wgmmaShapeK))) {
1220
- return failure ();
1230
+ // / Generates layout attribute for the input matrix for wgmma instruction
1231
+ NVVM::MMALayoutAttr
1232
+ generateWgmmaLayout (std::optional<bool > transpose) const {
1233
+ if (transpose.value_or (false ))
1234
+ return NVVM::MMALayoutAttr::get (op->getContext (), NVVM::MMALayout::col);
1235
+ return NVVM::MMALayoutAttr::get (op->getContext (), NVVM::MMALayout::row);
1221
1236
}
1222
1237
1223
- Value descriptorA = adaptor.getDescriptorA ();
1224
- Value descriptorB = adaptor.getDescriptorB ();
1238
+ // / Generates shape attribute for wgmma instruction
1239
+ NVVM::MMAShapeAttr generateWgmmaShape () const {
1240
+ return NVVM::MMAShapeAttr::get (op->getContext (), wgmmaM, wgmmaN, wgmmaK);
1241
+ }
1225
1242
1226
- // Generate wgmma group
1227
- MemRefType typeTensorA = op.getDescriptorA ().getType ().getTensor ();
1228
- MemRefType typeTensorB = op.getDescriptorB ().getType ().getTensor ();
1243
+ // / Generates scale attributes of output matrix for wgmma instruction
1244
+ NVVM::WGMMAScaleOutAttr generateScaleOut () const {
1245
+ return NVVM::WGMMAScaleOutAttr::get (op->getContext (),
1246
+ NVVM::WGMMAScaleOut::one);
1247
+ }
1248
+ // / Generates scale attributes of input matrix for wgmma instruction
1249
+ NVVM::WGMMAScaleInAttr generateScaleIn () const {
1250
+ return NVVM::WGMMAScaleInAttr::get (op->getContext (),
1251
+ NVVM::WGMMAScaleIn::one);
1252
+ }
1229
1253
1230
- auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1231
- return b.create <LLVM::AddOp>(lhs.getType (), lhs, rhs);
1254
+ // / Basic function to generate Add
1255
+ Value makeAdd (Value lhs, Value rhs) {
1256
+ return rewriter.create <LLVM::AddOp>(op->getLoc (), lhs.getType (), lhs,
1257
+ rhs);
1232
1258
};
1233
1259
1234
- auto iterateDescA = [&](Value desc, int iterM, int iterN,
1235
- int iterK) -> Value {
1236
- // todo : Handle column major
1237
- int byte = typeTensorA.getElementTypeBitWidth () / 8 ;
1238
- int tileShapeA = typeTensorA.getDimSize (1 );
1239
- int incrementVal =
1240
- ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
1260
+ // / Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1261
+ // / Currently, it only handles row-major.
1262
+ // /
1263
+ // / It moves the pointer like below for [128][64] size:
1264
+ // / +2 +4 +6
1265
+ // / ↓ ↓ ↓
1266
+ // / descA ---> +--+--+--+--+
1267
+ // / |->|->|->|->|
1268
+ // / | | | | |
1269
+ // / | | | | |
1270
+ // / | | | | |
1271
+ // / descA+512---> +-----------+
1272
+ // / | | | | |
1273
+ // / | | | | |
1274
+ // / | | | | |
1275
+ // / | | | | |
1276
+ // / +-----------+
1277
+ // /
1278
+ Value iterateDescriptorA (Value desc, int i, int j, int k) {
1279
+ MemRefType matrixTypeA = op.getDescriptorA ().getType ().getTensor ();
1280
+ Type elemA = matrixTypeA.getElementType ();
1281
+ int byte = elemA.getIntOrFloatBitWidth () / 8 ;
1282
+ int tileShapeA = matrixTypeA.getDimSize (1 );
1283
+ int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1241
1284
incrementVal = incrementVal >> exclude4LSB;
1242
- LLVM_DEBUG (DBGS () << " \t\t [m: " << iterM << " n: " << iterN << " k: "
1243
- << iterK << " ] [wgmma descriptors] Descriptor A + "
1285
+ LLVM_DEBUG (DBGS () << " \t\t [m: " << i << " n: " << j << " k: " << k
1286
+ << " ] [wgmma descriptors] Descriptor A + "
1244
1287
<< incrementVal << " | \t " );
1245
1288
if (!incrementVal)
1246
1289
return desc;
1247
- return makeAdd (desc, makeI64Const (b , incrementVal));
1248
- };
1290
+ return makeAdd (desc, makeI64Const (rewriter, op , incrementVal));
1291
+ }
1249
1292
1250
- auto iterateDescB = [&](Value desc, int iterM, int iterN,
1251
- int iterK) -> Value {
1252
- // todo : Handle row major
1253
- int byte = typeTensorB.getElementTypeBitWidth () / 8 ;
1254
- int incrementVal = typeTensorB.getDimSize (0 ) * wgmmaShapeK * iterK * byte;
1293
+ // / Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1294
+ // / Currently, it only handles column-major.
1295
+ // /
1296
+ // / It moves the pointer like below for [128][64] size:
1297
+ // / descB ---> +--+--+--+--+--+--+--+--+
1298
+ // / |↓ | | | | | | | |
1299
+ // / |↓ | | | | | | | |
1300
+ // / |↓ | | | | | | | |
1301
+ // / |↓ | | | | | | | |
1302
+ // / +--+--+--+--+--+--+--+--+
1303
+ // /
1304
+ Value iterateDescriptorB (Value desc, int i, int j, int k) {
1305
+ MemRefType matrixTypeB = op.getDescriptorB ().getType ().getTensor ();
1306
+ Type elemB = matrixTypeB.getElementType ();
1307
+ int byte = elemB.getIntOrFloatBitWidth () / 8 ;
1308
+ int incrementVal = matrixTypeB.getDimSize (0 ) * wgmmaK * k * byte;
1255
1309
incrementVal = incrementVal >> exclude4LSB;
1256
1310
LLVM_DEBUG (DBGSE () << " Descriptor B + " << incrementVal << " \n " );
1257
1311
if (!incrementVal)
1258
1312
return desc;
1259
- return makeAdd (desc, makeI64Const (b, incrementVal));
1260
- };
1313
+ return makeAdd (desc, makeI64Const (rewriter, op, incrementVal));
1314
+ }
1315
+
1316
+ // / This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1317
+ // / descriptors and arranges them based on induction variables: i, j, and k.
1318
+ Value generateWgmma (int i, int j, int k, Value matrixC, Value matrixD) {
1319
+ LLVM_DEBUG (DBGS () << " \t wgmma."
1320
+ << " m" << wgmmaM << " n" << wgmmaN << " k" << wgmmaK
1321
+ << " (A[" << (iterationM * wgmmaM) << " :"
1322
+ << (iterationM * wgmmaM) + wgmmaM << " ]["
1323
+ << (iterationK * wgmmaK) << " :"
1324
+ << (iterationK * wgmmaK + wgmmaK) << " ] * "
1325
+ << " B[" << (iterationK * wgmmaK) << " :"
1326
+ << (iterationK * wgmmaK + wgmmaK) << " ][" << 0 << " :"
1327
+ << wgmmaN << " ])\n " );
1328
+
1329
+ Value descriptorA = iterateDescriptorA (adaptor.getDescriptorA (), i, j, k);
1330
+ Value descriptorB = iterateDescriptorB (adaptor.getDescriptorB (), i, j, k);
1331
+
1332
+ Type elemA = op.getDescriptorA ().getType ().getTensor ().getElementType ();
1333
+ NVVM::WGMMATypesAttr itypeA = generateWgmmaType (elemA);
1334
+
1335
+ Type elemB = op.getDescriptorB ().getType ().getTensor ().getElementType ();
1336
+ NVVM::WGMMATypesAttr itypeB = generateWgmmaType (elemB);
1337
+
1338
+ NVVM::MMAShapeAttr shape = generateWgmmaShape ();
1339
+ NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut ();
1340
+ NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn ();
1341
+ NVVM::MMALayoutAttr layoutA = generateWgmmaLayout (op.getTransposeA ());
1342
+ NVVM::MMALayoutAttr layoutB = generateWgmmaLayout (op.getTransposeB ());
1343
+
1344
+ auto overflow = NVVM::MMAIntOverflowAttr::get (
1345
+ op->getContext (), NVVM::MMAIntOverflow::wrapped);
1346
+
1347
+ Type resultStructType = typeConverter.convertType (matrixD.getType ());
1348
+
1349
+ return rewriter.create <NVVM::WgmmaMmaAsyncOp>(
1350
+ op->getLoc (), resultStructType, matrixC, descriptorA, descriptorB,
1351
+ shape, itypeA, itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1352
+ overflow);
1353
+ }
1261
1354
1262
- b.create <NVVM::WgmmaFenceAlignedOp>();
1263
-
1264
- SmallVector<Value> wgmmaResults;
1265
- for (int iterM = 0 ; iterM < (sizeM / wgmmaShapeM); iterM++) {
1266
- Value matrixC = adaptor.getMatrixC ()[iterM];
1267
- Value matrixD = op.getMatrixD ()[iterM];
1268
- Type structType = getTypeConverter ()->convertType (matrixD.getType ());
1269
- LLVM_DEBUG (DBGS () << " D[" << (iterM * wgmmaShapeM) << " :"
1270
- << (iterM * wgmmaShapeM) + wgmmaShapeM << " ][" << 0
1271
- << " :" << wgmmaShapeN << " ] += \n " );
1272
- for (int iterK = 0 ; iterK < (sizeK / wgmmaShapeK); iterK++) {
1273
- Value descA = iterateDescA (descriptorA, iterM, 0 , iterK);
1274
- Value descB = iterateDescB (descriptorB, iterM, 0 , iterK);
1275
- LLVM_DEBUG (DBGS () << " \t wgmma."
1276
- << " m" << wgmmaShapeM << " n" << wgmmaShapeN << " k"
1277
- << wgmmaShapeK << " (A[" << (iterM * wgmmaShapeM)
1278
- << " :" << (iterM * wgmmaShapeM) + wgmmaShapeM << " ]["
1279
- << (iterK * wgmmaShapeK) << " :"
1280
- << (iterK * wgmmaShapeK + wgmmaShapeK) << " ] * "
1281
- << " B[" << (iterK * wgmmaShapeK) << " :"
1282
- << (iterK * wgmmaShapeK + wgmmaShapeK) << " ][" << 0
1283
- << " :" << wgmmaShapeN << " ])\n " );
1284
- matrixC = generateNVVMWgmmaOp (b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
1285
- structType, matrixC, descA, descB);
1355
+ // / Generates multiple wgmma instructions to complete the given GEMM shape
1356
+ SmallVector<Value> generateWgmmaGroup () {
1357
+ SmallVector<Value> wgmmaResults;
1358
+
1359
+ // Perform GEMM
1360
+ for (int i = 0 ; i < iterationM; ++i) {
1361
+ Value matrixC = adaptor.getMatrixC ()[i];
1362
+ Value matrixD = op.getMatrixD ()[i];
1363
+ for (int j = 0 ; j < iterationN; ++j)
1364
+ for (int k = 0 ; k < iterationK; ++k)
1365
+ matrixC = generateWgmma (i, j, k, matrixC, matrixD);
1366
+ wgmmaResults.push_back (matrixC);
1286
1367
}
1287
- wgmmaResults.push_back (matrixC);
1368
+
1369
+ return wgmmaResults;
1370
+ }
1371
+
1372
+ public:
1373
+ WarpgroupGemm (nvgpu::WarpgroupMmaOp op, ConversionPatternRewriter &rewriter,
1374
+ OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
1375
+ : op(op), rewriter(rewriter), adaptor(adaptor),
1376
+ typeConverter (typeConverter) {
1377
+ // Find the entire GEMM Shape
1378
+ totalM = op.getDescriptorA ().getType ().getTensor ().getDimSize (0 );
1379
+ totalN = op.getDescriptorB ().getType ().getTensor ().getDimSize (1 );
1380
+ totalK = op.getDescriptorA ().getType ().getTensor ().getDimSize (1 );
1381
+ LLVM_DEBUG (DBGS () << " ===--- GEMM D[" << totalM << " ][" << totalN
1382
+ << " ] += A[" << totalM << " ][" << totalK << " ] * B["
1383
+ << totalK << " ][" << totalN << " ] ---===\n " );
1384
+
1385
+ // Find the shape for one wgmma instruction
1386
+ findWgmmaShape (
1387
+ totalM, totalN,
1388
+ op.getDescriptorA ().getType ().getTensor ().getElementType ());
1389
+
1390
+ // Iterations counts to complete the given shape with wgmma shape
1391
+ iterationM = totalM / wgmmaM;
1392
+ iterationN = totalN / wgmmaN;
1393
+ iterationK = totalK / wgmmaK;
1288
1394
}
1289
- b.create <NVVM::WgmmaGroupSyncAlignedOp>();
1290
- b.create <NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup ());
1291
1395
1292
- ValueRange myres (wgmmaResults);
1293
- rewriter.replaceOp (op, myres);
1396
+ // / Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1397
+ // / includes generating a fence Op (WgmmaFenceAlignedOp) before the
1398
+ // / instructions and group synchronization, as well as waiting
1399
+ // / (WgmmaGroupSyncAlignedOp) for group synchronization
1400
+ // / (WgmmaWaitGroupSyncOp) after the instructions.
1401
+ SmallVector<Value> generateWarpgroupMma () {
1402
+ Location loc = op->getLoc ();
1403
+ rewriter.create <NVVM::WgmmaFenceAlignedOp>(loc);
1404
+ SmallVector<Value> wgmmaResults = generateWgmmaGroup ();
1405
+ rewriter.create <NVVM::WgmmaGroupSyncAlignedOp>(loc);
1406
+ rewriter.create <NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup ());
1407
+
1408
+ return wgmmaResults;
1409
+ }
1410
+ };
1411
+
1412
+ LogicalResult
1413
+ matchAndRewrite (nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1414
+ ConversionPatternRewriter &rewriter) const override {
1415
+ // Step 1. Build a helper class
1416
+ WarpgroupGemm warpgroupGemm (op, rewriter, adaptor,
1417
+ *this ->getTypeConverter ());
1418
+
1419
+ // Step 2. Get the entire GEMM Shape
1420
+ SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma ();
1421
+
1422
+ // Step 3. Replace fragmented result struct with the op results
1423
+ rewriter.replaceOp (op, wgmmaResults);
1294
1424
return success ();
1295
1425
}
1296
1426
};
0 commit comments