@@ -39,7 +39,7 @@ namespace mlir {
39
39
40
40
using namespace mlir ;
41
41
42
- // / Number of bits that needs to excluded when building matrix descriptor for
42
+ // / Number of bits that needs to be excluded when building matrix descriptor for
43
43
// / wgmma operations.
44
44
constexpr int exclude4LSB = 4 ;
45
45
@@ -1160,137 +1160,276 @@ struct NVGPUWarpgroupMmaOpLowering
1160
1160
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1161
1161
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1162
1162
1163
- LogicalResult getWgmmaShape (int64_t sizeM, int64_t sizeN, Type inputElemType,
1164
- int &wgmmaShapeM, int &wgmmaShapeN,
1165
- int &wgmmaShapeK) const {
1166
- wgmmaShapeM = 64 ;
1167
- wgmmaShapeN = sizeN;
1168
- if (inputElemType.isTF32 ()) {
1169
- wgmmaShapeK = 8 ;
1170
- } else if (inputElemType.isF16 () || inputElemType.isBF16 ()) {
1171
- wgmmaShapeK = 16 ;
1172
- } else if (inputElemType.isFloat8E4M3FN () || inputElemType.isFloat8E5M2 () ||
1173
- inputElemType.isInteger (16 )) {
1174
- wgmmaShapeK = 32 ;
1175
- } else if (inputElemType.isInteger (1 )) {
1176
- wgmmaShapeK = 256 ;
1177
- } else {
1178
- llvm_unreachable (" msg: not supported K shape" );
1163
+ // / This is a helper class to generate required NVVM Ops for warp-group level
1164
+ // / matrix multiplication.
1165
+ // / When the given GEMM shape is larger than the shape of
1166
+ // / a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1167
+ // / Op(s), group and execute them asynchronously. The class also handles
1168
+ // / waiting for completion and iterates through WarpgroupMatrixDescriptor to
1169
+ // / create descriptors for each instruction.
1170
+ // /
1171
+ // / For example this is the case when the shape of GEMM is 128x128x128
1172
+ // /
1173
+ // / nvvm.wgmma.fence.aligned
1174
+ // /
1175
+ // / nvvm.wgmma.mma.async descA, descB
1176
+ // / iterate(descA, descB)
1177
+ // / nvvm.wgmma.mma.async descA, descB
1178
+ // / [6x times more]
1179
+ // /
1180
+ // / nvvm.wgmma.group.sync.aligned
1181
+ // / nvvm.wgmma.wait.group.sync [groupId]
1182
+ // /
1183
+ class WarpgroupGemm {
1184
+ nvgpu::WarpgroupMmaOp op;
1185
+ ImplicitLocOpBuilder b;
1186
+ OpAdaptor adaptor;
1187
+ const LLVMTypeConverter &typeConverter;
1188
+
1189
+ // Entire shape of the given Op
1190
+ int64_t totalM, totalN, totalK;
1191
+
1192
+ // Shape of one wgmma instruction
1193
+ int wgmmaM = 0 , wgmmaN = 0 , wgmmaK = 0 ;
1194
+
1195
+ // Iteration counts for GEMM
1196
+ int iterationM = 0 , iterationN = 0 , iterationK = 0 ;
1197
+
1198
+ // / The function returns the shape of wgmma instruction that is defined in
1199
+ // / PTX programming guide.
1200
+ // / https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1201
+ void findWgmmaShape (int64_t sizeM, int64_t sizeN, Type inputElemType) {
1202
+ wgmmaM = 64 ;
1203
+ wgmmaN = sizeN;
1204
+ if (inputElemType.isTF32 ()) {
1205
+ wgmmaK = 8 ;
1206
+ } else if (inputElemType.isF16 () || inputElemType.isBF16 ()) {
1207
+ wgmmaK = 16 ;
1208
+ } else if (inputElemType.isFloat8E4M3FN () ||
1209
+ inputElemType.isFloat8E5M2 () || inputElemType.isInteger (16 )) {
1210
+ wgmmaK = 32 ;
1211
+ } else if (inputElemType.isInteger (1 )) {
1212
+ wgmmaK = 256 ;
1213
+ } else {
1214
+ llvm_unreachable (" msg: not supported K shape" );
1215
+ }
1216
+ LLVM_DEBUG (DBGS () << " Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1217
+ << " , n = " << wgmmaN << " , k = " << wgmmaK << " ]\n " );
1179
1218
}
1180
- LLVM_DEBUG (DBGS () << " Generating wgmma.mma.async shape[m = " << wgmmaShapeM
1181
- << " , n = " << wgmmaShapeN << " , k = " << wgmmaShapeK
1182
- << " ]\n " );
1183
- return success ();
1184
- }
1185
1219
1186
- Value generateNVVMWgmmaOp (ImplicitLocOpBuilder &b, int m, int n, int k,
1187
- Type resultStructType, Value inout,
1188
- Value descriptorA, Value descriptorB) const {
1189
- MLIRContext *ctx = b.getContext ();
1190
- auto shape = NVVM::MMAShapeAttr::get (ctx, m, n, k);
1191
- auto scaleOut = NVVM::WGMMAScaleOutAttr::get (ctx, NVVM::WGMMAScaleOut::one);
1192
- auto scaleIn = NVVM::WGMMAScaleInAttr::get (ctx, NVVM::WGMMAScaleIn::one);
1193
- auto layoutA = NVVM::MMALayoutAttr::get (ctx, NVVM::MMALayout::row);
1194
- auto layoutB = NVVM::MMALayoutAttr::get (ctx, NVVM::MMALayout::col);
1195
- // todo: handle other input and output types
1196
- auto itype = NVVM::WGMMATypesAttr::get (ctx, NVVM::WGMMATypes::f16);
1197
- auto overflow =
1198
- NVVM::MMAIntOverflowAttr::get (ctx, NVVM::MMAIntOverflow::wrapped);
1199
- Value res = b.create <NVVM::WgmmaMmaAsyncOp>(
1200
- resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
1201
- scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
1202
- return res;
1203
- }
1204
-
1205
- LogicalResult
1206
- matchAndRewrite (nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1207
- ConversionPatternRewriter &rewriter) const override {
1208
- ImplicitLocOpBuilder b (op->getLoc (), rewriter);
1209
- int64_t sizeM = op.getDescriptorA ().getType ().getTensor ().getDimSize (0 );
1210
- int64_t sizeN = op.getDescriptorB ().getType ().getTensor ().getDimSize (1 );
1211
- int64_t sizeK = op.getDescriptorA ().getType ().getTensor ().getDimSize (1 );
1212
-
1213
- LLVM_DEBUG (DBGS () << " ===--- GEMM D[" << sizeM << " ][" << sizeN << " ] += A["
1214
- << sizeM << " ][" << sizeK << " ] * B[" << sizeK << " ]["
1215
- << sizeN << " ] ---===\n " );
1220
+ // / Generates WGMMATypesAttr from MLIR Type
1221
+ NVVM::WGMMATypesAttr generateWgmmaType (Type type) const {
1222
+ auto getWgmmaType = [](Type elemType) {
1223
+ if (elemType.isF32 () || elemType.isTF32 ())
1224
+ return NVVM::WGMMATypes::tf32;
1225
+ if (elemType.isF16 ())
1226
+ return NVVM::WGMMATypes::f16;
1227
+ if (elemType.isBF16 ())
1228
+ return NVVM::WGMMATypes::bf16;
1229
+ if (elemType.isFloat8E4M3FN ())
1230
+ return NVVM::WGMMATypes::e4m3;
1231
+ if (elemType.isFloat8E5M2 ())
1232
+ return NVVM::WGMMATypes::e5m2;
1233
+ if (elemType.isInteger (1 ))
1234
+ return NVVM::WGMMATypes::b1;
1235
+ if (elemType.isInteger (8 ))
1236
+ return NVVM::WGMMATypes::s8;
1237
+ if (elemType.isUnsignedInteger (8 ))
1238
+ return NVVM::WGMMATypes::u8;
1239
+ llvm_unreachable (" unsupported type" );
1240
+ };
1241
+ return NVVM::WGMMATypesAttr::get (op->getContext (), getWgmmaType (type));
1242
+ }
1216
1243
1217
- int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
1218
- if (failed (getWgmmaShape (sizeM, sizeN, rewriter.getF16Type (), wgmmaShapeM,
1219
- wgmmaShapeN, wgmmaShapeK))) {
1220
- return failure ();
1244
+ // / Generates layout attribute for the input matrix for wgmma instruction
1245
+ NVVM::MMALayoutAttr
1246
+ generateWgmmaLayout (std::optional<bool > transpose) const {
1247
+ if (transpose.value_or (false ))
1248
+ return NVVM::MMALayoutAttr::get (op->getContext (), NVVM::MMALayout::col);
1249
+ return NVVM::MMALayoutAttr::get (op->getContext (), NVVM::MMALayout::row);
1221
1250
}
1222
1251
1223
- Value descriptorA = adaptor.getDescriptorA ();
1224
- Value descriptorB = adaptor.getDescriptorB ();
1252
+ // / Generates shape attribute for wgmma instruction
1253
+ NVVM::MMAShapeAttr generateWgmmaShape () const {
1254
+ return NVVM::MMAShapeAttr::get (op->getContext (), wgmmaM, wgmmaN, wgmmaK);
1255
+ }
1225
1256
1226
- // Generate wgmma group
1227
- MemRefType typeTensorA = op.getDescriptorA ().getType ().getTensor ();
1228
- MemRefType typeTensorB = op.getDescriptorB ().getType ().getTensor ();
1257
+ // / Generates scale attributes of output matrix for wgmma instruction
1258
+ NVVM::WGMMAScaleOutAttr generateScaleOut () const {
1259
+ return NVVM::WGMMAScaleOutAttr::get (op->getContext (),
1260
+ NVVM::WGMMAScaleOut::one);
1261
+ }
1262
+ // / Generates scale attributes of input matrix for wgmma instruction
1263
+ NVVM::WGMMAScaleInAttr generateScaleIn () const {
1264
+ return NVVM::WGMMAScaleInAttr::get (op->getContext (),
1265
+ NVVM::WGMMAScaleIn::one);
1266
+ }
1229
1267
1230
- auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1268
+ // / Basic function to generate Add
1269
+ Value makeAdd (Value lhs, Value rhs) {
1231
1270
return b.create <LLVM::AddOp>(lhs.getType (), lhs, rhs);
1232
1271
};
1233
1272
1234
- auto iterateDescA = [&](Value desc, int iterM, int iterN,
1235
- int iterK) -> Value {
1236
- // todo : Handle column major
1237
- int byte = typeTensorA.getElementTypeBitWidth () / 8 ;
1238
- int tileShapeA = typeTensorA.getDimSize (1 );
1239
- int incrementVal =
1240
- ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
1273
+ // / Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1274
+ // / Currently, it only handles row-major.
1275
+ // /
1276
+ // / It moves the pointer like below for [128][64] size:
1277
+ // / +2 +4 +6
1278
+ // / ↓ ↓ ↓
1279
+ // / descA ---> +--+--+--+--+
1280
+ // / |->|->|->|->|
1281
+ // / | | | | |
1282
+ // / | | | | |
1283
+ // / | | | | |
1284
+ // / descA+512---> +-----------+
1285
+ // / | | | | |
1286
+ // / | | | | |
1287
+ // / | | | | |
1288
+ // / | | | | |
1289
+ // / +-----------+
1290
+ // /
1291
+ Value iterateDescriptorA (Value desc, int i, int j, int k) {
1292
+ MemRefType matrixTypeA = op.getDescriptorA ().getType ().getTensor ();
1293
+ Type elemA = matrixTypeA.getElementType ();
1294
+ int byte = elemA.getIntOrFloatBitWidth () / 8 ;
1295
+ int tileShapeA = matrixTypeA.getDimSize (1 );
1296
+ int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1241
1297
incrementVal = incrementVal >> exclude4LSB;
1242
- LLVM_DEBUG (DBGS () << " \t\t [m: " << iterM << " n: " << iterN << " k: "
1243
- << iterK << " ] [wgmma descriptors] Descriptor A + "
1298
+ LLVM_DEBUG (DBGS () << " \t\t [m: " << i << " n: " << j << " k: " << k
1299
+ << " ] [wgmma descriptors] Descriptor A + "
1244
1300
<< incrementVal << " | \t " );
1245
1301
if (!incrementVal)
1246
1302
return desc;
1247
1303
return makeAdd (desc, makeI64Const (b, incrementVal));
1248
- };
1304
+ }
1249
1305
1250
- auto iterateDescB = [&](Value desc, int iterM, int iterN,
1251
- int iterK) -> Value {
1252
- // todo : Handle row major
1253
- int byte = typeTensorB.getElementTypeBitWidth () / 8 ;
1254
- int incrementVal = typeTensorB.getDimSize (0 ) * wgmmaShapeK * iterK * byte;
1306
+ // / Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1307
+ // / Currently, it only handles column-major.
1308
+ // /
1309
+ // / It moves the pointer like below for [128][64] size:
1310
+ // / descB ---> +--+--+--+--+--+--+--+--+
1311
+ // / |↓ | | | | | | | |
1312
+ // / |↓ | | | | | | | |
1313
+ // / |↓ | | | | | | | |
1314
+ // / |↓ | | | | | | | |
1315
+ // / +--+--+--+--+--+--+--+--+
1316
+ // /
1317
+ Value iterateDescriptorB (Value desc, int i, int j, int k) {
1318
+ MemRefType matrixTypeB = op.getDescriptorB ().getType ().getTensor ();
1319
+ Type elemB = matrixTypeB.getElementType ();
1320
+ int byte = elemB.getIntOrFloatBitWidth () / 8 ;
1321
+ int incrementVal = matrixTypeB.getDimSize (0 ) * wgmmaK * k * byte;
1255
1322
incrementVal = incrementVal >> exclude4LSB;
1256
1323
LLVM_DEBUG (DBGSE () << " Descriptor B + " << incrementVal << " \n " );
1257
1324
if (!incrementVal)
1258
1325
return desc;
1259
1326
return makeAdd (desc, makeI64Const (b, incrementVal));
1260
- };
1327
+ }
1328
+
1329
+ // / This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1330
+ // / descriptors and arranges them based on induction variables: i, j, and k.
1331
+ Value generateWgmma (int i, int j, int k, Value matrixC, Value matrixD) {
1332
+ LLVM_DEBUG (DBGS () << " \t wgmma."
1333
+ << " m" << wgmmaM << " n" << wgmmaN << " k" << wgmmaK
1334
+ << " (A[" << (iterationM * wgmmaM) << " :"
1335
+ << (iterationM * wgmmaM) + wgmmaM << " ]["
1336
+ << (iterationK * wgmmaK) << " :"
1337
+ << (iterationK * wgmmaK + wgmmaK) << " ] * "
1338
+ << " B[" << (iterationK * wgmmaK) << " :"
1339
+ << (iterationK * wgmmaK + wgmmaK) << " ][" << 0 << " :"
1340
+ << wgmmaN << " ])\n " );
1341
+
1342
+ Value descriptorA = iterateDescriptorA (adaptor.getDescriptorA (), i, j, k);
1343
+ Value descriptorB = iterateDescriptorB (adaptor.getDescriptorB (), i, j, k);
1344
+
1345
+ Type elemA = op.getDescriptorA ().getType ().getTensor ().getElementType ();
1346
+ NVVM::WGMMATypesAttr itypeA = generateWgmmaType (elemA);
1347
+
1348
+ Type elemB = op.getDescriptorB ().getType ().getTensor ().getElementType ();
1349
+ NVVM::WGMMATypesAttr itypeB = generateWgmmaType (elemB);
1350
+
1351
+ NVVM::MMAShapeAttr shape = generateWgmmaShape ();
1352
+ NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut ();
1353
+ NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn ();
1354
+ NVVM::MMALayoutAttr layoutA = generateWgmmaLayout (op.getTransposeA ());
1355
+ NVVM::MMALayoutAttr layoutB = generateWgmmaLayout (op.getTransposeB ());
1356
+
1357
+ auto overflow = NVVM::MMAIntOverflowAttr::get (
1358
+ op->getContext (), NVVM::MMAIntOverflow::wrapped);
1359
+
1360
+ Type resultStructType = typeConverter.convertType (matrixD.getType ());
1361
+
1362
+ return b.create <NVVM::WgmmaMmaAsyncOp>(
1363
+ resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
1364
+ itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
1365
+ }
1261
1366
1262
- b.create <NVVM::WgmmaFenceAlignedOp>();
1263
-
1264
- SmallVector<Value> wgmmaResults;
1265
- for (int iterM = 0 ; iterM < (sizeM / wgmmaShapeM); iterM++) {
1266
- Value matrixC = adaptor.getMatrixC ()[iterM];
1267
- Value matrixD = op.getMatrixD ()[iterM];
1268
- Type structType = getTypeConverter ()->convertType (matrixD.getType ());
1269
- LLVM_DEBUG (DBGS () << " D[" << (iterM * wgmmaShapeM) << " :"
1270
- << (iterM * wgmmaShapeM) + wgmmaShapeM << " ][" << 0
1271
- << " :" << wgmmaShapeN << " ] += \n " );
1272
- for (int iterK = 0 ; iterK < (sizeK / wgmmaShapeK); iterK++) {
1273
- Value descA = iterateDescA (descriptorA, iterM, 0 , iterK);
1274
- Value descB = iterateDescB (descriptorB, iterM, 0 , iterK);
1275
- LLVM_DEBUG (DBGS () << " \t wgmma."
1276
- << " m" << wgmmaShapeM << " n" << wgmmaShapeN << " k"
1277
- << wgmmaShapeK << " (A[" << (iterM * wgmmaShapeM)
1278
- << " :" << (iterM * wgmmaShapeM) + wgmmaShapeM << " ]["
1279
- << (iterK * wgmmaShapeK) << " :"
1280
- << (iterK * wgmmaShapeK + wgmmaShapeK) << " ] * "
1281
- << " B[" << (iterK * wgmmaShapeK) << " :"
1282
- << (iterK * wgmmaShapeK + wgmmaShapeK) << " ][" << 0
1283
- << " :" << wgmmaShapeN << " ])\n " );
1284
- matrixC = generateNVVMWgmmaOp (b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
1285
- structType, matrixC, descA, descB);
1367
+ // / Generates multiple wgmma instructions to complete the given GEMM shape
1368
+ SmallVector<Value> generateWgmmaGroup () {
1369
+ SmallVector<Value> wgmmaResults;
1370
+
1371
+ // Perform GEMM
1372
+ for (int i = 0 ; i < iterationM; ++i) {
1373
+ Value matrixC = adaptor.getMatrixC ()[i];
1374
+ Value matrixD = op.getMatrixD ()[i];
1375
+ for (int j = 0 ; j < iterationN; ++j)
1376
+ for (int k = 0 ; k < iterationK; ++k)
1377
+ matrixC = generateWgmma (i, j, k, matrixC, matrixD);
1378
+ wgmmaResults.push_back (matrixC);
1286
1379
}
1287
- wgmmaResults.push_back (matrixC);
1380
+
1381
+ return wgmmaResults;
1382
+ }
1383
+
1384
+ public:
1385
+ WarpgroupGemm (nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1386
+ OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
1387
+ : op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
1388
+ // Find the entire GEMM Shape
1389
+ totalM = op.getDescriptorA ().getType ().getTensor ().getDimSize (0 );
1390
+ totalN = op.getDescriptorB ().getType ().getTensor ().getDimSize (1 );
1391
+ totalK = op.getDescriptorA ().getType ().getTensor ().getDimSize (1 );
1392
+ LLVM_DEBUG (DBGS () << " ===--- GEMM D[" << totalM << " ][" << totalN
1393
+ << " ] += A[" << totalM << " ][" << totalK << " ] * B["
1394
+ << totalK << " ][" << totalN << " ] ---===\n " );
1395
+
1396
+ // Find the shape for one wgmma instruction
1397
+ findWgmmaShape (
1398
+ totalM, totalN,
1399
+ op.getDescriptorA ().getType ().getTensor ().getElementType ());
1400
+
1401
+ // Iterations counts to complete the given shape with wgmma shape
1402
+ iterationM = totalM / wgmmaM;
1403
+ iterationN = totalN / wgmmaN;
1404
+ iterationK = totalK / wgmmaK;
1288
1405
}
1289
- b.create <NVVM::WgmmaGroupSyncAlignedOp>();
1290
- b.create <NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup ());
1291
1406
1292
- ValueRange myres (wgmmaResults);
1293
- rewriter.replaceOp (op, myres);
1407
+ // / Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1408
+ // / includes generating a fence Op (WgmmaFenceAlignedOp) before the
1409
+ // / instructions and group synchronization, as well as waiting
1410
+ // / (WgmmaGroupSyncAlignedOp) for group synchronization
1411
+ // / (WgmmaWaitGroupSyncOp) after the instructions.
1412
+ SmallVector<Value> generateWarpgroupMma () {
1413
+ b.create <NVVM::WgmmaFenceAlignedOp>();
1414
+ SmallVector<Value> wgmmaResults = generateWgmmaGroup ();
1415
+ b.create <NVVM::WgmmaGroupSyncAlignedOp>();
1416
+ b.create <NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup ());
1417
+ return wgmmaResults;
1418
+ }
1419
+ };
1420
+
1421
+ LogicalResult
1422
+ matchAndRewrite (nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1423
+ ConversionPatternRewriter &rewriter) const override {
1424
+ ImplicitLocOpBuilder b (op->getLoc (), rewriter);
1425
+ // Step 1. Build a helper class
1426
+ WarpgroupGemm warpgroupGemm (op, b, adaptor, *this ->getTypeConverter ());
1427
+
1428
+ // Step 2. Get the entire GEMM Shape
1429
+ SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma ();
1430
+
1431
+ // Step 3. Replace fragmented result struct with the op results
1432
+ rewriter.replaceOp (op, wgmmaResults);
1294
1433
return success ();
1295
1434
}
1296
1435
};
0 commit comments