Skip to content

Commit 1d9cfc8

Browse files
committed
address @qcolombet comments
1 parent a0903c9 commit 1d9cfc8

File tree

4 files changed

+101
-43
lines changed

4 files changed

+101
-43
lines changed

mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRNVGPUToNVVM
1717
MLIRLLVMDialect
1818
MLIRNVGPUDialect
1919
MLIRNVVMDialect
20+
MLIRArithDialect
2021
MLIRPass
2122
MLIRSCFTransforms
2223
MLIRTransforms

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 84 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,26 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
5454
return b.create<LLVM::TruncOp>(b.getI32Type(), value);
5555
}
5656

57+
/// Returns warp-size as a value.
58+
static Value getWarpSizeValue(ImplicitLocOpBuilder &b) {
59+
static std::optional<Value> warpSize = std::nullopt;
60+
if (!warpSize.has_value()) {
61+
warpSize = b.create<LLVM::ConstantOp>(IntegerType::get(b.getContext(), 32),
62+
b.getI32IntegerAttr(kWarpSize));
63+
}
64+
return warpSize.value();
65+
}
66+
67+
/// Returns warp-size as a value.
68+
static Value getWarpSizeValue(ImplicitLocOpBuilder &b) {
69+
static std::optional<Value> warpSize = std::nullopt;
70+
if (!warpSize.has_value()) {
71+
warpSize = b.create<LLVM::ConstantOp>(IntegerType::get(b.getContext(), 32),
72+
b.getI32IntegerAttr(kWarpSize));
73+
}
74+
return warpSize.value();
75+
}
76+
5777
/// Returns the type for the intrinsic given the vectorResultType of the
5878
/// `gpu.mma.sync` operation.
5979
static Type inferIntrinsicResultType(Type vectorResultType) {
@@ -1441,47 +1461,80 @@ struct NVGPUWarpgroupMmaStoreOpLowering
14411461
using ConvertOpToLLVMPattern<
14421462
nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
14431463

1444-
void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op,
1445-
OpAdaptor adaptor,
1446-
ConversionPatternRewriter &rewriter,
1464+
/// This function stores a fragmented register matrix owned by a warp group
1465+
/// (128 threads) into a memref. Each thread has 64 registers, each the size
1466+
/// of a struct.
1467+
/// Here is what each threads (T) holds, each `d` is struct value with a
1468+
/// number.
1469+
///
1470+
/// Threads in warp-group (128 threads) and what they owns in the matriD:
1471+
/// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1472+
/// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1473+
/// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1474+
/// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1475+
///
1476+
/// Matrix-D:
1477+
/// +______________________________________________________________________+
1478+
/// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1479+
/// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1480+
/// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1481+
/// ..| .........|.........|.........|.........|........|...........|........|
1482+
/// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1483+
/// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1484+
/// ..| .........|.........|.........|.........|........|...........|........|
1485+
/// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1486+
/// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1487+
/// ..| .........|.........|.........|.........|........|...........|........|
1488+
/// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1489+
/// ..| .........|.........|.........|.........|........|...........|........|
1490+
/// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1491+
/// ..| .........|.........|.........|.........|........|...........|........|
1492+
/// +______________________________________________________________________+
1493+
///
1494+
/// \param rewriter: The pattern rewriter.
1495+
/// \param matrixD: Result of the warp-group MMA operation (fragmented
1496+
/// matrix). It is holded by a thread and a struct with 64 elements.
1497+
/// \param dstMemref: The memref where the registers will be stored.
1498+
/// \param offset: the offset within the memref where the registers will be
1499+
/// stored.
1500+
void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1501+
TypedValue<MemRefType> dstMemref,
14471502
int offset) const {
1448-
Location loc = op->getLoc();
1449-
Type i32 = rewriter.getI32Type();
1503+
Type i32 = b.getI32Type();
14501504

14511505
auto makeConst = [&](int32_t index) -> Value {
1452-
return rewriter.create<LLVM::ConstantOp>(
1453-
loc, i32, rewriter.getI32IntegerAttr(index));
1506+
return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
14541507
};
1508+
Value c1 = makeConst(1);
1509+
Value c2 = makeConst(2);
14551510
Value c4 = makeConst(4);
1456-
Value c32 = makeConst(kWarpSize);
14571511
Value c8 = makeConst(8);
1458-
Value c2 = makeConst(2);
1459-
Value c1 = makeConst(1);
14601512
Value c16 = makeConst(16);
1513+
Value warpSize = getWarpSizeValue(b);
14611514

14621515
auto makeMul = [&](Value lhs, Value rhs) -> Value {
1463-
return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs);
1516+
return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
14641517
};
14651518
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1466-
return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
1519+
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
14671520
};
14681521

1469-
Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32);
1470-
Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32);
1471-
Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32);
1472-
Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4);
1473-
Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4);
1522+
Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
1523+
Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
1524+
Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
1525+
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
1526+
Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
14741527

14751528
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
14761529
TypedValue<::mlir::MemRefType> memref) {
1477-
Type it = rewriter.getIndexType();
1478-
Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x);
1479-
Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y);
1480-
Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1));
1481-
Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i);
1482-
Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1);
1483-
rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0});
1484-
rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1});
1530+
Type it = b.getIndexType();
1531+
Value idx = b.create<arith::IndexCastOp>(it, x);
1532+
Value idy0 = b.create<arith::IndexCastOp>(it, y);
1533+
Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
1534+
Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
1535+
Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1536+
b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
1537+
b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
14851538
};
14861539

14871540
Value tj = makeMul(lane4modId, c2);
@@ -1493,7 +1546,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
14931546
for (int j = 0; j < 16; ++j) {
14941547
Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
14951548
int sIndex = i * 2 + j * 4;
1496-
makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref());
1549+
makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
14971550
}
14981551
}
14991552
}
@@ -1502,10 +1555,11 @@ struct NVGPUWarpgroupMmaStoreOpLowering
15021555
matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
15031556
ConversionPatternRewriter &rewriter) const override {
15041557
int offset = 0;
1505-
for (auto result : adaptor.getMatrixD()) {
1506-
auto stype = result.getType().cast<LLVM::LLVMStructType>();
1507-
storeFragmentedMatrix(result, op, adaptor, rewriter, offset);
1508-
offset += stype.getBody().size();
1558+
ImplicitLocOpBuilder lb(op->getLoc(), rewriter);
1559+
for (Value matrixD : adaptor.getMatrixD()) {
1560+
auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
1561+
storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset);
1562+
offset += structType.getBody().size();
15091563
}
15101564
rewriter.eraseOp(op);
15111565
return success();

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -780,18 +780,18 @@ func.func @warpgroup_mma_store(
780780
%matrixD: memref<128x128xf32,3>) {
781781
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
782782
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
783+
// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
784+
// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
783785
// CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
784-
// CHECK: %[[S3:.+]] = llvm.mlir.constant(32 : i32) : i32
785786
// CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32
786-
// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
787-
// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
788787
// CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32
788+
// CHECK: %[[WarpSize:.+]] = llvm.mlir.constant(32 : i32) : i32
789789

790790
// ### Store {d0, d1} of each thread ###
791791

792792
// CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32
793-
// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[S3]] : i32
794-
// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[S3]] : i32
793+
// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[WarpSize]] : i32
794+
// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[WarpSize]] : i32
795795
// CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]] : i32
796796
// CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]] : i32
797797
// CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]] : i32
@@ -856,20 +856,22 @@ func.func @warpgroup_mma_store(
856856

857857
// Pattern continues similarly 28x times until {... d62, d63}
858858

859+
// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
860+
// CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
861+
859862
// ### Store {d64, d65} of each thread ###
860863

864+
// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
865+
// CHECK: %[[S312:.+]] = llvm.mlir.constant(2 : i32) : i32
861866
// CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32
862-
// CHECK: %[[S312:.+]] = llvm.mlir.constant(32 : i32) : i32
863867
// CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32
864-
// CHECK: %[[S314:.+]] = llvm.mlir.constant(2 : i32) : i32
865-
// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
866868
// CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32
867869
// CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32
868-
// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[S312]] : i32
869-
// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[S312]] : i32
870-
// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]]
871-
// CHECK: %[[S321:.+]] = llvm.urem %[[S318]]
872-
// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S314]] : i32
870+
// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[WarpSize]] : i32
871+
// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[WarpSize]] : i32
872+
// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]], %[[S311]] : i32
873+
// CHECK: %[[S321:.+]] = llvm.urem %[[S318]], %[[S311]] : i32
874+
// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S312]] : i32
873875
// CHECK: %[[S323:.+]] = llvm.mul %[[S319]], %[[S316]] : i32
874876
// CHECK: %[[S324:.+]] = llvm.add %[[S320]], %[[S323]] : i32
875877
// CHECK: %[[S325:.+]] = llvm.mlir.constant(64 : i32) : i32

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5308,6 +5308,7 @@ cc_library(
53085308
":LLVMCommonConversion",
53095309
":LLVMDialect",
53105310
":MemRefDialect",
5311+
":MLIRArithDialect",
53115312
":NVGPUDialect",
53125313
":NVVMDialect",
53135314
":Pass",

0 commit comments

Comments
 (0)