-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[ROCDL] Add the global.atomic.fadd intrinsic in ROCDL #94486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm Author: Giuseppe Rossini (giuseros) ChangesThis PR adds the Full diff: https://github.com/llvm/llvm-project/pull/94486.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 1dabf5d7979b7..c8d4e4c03486e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
let summary = "Vote across thread group";
let description = [{
- Ballot provides a bit mask containing the 1-bit predicate value from each lane.
+ Ballot provides a bit mask containing the 1-bit predicate value from each lane.
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
}];
@@ -516,7 +516,7 @@ def ROCDL_RawBufferAtomicCmpSwap :
}
//===---------------------------------------------------------------------===//
-// MI-100 and MI-200 buffer atomic floating point add intrinsic
+// MI-100, MI-200 and MI-300 global/buffer atomic floating point add intrinsic
def ROCDL_RawBufferAtomicFAddOp :
ROCDL_Op<"raw.buffer.atomic.fadd">,
@@ -534,6 +534,19 @@ def ROCDL_RawBufferAtomicFAddOp :
let hasCustomAssemblyFormat = 1;
}
+def ROCDL_GlobalAtomicFAddOp :
+ ROCDL_Op<"global.atomic.fadd">,
+ Arguments<(ins LLVM_Type:$ptr,
+ LLVM_Type:$vdata)>{
+ string llvmBuilder = [{
+ auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
+ auto ptrType = moduleTranslation.convertType(op.getPtr().getType());
+ createIntrinsicCall(builder,
+ llvm::Intrinsic::amdgcn_global_atomic_fadd, {$ptr, $vdata}, {vdataType, ptrType, vdataType});
+ }];
+ let hasCustomAssemblyFormat = 1;
+}
+
//===---------------------------------------------------------------------===//
// Buffer atomic floating point max intrinsic. GFX9 does not support fp32.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 65b770ae32610..34ebdb2ffd3d0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -157,6 +157,26 @@ void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
p << " " << getOperands() << " : " << getVdata().getType();
}
+// <operation> ::=
+// `llvm.amdgcn.global.atomic.fadd.* %vdata, %ptr
+ParseResult GlobalAtomicFAddOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
+ Type type;
+ if (parser.parseOperandList(ops, 2) || parser.parseColonType(type))
+ return failure();
+
+ auto ptrType = LLVM::LLVMPointerType::get(parser.getContext());
+ if (parser.resolveOperands(ops, {ptrType, type}, parser.getNameLoc(),
+ result.operands))
+ return failure();
+ return success();
+}
+
+void GlobalAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
+ p << " " << getOperands() << " : " << getVdata().getType();
+}
+
// <operation> ::=
// `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc, %offset,
// %soffset, %aux : result_type`
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index ce6b56d48437a..9d22b80748e14 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -494,6 +494,15 @@ llvm.func @rocdl.raw.buffer.atomic.f32(%rsrc : vector<4xi32>,
llvm.return
}
+// CHECK-LABEL: rocdl.global.atomic
+llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr) {
+ // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p0.f32(ptr %{{.*}}, float %{{.*}}
+ rocdl.global.atomic.fadd %ptr, %vdata0: f32
+ // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p0.v2f16(ptr %{{.*}}, <2 x half> %{{.*}})
+ rocdl.global.atomic.fadd %ptr, %vdata1: vector<2xf16>
+ llvm.return
+}
+
llvm.func @rocdl.raw.buffer.atomic.i32(%rsrc : vector<4xi32>,
%offset : i32, %soffset : i32,
%vdata1 : i32) {
|
@llvm/pr-subscribers-mlir Author: Giuseppe Rossini (giuseros) ChangesThis PR adds the Full diff: https://github.com/llvm/llvm-project/pull/94486.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 1dabf5d7979b7..c8d4e4c03486e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
let summary = "Vote across thread group";
let description = [{
- Ballot provides a bit mask containing the 1-bit predicate value from each lane.
+ Ballot provides a bit mask containing the 1-bit predicate value from each lane.
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
}];
@@ -516,7 +516,7 @@ def ROCDL_RawBufferAtomicCmpSwap :
}
//===---------------------------------------------------------------------===//
-// MI-100 and MI-200 buffer atomic floating point add intrinsic
+// MI-100, MI-200 and MI-300 global/buffer atomic floating point add intrinsic
def ROCDL_RawBufferAtomicFAddOp :
ROCDL_Op<"raw.buffer.atomic.fadd">,
@@ -534,6 +534,19 @@ def ROCDL_RawBufferAtomicFAddOp :
let hasCustomAssemblyFormat = 1;
}
+def ROCDL_GlobalAtomicFAddOp :
+ ROCDL_Op<"global.atomic.fadd">,
+ Arguments<(ins LLVM_Type:$ptr,
+ LLVM_Type:$vdata)>{
+ string llvmBuilder = [{
+ auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
+ auto ptrType = moduleTranslation.convertType(op.getPtr().getType());
+ createIntrinsicCall(builder,
+ llvm::Intrinsic::amdgcn_global_atomic_fadd, {$ptr, $vdata}, {vdataType, ptrType, vdataType});
+ }];
+ let hasCustomAssemblyFormat = 1;
+}
+
//===---------------------------------------------------------------------===//
// Buffer atomic floating point max intrinsic. GFX9 does not support fp32.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 65b770ae32610..34ebdb2ffd3d0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -157,6 +157,26 @@ void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
p << " " << getOperands() << " : " << getVdata().getType();
}
+// <operation> ::=
+// `llvm.amdgcn.global.atomic.fadd.* %vdata, %ptr
+ParseResult GlobalAtomicFAddOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
+ Type type;
+ if (parser.parseOperandList(ops, 2) || parser.parseColonType(type))
+ return failure();
+
+ auto ptrType = LLVM::LLVMPointerType::get(parser.getContext());
+ if (parser.resolveOperands(ops, {ptrType, type}, parser.getNameLoc(),
+ result.operands))
+ return failure();
+ return success();
+}
+
+void GlobalAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
+ p << " " << getOperands() << " : " << getVdata().getType();
+}
+
// <operation> ::=
// `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc, %offset,
// %soffset, %aux : result_type`
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index ce6b56d48437a..9d22b80748e14 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -494,6 +494,15 @@ llvm.func @rocdl.raw.buffer.atomic.f32(%rsrc : vector<4xi32>,
llvm.return
}
+// CHECK-LABEL: rocdl.global.atomic
+llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr) {
+ // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p0.f32(ptr %{{.*}}, float %{{.*}}
+ rocdl.global.atomic.fadd %ptr, %vdata0: f32
+ // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p0.v2f16(ptr %{{.*}}, <2 x half> %{{.*}})
+ rocdl.global.atomic.fadd %ptr, %vdata1: vector<2xf16>
+ llvm.return
+}
+
llvm.func @rocdl.raw.buffer.atomic.i32(%rsrc : vector<4xi32>,
%offset : i32, %soffset : i32,
%vdata1 : i32) {
|
I noticed that the use case I am dealing with needs also the output from the |
Done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please just use atomicrmw fadd. I will shortly be pushing to remove the intrinsic
Hi @arsenm , the problem is that |
Or maybe it does? |
@giuseros I wonder if it's that MLIR's wrappers around atomicrmw don't support vectors ... which seems like an extension we could do |
atomicrmw FP operations do since 4cb110a. I still need to implement the AMDGPU codegen changes to start using the vector instructions though (plus eventually the new metadata from #85052 will be needed |
Ok, given I am exactly after that vector instruction, how about we merge this PR and then we enable vector support for |
Hi @arsenm , is it ok for this to merge? |
I guess, though I always prefer to just do whatever is needed move towards the end goal instead of adding new throwaway code |
Ok, after a chat with Matthew, we agree on closing this for now and trying to emit the vectorized |
Part 1 to start supporting the vector selection is in #94845 |
This PR adds the
global.atomic.fadd
intrinsic in ROCDL (which supportsf32
andvector<2xf16>
)