-
Notifications
You must be signed in to change notification settings - Fork 13.5k
Lower shuffle to single-result form if possible. #84321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Johannes Reifferscheid (jreiffers) ChangesWe currently always lower shuffle to the struct-returning variant. I saw some cases where this survived all the way through ptx, resulting in increased register usage. The easiest fix is to simply lower to the single-result version when the predicate is unused. Full diff: https://github.com/llvm/llvm-project/pull/84321.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index d6a5d8cd74d5f2..993c6822ac74e4 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -155,8 +155,6 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
auto valueTy = adaptor.getValue().getType();
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
auto predTy = IntegerType::get(rewriter.getContext(), 1);
- auto resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
- {valueTy, predTy});
Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
@@ -176,14 +174,26 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
}
- auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
+ bool predIsUsed = !op->getResult(1).use_empty();
+ UnitAttr returnValueAndIsValidAttr = nullptr;
+ Type resultTy = valueTy;
+ if (predIsUsed) {
+ returnValueAndIsValidAttr = rewriter.getUnitAttr();
+ resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
+ {valueTy, predTy});
+ }
Value shfl = rewriter.create<NVVM::ShflOp>(
loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
- Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
- Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
-
- rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
+ if (predIsUsed) {
+ Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
+ Value isActiveSrcLane =
+ rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
+ rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
+ } else {
+ Value falseCst = rewriter.create<LLVM::ConstantOp>(loc, predTy, 0);
+ rewriter.replaceOp(op, {shfl, falseCst});
+ }
return success();
}
};
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index dd3b6c2080aa21..8877ee083286b4 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -112,7 +112,7 @@ gpu.module @test_module_3 {
gpu.module @test_module_4 {
// CHECK-LABEL: func @gpu_shuffle()
- func.func @gpu_shuffle() -> (f32, f32, f32, f32) {
+ func.func @gpu_shuffle() -> (f32, f32, f32, f32, i1, i1, i1, i1) {
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
%arg0 = arith.constant 1.0 : f32
// CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
@@ -143,11 +143,41 @@ gpu.module @test_module_4 {
// CHECK: nvvm.shfl.sync idx {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
%shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32
- func.return %shfl, %shflu, %shfld, %shfli : f32, f32,f32, f32
+ func.return %shfl, %shflu, %shfld, %shfli, %pred, %predu, %predd, %predi
+ : f32, f32,f32, f32, i1, i1, i1, i1
}
-}
+ // CHECK-LABEL: func @gpu_shuffle_unused_pred()
+ func.func @gpu_shuffle_unused_pred() -> (f32, f32, f32, f32) {
+ // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+ %arg0 = arith.constant 1.0 : f32
+ // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
+ %arg1 = arith.constant 4 : i32
+ // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(23 : i32) : i32
+ %arg2 = arith.constant 23 : i32
+ // CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[#MINUS_ONE:]] = llvm.mlir.constant(-1 : i32) : i32
+ // CHECK: %[[#THIRTY_TWO:]] = llvm.mlir.constant(32 : i32) : i32
+ // CHECK: %[[#NUM_LANES:]] = llvm.sub %[[#THIRTY_TWO]], %[[#WIDTH]] : i32
+ // CHECK: %[[#MASK:]] = llvm.lshr %[[#MINUS_ONE]], %[[#NUM_LANES]] : i32
+ // CHECK: %[[#CLAMP:]] = llvm.sub %[[#WIDTH]], %[[#ONE]] : i32
+ // CHECK: %[[#SHFL:]] = nvvm.shfl.sync bfly %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] : f32 -> f32
+ %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : f32
+ // CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[#MINUS_ONE:]] = llvm.mlir.constant(-1 : i32) : i32
+ // CHECK: %[[#THIRTY_TWO:]] = llvm.mlir.constant(32 : i32) : i32
+ // CHECK: %[[#NUM_LANES:]] = llvm.sub %[[#THIRTY_TWO]], %[[#WIDTH]] : i32
+ // CHECK: %[[#MASK:]] = llvm.lshr %[[#MINUS_ONE]], %[[#NUM_LANES]] : i32
+ // CHECK: %[[#SHFL:]] = nvvm.shfl.sync up %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#NUM_LANES]] : f32 -> f32
+ %shflu, %predu = gpu.shuffle up %arg0, %arg1, %arg2 : f32
+ // CHECK: nvvm.shfl.sync down {{.*}} : f32 -> f32
+ %shfld, %predd = gpu.shuffle down %arg0, %arg1, %arg2 : f32
+ // CHECK: nvvm.shfl.sync idx {{.*}} : f32 -> f32
+ %shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32
+ func.return %shfl, %shflu, %shfld, %shfli : f32, f32,f32, f32
+ }
+}
gpu.module @test_module_5 {
// CHECK-LABEL: func @gpu_sync()
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, I just have one minor nit.
rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1); | ||
rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); | ||
} else { | ||
Value falseCst = rewriter.create<LLVM::ConstantOp>(loc, predTy, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to not create a predicate value at all?
Instead of rewriter.replaceOp()
, you could do rewrite.replaceAllUsesWith(op.getResult(0), shfl)
plus rewriter.eraseOp(op)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried that, but that caused a bunch of things to start failing ("expected the op to be replaced" or something like that).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I poked at it for a bit and the best I could come up with is
Value isActiveSrcLane = nullptr;
if (predIsUsed) {
isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
shfl = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
}
rewriter.replaceOp(op, {shfl, isActiveSrcLane});
I think something along these lines would be a bit better, because it avoids the unnecessary constant being materialized in the IR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But unfortunately this doesn't actually work:
llvm/include/llvm/Support/Casting.h:572: decltype(auto) llvm::cast(From &) [To = mlir::LLVM::LLVMArrayType, From = mlir::Type]: Assertion `isa(Val) && "cast() argument of incompatible type!"' failed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately not, but the buildkite failed: https://buildkite.com/llvm-project/github-pull-requests/builds/49255#018e5fda-5040-4bf6-9d4d-b976d9799d20
Haven't yet been able to reproduce it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I just messed up applying the fix? Let me try again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That reproduced for me FYI:
# | #12 0x00005640f7453cc8 decltype(auto) llvm::cast<mlir::LLVM::LLVMArrayType, mlir::Type>(mlir::Type&) /home/mamini/projects/llvm-project2/llvm/include/llvm/Support/Casting.h:573:37
# | #13 0x00005640f71e5735 getInsertExtractValueElementType(mlir::Type, llvm::ArrayRef<long>) /home/mamini/projects/llvm-project2/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp:1590:18
# | #14 0x00005640f71e55ca mlir::LLVM::ExtractValueOp::build(mlir::OpBuilder&, mlir::OperationState&, mlir::Value, llvm::ArrayRef<long>) /home/mamini/projects/llvm-project2/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp:1645:9
# | #15 0x00005640f99a78e9 mlir::LLVM::ExtractValueOp mlir::OpBuilder::create<mlir::LLVM::ExtractValueOp, mlir::Value&, int>(mlir::Location, mlir::Value&, int&&) /home/mamini/projects/llvm-project2/mlir/include/mlir/IR/Builders.h:511:5
# | #16 0x00005640f9be1283 (anonymous namespace)::GPUShuffleOpLowering::matchAndRewrite(mlir::gpu::ShuffleOp, mlir::gpu::ShuffleOpAdaptor, mlir::ConversionPatternRewriter&) const /home/mamini/projects/llvm-project2/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp:191:34
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually that was on the previous commit that failed buildkite, the latest push seems fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, looks like it was just me. Thanks for checking!
d6cd651
to
05fb6f9
Compare
We currently always lower shuffle to the struct-returning variant. I saw some cases where this survived all the way through ptx, resulting in increased register usage. The easiest fix is to simply lower to the single-result version when the predicate is unused.
We currently always lower shuffle to the struct-returning variant. I saw some cases where this survived all the way through ptx, resulting in increased register usage. The easiest fix is to simply lower to the single-result version when the predicate is unused.
We currently always lower shuffle to the struct-returning variant. I saw some cases where this survived all the way through ptx, resulting in increased register usage. The easiest fix is to simply lower to the single-result version when the predicate is unused.