Skip to content

Commit 9596e83

Browse files
krzysz00manupak
andauthored
[mlir][AMDGPU] Enable emulating vector buffer_atomic_fadd on gfx11 (#108312)
* Fix a bug introduced by the Chipset refactoring in #107720 where atomics emulation for adds was mistakenly applied to gfx11+ * Add the case needed for gfx11+ atomic emulation, namely that gfx11 doesn't support atomically adding a v2f16 or v2bf16, thus requiring MLIR-level legalization for buffer intrinsics that attempt to do such an addition * Add tests, including tests for gfx11 atomic emulation Co-authored-by: Manupa Karunaratne <[email protected]>
1 parent 90a0be9 commit 9596e83

File tree

7 files changed

+68
-6
lines changed

7 files changed

+68
-6
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def AMDGPU_RawBufferAtomicCmpswapOp :
214214
AttrSizedOperandSegments,
215215
AllTypesMatch<["src", "cmp", "value"]>,
216216
AllElementTypesMatch<["value", "memref"]>]>,
217-
Arguments<(ins AnyTypeOf<[I32, I64, F32, F64]>:$src,
217+
Arguments<(ins AnyType:$src,
218218
AnyType:$cmp,
219219
Arg<AnyMemRef, "buffer to operate on", [MemRead, MemWrite]>:$memref,
220220
Variadic<I32>:$indices,

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
2424
let dependentDialects = [
2525
"cf::ControlFlowDialect",
2626
"arith::ArithDialect",
27+
"vector::VectorDialect"
2728
];
2829
let options = [Option<"chipset", "chipset", "std::string",
2930
/*default=*/"\"gfx000\"",

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
108108
if (wantedVecType.getElementType().isBF16())
109109
llvmBufferValType = wantedVecType.clone(rewriter.getI16Type());
110110
if (atomicCmpData) {
111-
if (isa<VectorType>(wantedDataType))
112-
return gpuOp.emitOpError("vector compare-and-swap does not exist");
113111
if (auto floatType = dyn_cast<FloatType>(wantedDataType))
114112
llvmBufferValType = this->getTypeConverter()->convertType(
115113
rewriter.getIntegerType(floatType.getWidth()));

mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
1111
MLIRAMDGPUDialect
1212
MLIRAMDGPUUtils
1313
MLIRArithDialect
14+
MLIRVectorDialect
1415
MLIRControlFlowDialect
1516
MLIRFuncDialect
1617
MLIRIR

mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
1515
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1617
#include "mlir/IR/BuiltinAttributes.h"
18+
#include "mlir/IR/TypeUtilities.h"
1719
#include "mlir/Transforms/DialectConversion.h"
1820

1921
namespace mlir::amdgpu {
@@ -86,6 +88,23 @@ static void patchOperandSegmentSizes(ArrayRef<NamedAttribute> attrs,
8688
}
8789
}
8890

91+
// A helper function to flatten a vector value to a scalar containing its bits,
92+
// returning the value itself if othetwise.
93+
static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc,
94+
Value val) {
95+
auto vectorType = dyn_cast<VectorType>(val.getType());
96+
if (!vectorType)
97+
return val;
98+
99+
int64_t bitwidth =
100+
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
101+
Type allBitsType = rewriter.getIntegerType(bitwidth);
102+
auto allBitsVecType = VectorType::get({1}, allBitsType);
103+
Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val);
104+
Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0);
105+
return scalar;
106+
}
107+
89108
template <typename AtomicOp, typename ArithOp>
90109
LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
91110
AtomicOp atomicOp, Adaptor adaptor,
@@ -113,6 +132,7 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
113132
rewriter.setInsertionPointToEnd(loopBlock);
114133
Value prevLoad = loopBlock->getArgument(0);
115134
Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
135+
dataType = operated.getType();
116136

117137
SmallVector<NamedAttribute> cmpswapAttrs;
118138
patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
@@ -126,8 +146,8 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
126146
// an int->float bitcast is introduced to account for the fact that cmpswap
127147
// only takes integer arguments.
128148

129-
Value prevLoadForCompare = prevLoad;
130-
Value atomicResForCompare = atomicRes;
149+
Value prevLoadForCompare = flattenVecToBits(rewriter, loc, prevLoad);
150+
Value atomicResForCompare = flattenVecToBits(rewriter, loc, atomicRes);
131151
if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
132152
Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
133153
prevLoadForCompare =
@@ -146,9 +166,17 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
146166
void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
147167
ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) {
148168
// gfx10 has no atomic adds.
149-
if (chipset >= Chipset(10, 0, 0) || chipset < Chipset(9, 0, 8)) {
169+
if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) {
150170
target.addIllegalOp<RawBufferAtomicFaddOp>();
151171
}
172+
// gfx11 has no fp16 atomics
173+
if (chipset.majorVersion == 11) {
174+
target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
175+
[](RawBufferAtomicFaddOp op) -> bool {
176+
Type elemType = getElementTypeOrSelf(op.getValue().getType());
177+
return !isa<Float16Type, BFloat16Type>(elemType);
178+
});
179+
}
152180
// gfx9 has no to a very limited support for floating-point min and max.
153181
if (chipset.majorVersion == 9) {
154182
if (chipset >= Chipset(9, 0, 0xa) && chipset != Chipset(9, 4, 1)) {

mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,18 @@ func.func @amdgpu_raw_buffer_atomic_cmpswap_i64(%src : i64, %cmp : i64, %buf : m
224224
func.return %dst : i64
225225
}
226226

227+
// CHECK-LABEL: func @amdgpu_raw_buffer_atomic_cmpswap_v2f16
228+
// CHECK-SAME: (%[[src:.*]]: vector<2xf16>, %[[cmp:.*]]: vector<2xf16>, {{.*}})
229+
func.func @amdgpu_raw_buffer_atomic_cmpswap_v2f16(%src : vector<2xf16>, %cmp : vector<2xf16>, %buf : memref<64xf16>, %idx: i32) -> vector<2xf16> {
230+
// CHECK-DAG: %[[srcBits:.+]] = llvm.bitcast %[[src]] : vector<2xf16> to i32
231+
// CHECK-DAG: %[[cmpBits:.+]] = llvm.bitcast %[[cmp]] : vector<2xf16> to i32
232+
// CHECK: %[[dstBits:.+]] = rocdl.raw.ptr.buffer.atomic.cmpswap %[[srcBits]], %[[cmpBits]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32
233+
// CHECK: %[[dst:.+]] = llvm.bitcast %[[dstBits]] : i32 to vector<2xf16>
234+
// CHECK: return %[[dst]]
235+
%dst = amdgpu.raw_buffer_atomic_cmpswap {boundsCheck = true} %src, %cmp -> %buf[%idx] : vector<2xf16> -> memref<64xf16>, i32
236+
func.return %dst : vector<2xf16>
237+
}
238+
227239
// CHECK-LABEL: func @lds_barrier
228240
func.func @lds_barrier() {
229241
// GFX908: llvm.inline_asm has_side_effects asm_dialect = att

mlir/test/Dialect/AMDGPU/amdgpu-emulate-atomics.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx90a %s | FileCheck %s --check-prefixes=CHECK,GFX9
22
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx1030 %s | FileCheck %s --check-prefixes=CHECK,GFX10
3+
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx1100 %s | FileCheck %s --check-prefixes=CHECK,GFX11
34

45
// -----
56

@@ -8,6 +9,7 @@ func.func @atomic_fmax(%val: f32, %buffer: memref<?xf32>, %idx: i32) {
89
// CHECK-SAME: ([[val:%.+]]: f32, [[buffer:%.+]]: memref<?xf32>, [[idx:%.+]]: i32)
910
// CHECK: gpu.printf "Begin\0A"
1011
// GFX10: amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} [[val]] -> [[buffer]][[[idx]]]
12+
// GFX11: amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} [[val]] -> [[buffer]][[[idx]]]
1113
// GFX9: [[ld:%.+]] = amdgpu.raw_buffer_load {foo, indexOffset = 4 : i32} [[buffer]][[[idx]]]
1214
// GFX9: cf.br [[loop:\^.+]]([[ld]] : f32)
1315
// GFX9: [[loop]]([[arg:%.+]]: f32):
@@ -33,6 +35,7 @@ func.func @atomic_fmax_f64(%val: f64, %buffer: memref<?xf64>, %idx: i32) {
3335
// CHECK: gpu.printf "Begin\0A"
3436
// GFX9: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
3537
// GFX10: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
38+
// GFX11: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
3639
// CHECK-NEXT: gpu.printf "End\0A"
3740
gpu.printf "Begin\n"
3841
amdgpu.raw_buffer_atomic_fmax %val -> %buffer[%idx] : f64 -> memref<?xf64>, i32
@@ -47,6 +50,25 @@ func.func @atomic_fadd(%val: f32, %buffer: memref<?xf32>, %idx: i32) {
4750
// GFX9: amdgpu.raw_buffer_atomic_fadd
4851
// GFX10: amdgpu.raw_buffer_load
4952
// GFX10: amdgpu.raw_buffer_atomic_cmpswap
53+
// GFX11: amdgpu.raw_buffer_atomic_fadd
5054
amdgpu.raw_buffer_atomic_fadd %val -> %buffer[%idx] : f32 -> memref<?xf32>, i32
5155
func.return
5256
}
57+
58+
// CHECK: func @atomic_fadd_v2f16
59+
func.func @atomic_fadd_v2f16(%val: vector<2xf16>, %buffer: memref<?xf16>, %idx: i32) {
60+
// GFX9: amdgpu.raw_buffer_atomic_fadd
61+
// GFX10: amdgpu.raw_buffer_load
62+
// GFX10: amdgpu.raw_buffer_atomic_cmpswap
63+
// Note: the atomic operation itself will be done over i32, and then we use bitcasts
64+
// to scalars in order to test for exact bitwise equality instead of float
65+
// equality.
66+
// GFX11: %[[old:.+]] = amdgpu.raw_buffer_atomic_cmpswap
67+
// GFX11: %[[vecCastExpected:.+]] = vector.bitcast %{{.*}} : vector<2xf16> to vector<1xi32>
68+
// GFX11: %[[scalarExpected:.+]] = vector.extract %[[vecCastExpected]][0]
69+
// GFX11: %[[vecCastOld:.+]] = vector.bitcast %[[old]] : vector<2xf16> to vector<1xi32>
70+
// GFX11: %[[scalarOld:.+]] = vector.extract %[[vecCastOld]][0]
71+
// GFX11: arith.cmpi eq, %[[scalarOld]], %[[scalarExpected]]
72+
amdgpu.raw_buffer_atomic_fadd %val -> %buffer[%idx] : vector<2xf16> -> memref<?xf16>, i32
73+
func.return
74+
}

0 commit comments

Comments
 (0)