Skip to content

Commit dc24a32

Browse files
committed
[MLIR][NVGPU] Introduce nvgpu.warpgroup.mma.store Op for Hopper GPUs
This work introduces a new operation called `warpgroup.mma.store` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted results of WGMMA to the given memref. An example of fragmentation is given here : https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d The `warpgroup.mma.store` does followings: 1) Takes one or more fragmented results matrix. 2) Calculates indexes per thread in warp group and stores the data into give memref. Here's an example usage of the `nvgpu.warpgroup.mma` operation: ``` // Performs matmul, results are fragmented and in registers %res, %res2 = nvgpu.warpgroup.mma ... // Stores the fragmented result to the give memory nvgpu.warpgroup.mma.store [%res1, %res2], %matrixD : !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>, !nvgpu.warpgroup.result<tensor = !llvm.struct<...>> to memref<128x128xf32,3> ``` Depends on llvm#65440
1 parent b74cfc1 commit dc24a32

File tree

3 files changed

+129
-2
lines changed

3 files changed

+129
-2
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,4 +728,23 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
728728
let hasVerifier = 1;
729729
}
730730

731+
def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
732+
let description = [{
733+
The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result
734+
in $matrixD to give memref.
735+
736+
[See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
737+
738+
Note that, the op must be run with warp group.
739+
}];
740+
741+
let arguments = (ins Variadic<NVGPU_WarpgroupResult>:$matrixD,
742+
Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
743+
744+
let assemblyFormat = [{
745+
`[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
746+
}];
747+
let hasVerifier = 1;
748+
}
749+
731750
#endif // NVGPU

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1212
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1313
#include "mlir/Conversion/LLVMCommon/Pattern.h"
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
1415
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1516
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1617
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -394,8 +395,8 @@ struct ConvertNVGPUToNVVMPass
394395
using Base::Base;
395396

396397
void getDependentDialects(DialectRegistry &registry) const override {
397-
registry
398-
.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect>();
398+
registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
399+
arith::ArithDialect>();
399400
}
400401

401402
void runOnOperation() override {
@@ -436,6 +437,7 @@ struct ConvertNVGPUToNVVMPass
436437
populateNVGPUToNVVMConversionPatterns(converter, patterns);
437438
LLVMConversionTarget target(getContext());
438439
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
440+
target.addLegalDialect<::mlir::arith::ArithDialect>();
439441
target.addLegalDialect<::mlir::memref::MemRefDialect>();
440442
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
441443
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
@@ -1434,11 +1436,88 @@ struct NVGPUWarpgroupMmaOpLowering
14341436
}
14351437
};
14361438

1439+
struct NVGPUWarpgroupMmaStoreOpLowering
1440+
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1441+
using ConvertOpToLLVMPattern<
1442+
nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1443+
1444+
void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op,
1445+
OpAdaptor adaptor,
1446+
ConversionPatternRewriter &rewriter,
1447+
int offset) const {
1448+
Location loc = op->getLoc();
1449+
Type i32 = rewriter.getI32Type();
1450+
1451+
auto makeConst = [&](int32_t index) -> Value {
1452+
return rewriter.create<LLVM::ConstantOp>(
1453+
loc, i32, rewriter.getI32IntegerAttr(index));
1454+
};
1455+
Value c4 = makeConst(4);
1456+
Value c32 = makeConst(kWarpSize);
1457+
Value c8 = makeConst(8);
1458+
Value c2 = makeConst(2);
1459+
Value c1 = makeConst(1);
1460+
Value c16 = makeConst(16);
1461+
1462+
auto makeMul = [&](Value lhs, Value rhs) -> Value {
1463+
return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs);
1464+
};
1465+
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1466+
return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
1467+
};
1468+
1469+
Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32);
1470+
Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32);
1471+
Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32);
1472+
Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4);
1473+
Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4);
1474+
1475+
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1476+
TypedValue<::mlir::MemRefType> memref) {
1477+
Type it = rewriter.getIndexType();
1478+
Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x);
1479+
Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y);
1480+
Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1));
1481+
Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i);
1482+
Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1);
1483+
rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0});
1484+
rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1});
1485+
};
1486+
1487+
Value tj = makeMul(lane4modId, c2);
1488+
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1489+
if (offset)
1490+
ti = makeAdd(ti, makeConst(offset));
1491+
for (int i = 0; i < 2; ++i) {
1492+
Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1493+
for (int j = 0; j < 16; ++j) {
1494+
Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1495+
int sIndex = i * 2 + j * 4;
1496+
makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref());
1497+
}
1498+
}
1499+
}
1500+
1501+
LogicalResult
1502+
matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1503+
ConversionPatternRewriter &rewriter) const override {
1504+
int offset = 0;
1505+
for (auto result : adaptor.getMatrixD()) {
1506+
auto stype = result.getType().cast<LLVM::LLVMStructType>();
1507+
storeFragmentedMatrix(result, op, adaptor, rewriter, offset);
1508+
offset += stype.getBody().size();
1509+
}
1510+
rewriter.eraseOp(op);
1511+
return success();
1512+
}
1513+
};
1514+
14371515
} // namespace
14381516

14391517
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
14401518
RewritePatternSet &patterns) {
14411519
patterns.add<
1520+
NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store`
14421521
NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
14431522
NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
14441523
NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
1414
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1516
#include "mlir/IR/Builders.h"
1617
#include "mlir/IR/BuiltinAttributes.h"
1718
#include "mlir/IR/BuiltinTypes.h"
@@ -529,6 +530,34 @@ LogicalResult WarpgroupMmaOp::verify() {
529530
return success();
530531
}
531532

533+
LogicalResult WarpgroupMmaStoreOp::verify() {
534+
Type stype =
535+
getMatrixD().front().getType().cast<WarpgroupResultType>().getTensor();
536+
537+
for (auto result : getMatrixD()) {
538+
auto resultStype = result.getType()
539+
.cast<WarpgroupResultType>()
540+
.getTensor()
541+
.dyn_cast<LLVM::LLVMStructType>();
542+
if (!resultStype)
543+
return emitOpError() << "result is " << result.getType()
544+
<< " but must keep type of llvm struct";
545+
if (stype != resultStype)
546+
return emitOpError() << "all results must be the same type";
547+
548+
// todo improve this limitation
549+
if (!resultStype.getBody().front().isF32()) {
550+
return emitOpError() << "supporst only f32 results for the time being";
551+
}
552+
}
553+
554+
if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) {
555+
return emitOpError() << "all element types must be equal ";
556+
}
557+
558+
return success();
559+
}
560+
532561
//===----------------------------------------------------------------------===//
533562
// TableGen'd dialect, type, and op definitions
534563
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)