@@ -1160,15 +1160,29 @@ struct NVGPUWarpgroupMmaOpLowering
1160
1160
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1161
1161
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1162
1162
1163
- // / This class assists in generating WgmmaMmaAsyncOp instructions to complete
1164
- // / a specified shape. If the GEMM shape is larger than the shape of a wgmma
1165
- // / instrution, it can generate multiple wgmma instructions, group and execute
1166
- // / them asynchronously. The class also handles waiting for instruction
1167
- // / completion and iterates through GenerateGmmaDescriptor to create
1168
- // / descriptors for each instruction.
1163
+ // / This is a helper class to generate required NVVM Ops for warp-group level
1164
+ // / matrix multiplication.
1165
+ // / When the given GEMM shape is larger than the shape of
1166
+ // / a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1167
+ // / Op(s), group and execute them asynchronously. The class also handles
1168
+ // / waiting for completion and iterates through WarpgroupMatrixDescriptor to
1169
+ // / create descriptors for each instruction.
1170
+ // /
1171
+ // / For example this is the case when the shape of GEMM is 128x128x128
1172
+ // /
1173
+ // / nvvm.wgmma.fence.aligned
1174
+ // /
1175
+ // / nvvm.wgmma.mma.async descA, descB
1176
+ // / iterate(descA, descB)
1177
+ // / nvvm.wgmma.mma.async descA, descB
1178
+ // / [6x times more]
1179
+ // /
1180
+ // / nvvm.wgmma.group.sync.aligned
1181
+ // / nvvm.wgmma.wait.group.sync [groupId]
1182
+ // /
1169
1183
class WarpgroupGemm {
1170
1184
nvgpu::WarpgroupMmaOp op;
1171
- ConversionPatternRewriter &rewriter ;
1185
+ ImplicitLocOpBuilder b ;
1172
1186
OpAdaptor adaptor;
1173
1187
const LLVMTypeConverter &typeConverter;
1174
1188
@@ -1253,8 +1267,7 @@ struct NVGPUWarpgroupMmaOpLowering
1253
1267
1254
1268
// / Basic function to generate Add
1255
1269
Value makeAdd (Value lhs, Value rhs) {
1256
- return rewriter.create <LLVM::AddOp>(op->getLoc (), lhs.getType (), lhs,
1257
- rhs);
1270
+ return b.create <LLVM::AddOp>(lhs.getType (), lhs, rhs);
1258
1271
};
1259
1272
1260
1273
// / Moves the descriptor pointer of matrix-A for the next wgmma instruction.
@@ -1287,7 +1300,7 @@ struct NVGPUWarpgroupMmaOpLowering
1287
1300
<< incrementVal << " | \t " );
1288
1301
if (!incrementVal)
1289
1302
return desc;
1290
- return makeAdd (desc, makeI64Const (rewriter, op , incrementVal));
1303
+ return makeAdd (desc, makeI64Const (b , incrementVal));
1291
1304
}
1292
1305
1293
1306
// / Moves the descriptor pointer of matrix-B for the next wgmma instruction.
@@ -1310,7 +1323,7 @@ struct NVGPUWarpgroupMmaOpLowering
1310
1323
LLVM_DEBUG (DBGSE () << " Descriptor B + " << incrementVal << " \n " );
1311
1324
if (!incrementVal)
1312
1325
return desc;
1313
- return makeAdd (desc, makeI64Const (rewriter, op , incrementVal));
1326
+ return makeAdd (desc, makeI64Const (b , incrementVal));
1314
1327
}
1315
1328
1316
1329
// / This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
@@ -1346,10 +1359,9 @@ struct NVGPUWarpgroupMmaOpLowering
1346
1359
1347
1360
Type resultStructType = typeConverter.convertType (matrixD.getType ());
1348
1361
1349
- return rewriter.create <NVVM::WgmmaMmaAsyncOp>(
1350
- op->getLoc (), resultStructType, matrixC, descriptorA, descriptorB,
1351
- shape, itypeA, itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1352
- overflow);
1362
+ return b.create <NVVM::WgmmaMmaAsyncOp>(
1363
+ resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
1364
+ itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
1353
1365
}
1354
1366
1355
1367
// / Generates multiple wgmma instructions to complete the given GEMM shape
@@ -1370,10 +1382,9 @@ struct NVGPUWarpgroupMmaOpLowering
1370
1382
}
1371
1383
1372
1384
public:
1373
- WarpgroupGemm (nvgpu::WarpgroupMmaOp op, ConversionPatternRewriter &rewriter ,
1385
+ WarpgroupGemm (nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b ,
1374
1386
OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
1375
- : op(op), rewriter(rewriter), adaptor(adaptor),
1376
- typeConverter (typeConverter) {
1387
+ : op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
1377
1388
// Find the entire GEMM Shape
1378
1389
totalM = op.getDescriptorA ().getType ().getTensor ().getDimSize (0 );
1379
1390
totalN = op.getDescriptorB ().getType ().getTensor ().getDimSize (1 );
@@ -1399,22 +1410,20 @@ struct NVGPUWarpgroupMmaOpLowering
1399
1410
// / (WgmmaGroupSyncAlignedOp) for group synchronization
1400
1411
// / (WgmmaWaitGroupSyncOp) after the instructions.
1401
1412
SmallVector<Value> generateWarpgroupMma () {
1402
- Location loc = op->getLoc ();
1403
- rewriter.create <NVVM::WgmmaFenceAlignedOp>(loc);
1413
+ b.create <NVVM::WgmmaFenceAlignedOp>();
1404
1414
SmallVector<Value> wgmmaResults = generateWgmmaGroup ();
1405
- rewriter.create <NVVM::WgmmaGroupSyncAlignedOp>(loc);
1406
- rewriter.create <NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup ());
1407
-
1415
+ b.create <NVVM::WgmmaGroupSyncAlignedOp>();
1416
+ b.create <NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup ());
1408
1417
return wgmmaResults;
1409
1418
}
1410
1419
};
1411
1420
1412
1421
LogicalResult
1413
1422
matchAndRewrite (nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1414
1423
ConversionPatternRewriter &rewriter) const override {
1424
+ ImplicitLocOpBuilder b (op->getLoc (), rewriter);
1415
1425
// Step 1. Build a helper class
1416
- WarpgroupGemm warpgroupGemm (op, rewriter, adaptor,
1417
- *this ->getTypeConverter ());
1426
+ WarpgroupGemm warpgroupGemm (op, b, adaptor, *this ->getTypeConverter ());
1418
1427
1419
1428
// Step 2. Get the entire GEMM Shape
1420
1429
SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma ();
0 commit comments