Skip to content

Commit 2ca1556

Browse files
committed
rebase and add more comment
1 parent a119578 commit 2ca1556

File tree

1 file changed

+34
-25
lines changed

1 file changed

+34
-25
lines changed

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,15 +1160,29 @@ struct NVGPUWarpgroupMmaOpLowering
11601160
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
11611161
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
11621162

1163-
/// This class assists in generating WgmmaMmaAsyncOp instructions to complete
1164-
/// a specified shape. If the GEMM shape is larger than the shape of a wgmma
1165-
/// instrution, it can generate multiple wgmma instructions, group and execute
1166-
/// them asynchronously. The class also handles waiting for instruction
1167-
/// completion and iterates through GenerateGmmaDescriptor to create
1168-
/// descriptors for each instruction.
1163+
/// This is a helper class to generate required NVVM Ops for warp-group level
1164+
/// matrix multiplication.
1165+
/// When the given GEMM shape is larger than the shape of
1166+
/// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1167+
/// Op(s), group and execute them asynchronously. The class also handles
1168+
/// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1169+
/// create descriptors for each instruction.
1170+
///
1171+
/// For example this is the case when the shape of GEMM is 128x128x128
1172+
///
1173+
/// nvvm.wgmma.fence.aligned
1174+
///
1175+
/// nvvm.wgmma.mma.async descA, descB
1176+
/// iterate(descA, descB)
1177+
/// nvvm.wgmma.mma.async descA, descB
1178+
/// [6x times more]
1179+
///
1180+
/// nvvm.wgmma.group.sync.aligned
1181+
/// nvvm.wgmma.wait.group.sync [groupId]
1182+
///
11691183
class WarpgroupGemm {
11701184
nvgpu::WarpgroupMmaOp op;
1171-
ConversionPatternRewriter &rewriter;
1185+
ImplicitLocOpBuilder b;
11721186
OpAdaptor adaptor;
11731187
const LLVMTypeConverter &typeConverter;
11741188

@@ -1253,8 +1267,7 @@ struct NVGPUWarpgroupMmaOpLowering
12531267

12541268
/// Basic function to generate Add
12551269
Value makeAdd(Value lhs, Value rhs) {
1256-
return rewriter.create<LLVM::AddOp>(op->getLoc(), lhs.getType(), lhs,
1257-
rhs);
1270+
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
12581271
};
12591272

12601273
/// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
@@ -1287,7 +1300,7 @@ struct NVGPUWarpgroupMmaOpLowering
12871300
<< incrementVal << " | \t ");
12881301
if (!incrementVal)
12891302
return desc;
1290-
return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
1303+
return makeAdd(desc, makeI64Const(b, incrementVal));
12911304
}
12921305

12931306
/// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
@@ -1310,7 +1323,7 @@ struct NVGPUWarpgroupMmaOpLowering
13101323
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
13111324
if (!incrementVal)
13121325
return desc;
1313-
return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
1326+
return makeAdd(desc, makeI64Const(b, incrementVal));
13141327
}
13151328

13161329
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
@@ -1346,10 +1359,9 @@ struct NVGPUWarpgroupMmaOpLowering
13461359

13471360
Type resultStructType = typeConverter.convertType(matrixD.getType());
13481361

1349-
return rewriter.create<NVVM::WgmmaMmaAsyncOp>(
1350-
op->getLoc(), resultStructType, matrixC, descriptorA, descriptorB,
1351-
shape, itypeA, itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1352-
overflow);
1362+
return b.create<NVVM::WgmmaMmaAsyncOp>(
1363+
resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
1364+
itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
13531365
}
13541366

13551367
/// Generates multiple wgmma instructions to complete the given GEMM shape
@@ -1370,10 +1382,9 @@ struct NVGPUWarpgroupMmaOpLowering
13701382
}
13711383

13721384
public:
1373-
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ConversionPatternRewriter &rewriter,
1385+
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
13741386
OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
1375-
: op(op), rewriter(rewriter), adaptor(adaptor),
1376-
typeConverter(typeConverter) {
1387+
: op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
13771388
// Find the entire GEMM Shape
13781389
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
13791390
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
@@ -1399,22 +1410,20 @@ struct NVGPUWarpgroupMmaOpLowering
13991410
/// (WgmmaGroupSyncAlignedOp) for group synchronization
14001411
/// (WgmmaWaitGroupSyncOp) after the instructions.
14011412
SmallVector<Value> generateWarpgroupMma() {
1402-
Location loc = op->getLoc();
1403-
rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
1413+
b.create<NVVM::WgmmaFenceAlignedOp>();
14041414
SmallVector<Value> wgmmaResults = generateWgmmaGroup();
1405-
rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
1406-
rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
1407-
1415+
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1416+
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
14081417
return wgmmaResults;
14091418
}
14101419
};
14111420

14121421
LogicalResult
14131422
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
14141423
ConversionPatternRewriter &rewriter) const override {
1424+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
14151425
// Step 1. Build a helper class
1416-
WarpgroupGemm warpgroupGemm(op, rewriter, adaptor,
1417-
*this->getTypeConverter());
1426+
WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
14181427

14191428
// Step 2. Get the entire GEMM Shape
14201429
SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();

0 commit comments

Comments
 (0)