Skip to content

Lower shuffle to single-result form if possible. #84321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2024
Merged

Conversation

jreiffers
Copy link
Member

We currently always lower shuffle to the struct-returning variant. I saw some cases where this survived all the way through ptx, resulting in increased register usage. The easiest fix is to simply lower to the single-result version when the predicate is unused.

@llvmbot
Copy link
Member

llvmbot commented Mar 7, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Johannes Reifferscheid (jreiffers)

Changes

We currently always lower shuffle to the struct-returning variant. I saw some cases where this survived all the way through ptx, resulting in increased register usage. The easiest fix is to simply lower to the single-result version when the predicate is unused.


Full diff: https://github.com/llvm/llvm-project/pull/84321.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+17-7)
  • (modified) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir (+33-3)
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index d6a5d8cd74d5f2..993c6822ac74e4 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -155,8 +155,6 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
     auto valueTy = adaptor.getValue().getType();
     auto int32Type = IntegerType::get(rewriter.getContext(), 32);
     auto predTy = IntegerType::get(rewriter.getContext(), 1);
-    auto resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
-                                                     {valueTy, predTy});
 
     Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
     Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
@@ -176,14 +174,26 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
           rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
     }
 
-    auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
+    bool predIsUsed = !op->getResult(1).use_empty();
+    UnitAttr returnValueAndIsValidAttr = nullptr;
+    Type resultTy = valueTy;
+    if (predIsUsed) {
+      returnValueAndIsValidAttr = rewriter.getUnitAttr();
+      resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
+                                                  {valueTy, predTy});
+    }
     Value shfl = rewriter.create<NVVM::ShflOp>(
         loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
         maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
-    Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
-    Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
-
-    rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
+    if (predIsUsed) {
+      Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
+      Value isActiveSrcLane =
+          rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
+      rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
+    } else {
+      Value falseCst = rewriter.create<LLVM::ConstantOp>(loc, predTy, 0);
+      rewriter.replaceOp(op, {shfl, falseCst});
+    }
     return success();
   }
 };
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index dd3b6c2080aa21..8877ee083286b4 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -112,7 +112,7 @@ gpu.module @test_module_3 {
 
 gpu.module @test_module_4 {
   // CHECK-LABEL: func @gpu_shuffle()
-  func.func @gpu_shuffle() -> (f32, f32, f32, f32) {
+  func.func @gpu_shuffle() -> (f32, f32, f32, f32, i1, i1, i1, i1) {
     // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
     %arg0 = arith.constant 1.0 : f32
     // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
@@ -143,11 +143,41 @@ gpu.module @test_module_4 {
     // CHECK: nvvm.shfl.sync idx {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
     %shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32
 
-    func.return %shfl, %shflu, %shfld, %shfli : f32, f32,f32, f32
+    func.return %shfl, %shflu, %shfld, %shfli, %pred, %predu, %predd, %predi
+      : f32, f32,f32, f32, i1, i1, i1, i1
   }
-}
 
+  // CHECK-LABEL: func @gpu_shuffle_unused_pred()
+  func.func @gpu_shuffle_unused_pred() -> (f32, f32, f32, f32) {
+    // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+    %arg0 = arith.constant 1.0 : f32
+    // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
+    %arg1 = arith.constant 4 : i32
+    // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(23 : i32) : i32
+    %arg2 = arith.constant 23 : i32
+    // CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: %[[#MINUS_ONE:]] = llvm.mlir.constant(-1 : i32) : i32
+    // CHECK: %[[#THIRTY_TWO:]] = llvm.mlir.constant(32 : i32) : i32
+    // CHECK: %[[#NUM_LANES:]] = llvm.sub %[[#THIRTY_TWO]], %[[#WIDTH]] : i32
+    // CHECK: %[[#MASK:]] = llvm.lshr %[[#MINUS_ONE]], %[[#NUM_LANES]] : i32
+    // CHECK: %[[#CLAMP:]] = llvm.sub %[[#WIDTH]], %[[#ONE]] : i32
+    // CHECK: %[[#SHFL:]] = nvvm.shfl.sync bfly %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] : f32 -> f32
+    %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : f32
+    // CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: %[[#MINUS_ONE:]] = llvm.mlir.constant(-1 : i32) : i32
+    // CHECK: %[[#THIRTY_TWO:]] = llvm.mlir.constant(32 : i32) : i32
+    // CHECK: %[[#NUM_LANES:]] = llvm.sub %[[#THIRTY_TWO]], %[[#WIDTH]] : i32
+    // CHECK: %[[#MASK:]] = llvm.lshr %[[#MINUS_ONE]], %[[#NUM_LANES]] : i32
+    // CHECK: %[[#SHFL:]] = nvvm.shfl.sync up %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#NUM_LANES]] : f32 -> f32
+    %shflu, %predu = gpu.shuffle up %arg0, %arg1, %arg2 : f32
+    // CHECK: nvvm.shfl.sync down {{.*}} : f32 -> f32
+    %shfld, %predd = gpu.shuffle down %arg0, %arg1, %arg2 : f32
+    // CHECK: nvvm.shfl.sync idx {{.*}} : f32 -> f32
+    %shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32
 
+    func.return %shfl, %shflu, %shfld, %shfli : f32, f32,f32, f32
+  }
+}
 
 gpu.module @test_module_5 {
   // CHECK-LABEL: func @gpu_sync()

Copy link
Contributor

@chsigg chsigg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, I just have one minor nit.

rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
} else {
Value falseCst = rewriter.create<LLVM::ConstantOp>(loc, predTy, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to not create a predicate value at all?
Instead of rewriter.replaceOp(), you could do rewrite.replaceAllUsesWith(op.getResult(0), shfl) plus rewriter.eraseOp(op).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that, but that caused a bunch of things to start failing ("expected the op to be replaced" or something like that).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I poked at it for a bit and the best I could come up with is

    Value isActiveSrcLane = nullptr;
    if (predIsUsed) {
      isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
      shfl = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
    }
    rewriter.replaceOp(op, {shfl, isActiveSrcLane});

I think something along these lines would be a bit better, because it avoids the unnecessary constant being materialized in the IR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But unfortunately this doesn't actually work:

llvm/include/llvm/Support/Casting.h:572: decltype(auto) llvm::cast(From &) [To = mlir::LLVM::LLVMArrayType, From = mlir::Type]: Assertion `isa(Val) && "cast() argument of incompatible type!"' failed.
 

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately not, but the buildkite failed: https://buildkite.com/llvm-project/github-pull-requests/builds/49255#018e5fda-5040-4bf6-9d4d-b976d9799d20

Haven't yet been able to reproduce it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I just messed up applying the fix? Let me try again.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That reproduced for me FYI:

# | #12 0x00005640f7453cc8 decltype(auto) llvm::cast<mlir::LLVM::LLVMArrayType, mlir::Type>(mlir::Type&) /home/mamini/projects/llvm-project2/llvm/include/llvm/Support/Casting.h:573:37
# | #13 0x00005640f71e5735 getInsertExtractValueElementType(mlir::Type, llvm::ArrayRef<long>) /home/mamini/projects/llvm-project2/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp:1590:18
# | #14 0x00005640f71e55ca mlir::LLVM::ExtractValueOp::build(mlir::OpBuilder&, mlir::OperationState&, mlir::Value, llvm::ArrayRef<long>) /home/mamini/projects/llvm-project2/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp:1645:9
# | #15 0x00005640f99a78e9 mlir::LLVM::ExtractValueOp mlir::OpBuilder::create<mlir::LLVM::ExtractValueOp, mlir::Value&, int>(mlir::Location, mlir::Value&, int&&) /home/mamini/projects/llvm-project2/mlir/include/mlir/IR/Builders.h:511:5
# | #16 0x00005640f9be1283 (anonymous namespace)::GPUShuffleOpLowering::matchAndRewrite(mlir::gpu::ShuffleOp, mlir::gpu::ShuffleOpAdaptor, mlir::ConversionPatternRewriter&) const /home/mamini/projects/llvm-project2/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp:191:34

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually that was on the previous commit that failed buildkite, the latest push seems fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, looks like it was just me. Thanks for checking!

We currently always lower shuffle to the struct-returning variant. I saw
some cases where this survived all the way through ptx, resulting in
increased register usage. The easiest fix is to simply lower to the
single-result version when the predicate is unused.
@jreiffers jreiffers merged commit a6a9215 into llvm:main Mar 21, 2024
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
We currently always lower shuffle to the struct-returning variant. I saw
some cases where this survived all the way through ptx, resulting in
increased register usage. The easiest fix is to simply lower to the
single-result version when the predicate is unused.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants