Skip to content

[MLIR][NVGPU] Introduce nvgpu.wargroup.mma.store Op for Hopper GPUs #65441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -728,4 +728,24 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
let hasVerifier = 1;
}

def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
let description = [{
The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result
in $matrixD to given memref.

[See the details of register fragment layout for accumulator matrix D]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)

Note that, the op must be run with warp group.
}];

let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);

let assemblyFormat = [{
`[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
}];
let hasVerifier = 1;
}

#endif // NVGPU
1 change: 1 addition & 0 deletions mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRNVGPUToNVVM
MLIRLLVMDialect
MLIRNVGPUDialect
MLIRNVVMDialect
MLIRArithDialect
MLIRPass
MLIRSCFTransforms
MLIRTransforms
Expand Down
117 changes: 115 additions & 2 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
Expand Down Expand Up @@ -394,8 +395,8 @@ struct ConvertNVGPUToNVVMPass
using Base::Base;

void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect>();
registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
arith::ArithDialect>();
}

void runOnOperation() override {
Expand Down Expand Up @@ -436,6 +437,7 @@ struct ConvertNVGPUToNVVMPass
populateNVGPUToNVVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::arith::ArithDialect>();
target.addLegalDialect<::mlir::memref::MemRefDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
Expand Down Expand Up @@ -1434,6 +1436,116 @@ struct NVGPUWarpgroupMmaOpLowering
}
};

struct NVGPUWarpgroupMmaStoreOpLowering
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
using ConvertOpToLLVMPattern<
nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;

/// This function stores a fragmented register matrix owned by a warp group
/// (128 threads) into a memref. Each thread has 64 registers, each the size
/// of a struct.
/// Here is what each threads (T) holds, each `d` is struct value with a
/// number.
///
/// Threads in warp-group (128 threads) and what they owns in the matrixD:
/// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
/// 32-63 Warp-1 -> MatrixD[16:31][0:N]
/// 64-95 Warp-2 -> MatrixD[32:47][0:N]
/// 96-127 Warp-3 -> MatrixD[48:64][0:N]
///
/// Matrix-D:
/// +______________________________________________________________________+
/// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
/// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
/// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
/// ..| .........|.........|.........|.........|........|...........|........|
/// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
/// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
/// ..| .........|.........|.........|.........|........|...........|........|
/// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
/// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
/// ..| .........|.........|.........|.........|........|...........|........|
/// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
/// ..| .........|.........|.........|.........|........|...........|........|
/// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
/// ..| .........|.........|.........|.........|........|...........|........|
/// +______________________________________________________________________+
///
/// \param rewriter: The pattern rewriter.
/// \param matrixD: Result of the warp-group MMA operation (fragmented
/// matrix). It is holded by a thread and a struct with 64 elements.
/// \param dstMemref: The memref where the registers will be stored.
/// \param offset: the offset within the memref where the registers will be
/// stored.
void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
TypedValue<MemRefType> dstMemref,
int offset) const {
Type i32 = b.getI32Type();

auto makeConst = [&](int32_t index) -> Value {
return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
};
Value c1 = makeConst(1);
Value c2 = makeConst(2);
Value c4 = makeConst(4);
Value c8 = makeConst(8);
Value c16 = makeConst(16);
Value warpSize = makeConst(kWarpSize);

auto makeMul = [&](Value lhs, Value rhs) -> Value {
return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
};
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
};

Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);

auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
TypedValue<::mlir::MemRefType> memref) {
Type it = b.getIndexType();
Value idx = b.create<arith::IndexCastOp>(it, x);
Value idy0 = b.create<arith::IndexCastOp>(it, y);
Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
};

Value tj = makeMul(lane4modId, c2);
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
if (offset)
ti = makeAdd(ti, makeConst(offset));
for (int i = 0; i < 2; ++i) {
Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
for (int j = 0; j < 16; ++j) {
Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
int sIndex = i * 2 + j * 4;
makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
}
}
}

LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int offset = 0;
ImplicitLocOpBuilder lb(op->getLoc(), rewriter);
for (Value matrixD : adaptor.getMatrixD()) {
auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset);
offset += structType.getBody().size();
}
rewriter.eraseOp(op);
return success();
}
};

} // namespace

void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
Expand All @@ -1450,6 +1562,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
Expand Down
34 changes: 34 additions & 0 deletions mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -529,6 +530,39 @@ LogicalResult WarpgroupMmaOp::verify() {
return success();
}

LogicalResult WarpgroupMmaStoreOp::verify() {
MemRefType dstMemrefType = getDstMemref().getType();
VectorType firstVtype = getMatrixD()
.front()
.getType()
.cast<WarpgroupAccumulatorType>()
.getFragmented();

int64_t totalFirstDimension = 0;
for (Value result : getMatrixD()) {
VectorType vtype =
result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
if (vtype != firstVtype)
return emitOpError() << "all fragmented types must be the same";
// Limitation
if (!vtype.getElementType().isF32()) {
return emitOpError()
<< "hit a limitation: only f32 results for the time being";
}
totalFirstDimension += vtype.getDimSize(0);
}
if (totalFirstDimension != dstMemrefType.getDimSize(0) ||
firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
return emitOpError() << "results [" << totalFirstDimension << "]["
<< firstVtype.getDimSize(1)
<< "] values. However, destination memref["
<< dstMemrefType.getDimSize(0) << "]["
<< dstMemrefType.getDimSize(1)
<< "] does not have same size as results";
}
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd dialect, type, and op definitions
//===----------------------------------------------------------------------===//
Expand Down
129 changes: 129 additions & 0 deletions mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,135 @@ func.func @warpgroup_mma_128_128_64(
return
}

// CHECK-LABEL: @warpgroup_mma_store(
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
func.func @warpgroup_mma_store(
%result1 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
%result2 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
%matrixD: memref<128x128xf32,3>) {
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: %[[WarpSize:.+]] = llvm.mlir.constant(32 : i32) : i32

// ### Store {d0, d1} of each thread ###

// CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32
// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[WarpSize]] : i32
// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[WarpSize]] : i32
// CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]] : i32
// CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]] : i32
// CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]] : i32
// CHECK: %[[S14:.+]] = llvm.mul %[[S10]], %[[S7]] : i32
// CHECK: %[[S15:.+]] = llvm.add %[[S11]], %[[S14]] : i32
// CHECK: %[[S16:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[S17:.+]] = llvm.mul %[[S16]], %[[S4]] : i32
// CHECK: %[[S18:.+]] = llvm.add %[[S15]], %[[S17]] : i32
// CHECK: %[[S19:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[S20:.+]] = llvm.mul %[[S19]], %[[S4]] : i32
// CHECK: %[[S21:.+]] = llvm.add %[[S13]], %[[S20]] : i32
// CHECK: %[[S22:.+]] = arith.index_cast %[[S18]] : i32 to index
// CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index
// CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]] : i32
// CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index
// CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct
// CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct
// CHECK: memref.store %[[S26]], %[[arg2]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3>
// CHECK: memref.store %[[S27]], %[[arg2]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3>

// ### Store {d2, d3} of each thread ###

// CHECK: %[[S28:.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[S29:.+]] = llvm.mul %[[S28]], %[[S4]] : i32
// CHECK: %[[S30:.+]] = llvm.add %[[S13]], %[[S29]] : i32
// CHECK: %[[S31:.+]] = arith.index_cast %[[S18]] : i32 to index
// CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index
// CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]] : i32
// CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index
// CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct<
// CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct<
// CHECK: memref.store %[[S35]], %[[arg2]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3>
// CHECK: memref.store %[[S36]], %[[arg2]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3>

// ### Store {d4, d5} of each thread ###

// CHECK: %[[S37:.+]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[S38:.+]] = llvm.mul %[[S37]], %[[S4]] : i32
// CHECK: %[[S39:.+]] = llvm.add %[[S13]], %[[S38]] : i32
// CHECK: %[[S40:.+]] = arith.index_cast %[[S18]] : i32 to index
// CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index
// CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]] : i32
// CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index
// CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct<
// CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct<
// CHECK: memref.store %[[S44]], %[[arg2]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3>
// CHECK: memref.store %[[S45]], %[[arg2]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3>

// ### Store {d6, d7} of each thread ###

// CHECK: %[[S46:.+]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: %[[S47:.+]] = llvm.mul %[[S46]], %[[S4]] : i32
// CHECK: %[[S48:.+]] = llvm.add %[[S13]], %[[S47]] : i32
// CHECK: %[[S49:.+]] = arith.index_cast %[[S18]] : i32 to index
// CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index
// CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]] : i32
// CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index
// CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct<
// CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct<
// CHECK: memref.store %[[S53]], %[[arg2]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3>
// CHECK: memref.store %[[S54]], %[[arg2]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3>

// Pattern continues similarly 28x times until {... d62, d63}

// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32

// ### Store {d64, d65} of each thread ###

// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[S312:.+]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: %[[WS2:.+]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32
// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[WS2]] : i32
// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[WS2]] : i32
// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]], %[[S311]] : i32
// CHECK: %[[S321:.+]] = llvm.urem %[[S318]], %[[S311]] : i32
// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S312]] : i32
// CHECK: %[[S323:.+]] = llvm.mul %[[S319]], %[[S316]] : i32
// CHECK: %[[S324:.+]] = llvm.add %[[S320]], %[[S323]] : i32
// CHECK: %[[S325:.+]] = llvm.mlir.constant(64 : i32) : i32
// CHECK: %[[S326:.+]] = llvm.add %[[S324]], %[[S325]] : i32
// CHECK: %[[S327:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[S328:.+]] = llvm.mul %[[S327]], %[[S313]] : i32
// CHECK: %[[S329:.+]] = llvm.add %[[S326]], %[[S328]] : i32
// CHECK: %[[S330:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[S331:.+]] = llvm.mul %[[S330]], %[[S313]] : i32
// CHECK: %[[S332:.+]] = llvm.add %[[S322]], %[[S331]] : i32
// CHECK: %[[S333:.+]] = arith.index_cast %[[S329]] : i32 to index
// CHECK: %[[S334:.+]] = arith.index_cast %[[S332]] : i32 to index
// CHECK: %[[S335:.+]] = llvm.add %[[S332]], %[[S315]] : i32
// CHECK: %[[S336:.+]] = arith.index_cast %[[S335]] : i32 to index
// CHECK: %[[S337:.+]] = llvm.extractvalue %[[S1]][0]
// CHECK: %[[S338:.+]] = llvm.extractvalue %[[S1]][1]
// CHECK: memref.store %[[S337]], %[[arg2]][%[[S333]], %[[S334]]] : memref<128x128xf32, 3>
// CHECK: memref.store %[[S338]], %[[arg2]][%[[S333]], %[[S336]]] : memref<128x128xf32, 3>

// Pattern continues similarly 31x times until {... d126, d127}

nvgpu.warpgroup.mma.store [%result1, %result2], %matrixD :
!nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
to memref<128x128xf32,3>
return
}

transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1
Expand Down
1 change: 1 addition & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5308,6 +5308,7 @@ cc_library(
":LLVMCommonConversion",
":LLVMDialect",
":MemRefDialect",
":MLIRArithDialect",
":NVGPUDialect",
":NVVMDialect",
":Pass",
Expand Down