Skip to content

Commit 96da274

Browse files
committed
address @qcolombet comments
1 parent 1d9cfc8 commit 96da274

File tree

4 files changed

+7
-26
lines changed

4 files changed

+7
-26
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
731731
def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
732732
let description = [{
733733
The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result
734-
in $matrixD to give memref.
734+
in $matrixD to given memref.
735735

736736
[See the details of register fragment layout for accumulator matrix D]
737737
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,6 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
5454
return b.create<LLVM::TruncOp>(b.getI32Type(), value);
5555
}
5656

57-
/// Returns warp-size as a value.
58-
static Value getWarpSizeValue(ImplicitLocOpBuilder &b) {
59-
static std::optional<Value> warpSize = std::nullopt;
60-
if (!warpSize.has_value()) {
61-
warpSize = b.create<LLVM::ConstantOp>(IntegerType::get(b.getContext(), 32),
62-
b.getI32IntegerAttr(kWarpSize));
63-
}
64-
return warpSize.value();
65-
}
66-
67-
/// Returns warp-size as a value.
68-
static Value getWarpSizeValue(ImplicitLocOpBuilder &b) {
69-
static std::optional<Value> warpSize = std::nullopt;
70-
if (!warpSize.has_value()) {
71-
warpSize = b.create<LLVM::ConstantOp>(IntegerType::get(b.getContext(), 32),
72-
b.getI32IntegerAttr(kWarpSize));
73-
}
74-
return warpSize.value();
75-
}
76-
7757
/// Returns the type for the intrinsic given the vectorResultType of the
7858
/// `gpu.mma.sync` operation.
7959
static Type inferIntrinsicResultType(Type vectorResultType) {
@@ -1467,7 +1447,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
14671447
/// Here is what each threads (T) holds, each `d` is struct value with a
14681448
/// number.
14691449
///
1470-
/// Threads in warp-group (128 threads) and what they owns in the matriD:
1450+
/// Threads in warp-group (128 threads) and what they owns in the matrixD:
14711451
/// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
14721452
/// 32-63 Warp-1 -> MatrixD[16:31][0:N]
14731453
/// 64-95 Warp-2 -> MatrixD[32:47][0:N]
@@ -1510,7 +1490,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
15101490
Value c4 = makeConst(4);
15111491
Value c8 = makeConst(8);
15121492
Value c16 = makeConst(16);
1513-
Value warpSize = getWarpSizeValue(b);
1493+
Value warpSize = makeConst(kWarpSize);
15141494

15151495
auto makeMul = [&](Value lhs, Value rhs) -> Value {
15161496
return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ LogicalResult WarpgroupMmaStoreOp::verify() {
539539
.getFragmented();
540540

541541
int64_t totalFirstDimension = 0;
542-
for (auto result : getMatrixD()) {
542+
for (Value result : getMatrixD()) {
543543
VectorType vtype =
544544
result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
545545
if (vtype != firstVtype)

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -866,9 +866,10 @@ func.func @warpgroup_mma_store(
866866
// CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32
867867
// CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32
868868
// CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32
869+
// CHECK: %[[WS2:.+]] = llvm.mlir.constant(32 : i32) : i32
869870
// CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32
870-
// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[WarpSize]] : i32
871-
// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[WarpSize]] : i32
871+
// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[WS2]] : i32
872+
// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[WS2]] : i32
872873
// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]], %[[S311]] : i32
873874
// CHECK: %[[S321:.+]] = llvm.urem %[[S318]], %[[S311]] : i32
874875
// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S312]] : i32

0 commit comments

Comments
 (0)