Skip to content

[mlir][AMDGPU] Enable emulating vector buffer_atomic_fadd on gfx11 #108312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def AMDGPU_RawBufferAtomicCmpswapOp :
AttrSizedOperandSegments,
AllTypesMatch<["src", "cmp", "value"]>,
AllElementTypesMatch<["value", "memref"]>]>,
Arguments<(ins AnyTypeOf<[I32, I64, F32, F64]>:$src,
Arguments<(ins AnyType:$src,
AnyType:$cmp,
Arg<AnyMemRef, "buffer to operate on", [MemRead, MemWrite]>:$memref,
Variadic<I32>:$indices,
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
let dependentDialects = [
"cf::ControlFlowDialect",
"arith::ArithDialect",
"vector::VectorDialect"
];
let options = [Option<"chipset", "chipset", "std::string",
/*default=*/"\"gfx000\"",
Expand Down
2 changes: 0 additions & 2 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (wantedVecType.getElementType().isBF16())
llvmBufferValType = wantedVecType.clone(rewriter.getI16Type());
if (atomicCmpData) {
if (isa<VectorType>(wantedDataType))
return gpuOp.emitOpError("vector compare-and-swap does not exist");
if (auto floatType = dyn_cast<FloatType>(wantedDataType))
llvmBufferValType = this->getTypeConverter()->convertType(
rewriter.getIntegerType(floatType.getWidth()));
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
MLIRAMDGPUDialect
MLIRAMDGPUUtils
MLIRArithDialect
MLIRVectorDialect
MLIRControlFlowDialect
MLIRFuncDialect
MLIRIR
Expand Down
34 changes: 31 additions & 3 deletions mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::amdgpu {
Expand Down Expand Up @@ -86,6 +88,23 @@ static void patchOperandSegmentSizes(ArrayRef<NamedAttribute> attrs,
}
}

// A helper function to flatten a vector value to a scalar containing its bits,
// returning the value itself if othetwise.
static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc,
Value val) {
auto vectorType = dyn_cast<VectorType>(val.getType());
if (!vectorType)
return val;

int64_t bitwidth =
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
Type allBitsType = rewriter.getIntegerType(bitwidth);
auto allBitsVecType = VectorType::get({1}, allBitsType);
Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val);
Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0);
return scalar;
}

template <typename AtomicOp, typename ArithOp>
LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
AtomicOp atomicOp, Adaptor adaptor,
Expand Down Expand Up @@ -113,6 +132,7 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
rewriter.setInsertionPointToEnd(loopBlock);
Value prevLoad = loopBlock->getArgument(0);
Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
dataType = operated.getType();

SmallVector<NamedAttribute> cmpswapAttrs;
patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
Expand All @@ -126,8 +146,8 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
// an int->float bitcast is introduced to account for the fact that cmpswap
// only takes integer arguments.

Value prevLoadForCompare = prevLoad;
Value atomicResForCompare = atomicRes;
Value prevLoadForCompare = flattenVecToBits(rewriter, loc, prevLoad);
Value atomicResForCompare = flattenVecToBits(rewriter, loc, atomicRes);
if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
prevLoadForCompare =
Expand All @@ -146,9 +166,17 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) {
// gfx10 has no atomic adds.
if (chipset >= Chipset(10, 0, 0) || chipset < Chipset(9, 0, 8)) {
if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) {
target.addIllegalOp<RawBufferAtomicFaddOp>();
}
// gfx11 has no fp16 atomics
if (chipset.majorVersion == 11) {
target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
[](RawBufferAtomicFaddOp op) -> bool {
Type elemType = getElementTypeOrSelf(op.getValue().getType());
return !isa<Float16Type, BFloat16Type>(elemType);
});
}
// gfx9 has no to a very limited support for floating-point min and max.
if (chipset.majorVersion == 9) {
if (chipset >= Chipset(9, 0, 0xa) && chipset != Chipset(9, 4, 1)) {
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,18 @@ func.func @amdgpu_raw_buffer_atomic_cmpswap_i64(%src : i64, %cmp : i64, %buf : m
func.return %dst : i64
}

// CHECK-LABEL: func @amdgpu_raw_buffer_atomic_cmpswap_v2f16
// CHECK-SAME: (%[[src:.*]]: vector<2xf16>, %[[cmp:.*]]: vector<2xf16>, {{.*}})
func.func @amdgpu_raw_buffer_atomic_cmpswap_v2f16(%src : vector<2xf16>, %cmp : vector<2xf16>, %buf : memref<64xf16>, %idx: i32) -> vector<2xf16> {
// CHECK-DAG: %[[srcBits:.+]] = llvm.bitcast %[[src]] : vector<2xf16> to i32
// CHECK-DAG: %[[cmpBits:.+]] = llvm.bitcast %[[cmp]] : vector<2xf16> to i32
// CHECK: %[[dstBits:.+]] = rocdl.raw.ptr.buffer.atomic.cmpswap %[[srcBits]], %[[cmpBits]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32
// CHECK: %[[dst:.+]] = llvm.bitcast %[[dstBits]] : i32 to vector<2xf16>
// CHECK: return %[[dst]]
%dst = amdgpu.raw_buffer_atomic_cmpswap {boundsCheck = true} %src, %cmp -> %buf[%idx] : vector<2xf16> -> memref<64xf16>, i32
func.return %dst : vector<2xf16>
}

// CHECK-LABEL: func @lds_barrier
func.func @lds_barrier() {
// GFX908: llvm.inline_asm has_side_effects asm_dialect = att
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Dialect/AMDGPU/amdgpu-emulate-atomics.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx90a %s | FileCheck %s --check-prefixes=CHECK,GFX9
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx1030 %s | FileCheck %s --check-prefixes=CHECK,GFX10
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx1100 %s | FileCheck %s --check-prefixes=CHECK,GFX11

// -----

Expand All @@ -8,6 +9,7 @@ func.func @atomic_fmax(%val: f32, %buffer: memref<?xf32>, %idx: i32) {
// CHECK-SAME: ([[val:%.+]]: f32, [[buffer:%.+]]: memref<?xf32>, [[idx:%.+]]: i32)
// CHECK: gpu.printf "Begin\0A"
// GFX10: amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} [[val]] -> [[buffer]][[[idx]]]
// GFX11: amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} [[val]] -> [[buffer]][[[idx]]]
// GFX9: [[ld:%.+]] = amdgpu.raw_buffer_load {foo, indexOffset = 4 : i32} [[buffer]][[[idx]]]
// GFX9: cf.br [[loop:\^.+]]([[ld]] : f32)
// GFX9: [[loop]]([[arg:%.+]]: f32):
Expand All @@ -33,6 +35,7 @@ func.func @atomic_fmax_f64(%val: f64, %buffer: memref<?xf64>, %idx: i32) {
// CHECK: gpu.printf "Begin\0A"
// GFX9: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
// GFX10: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
// GFX11: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
// CHECK-NEXT: gpu.printf "End\0A"
gpu.printf "Begin\n"
amdgpu.raw_buffer_atomic_fmax %val -> %buffer[%idx] : f64 -> memref<?xf64>, i32
Expand All @@ -47,6 +50,25 @@ func.func @atomic_fadd(%val: f32, %buffer: memref<?xf32>, %idx: i32) {
// GFX9: amdgpu.raw_buffer_atomic_fadd
// GFX10: amdgpu.raw_buffer_load
// GFX10: amdgpu.raw_buffer_atomic_cmpswap
// GFX11: amdgpu.raw_buffer_atomic_fadd
amdgpu.raw_buffer_atomic_fadd %val -> %buffer[%idx] : f32 -> memref<?xf32>, i32
func.return
}

// CHECK: func @atomic_fadd_v2f16
func.func @atomic_fadd_v2f16(%val: vector<2xf16>, %buffer: memref<?xf16>, %idx: i32) {
// GFX9: amdgpu.raw_buffer_atomic_fadd
// GFX10: amdgpu.raw_buffer_load
// GFX10: amdgpu.raw_buffer_atomic_cmpswap
// Note: the atomic operation itself will be done over i32, and then we use bitcasts
// to scalars in order to test for exact bitwise equality instead of float
// equality.
// GFX11: %[[old:.+]] = amdgpu.raw_buffer_atomic_cmpswap
// GFX11: %[[vecCastExpected:.+]] = vector.bitcast %{{.*}} : vector<2xf16> to vector<1xi32>
// GFX11: %[[scalarExpected:.+]] = vector.extract %[[vecCastExpected]][0]
// GFX11: %[[vecCastOld:.+]] = vector.bitcast %[[old]] : vector<2xf16> to vector<1xi32>
// GFX11: %[[scalarOld:.+]] = vector.extract %[[vecCastOld]][0]
// GFX11: arith.cmpi eq, %[[scalarOld]], %[[scalarExpected]]
amdgpu.raw_buffer_atomic_fadd %val -> %buffer[%idx] : vector<2xf16> -> memref<?xf16>, i32
func.return
}
Loading