Skip to content

Commit a6a9215

Browse files
authored
Lower shuffle to single-result form if possible. (#84321)
We currently always lower shuffle to the struct-returning variant. I saw some cases where this survived all the way through ptx, resulting in increased register usage. The easiest fix is to simply lower to the single-result version when the predicate is unused.
1 parent ee5e027 commit a6a9215

File tree

2 files changed

+49
-10
lines changed

2 files changed

+49
-10
lines changed

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,6 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
155155
auto valueTy = adaptor.getValue().getType();
156156
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
157157
auto predTy = IntegerType::get(rewriter.getContext(), 1);
158-
auto resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
159-
{valueTy, predTy});
160158

161159
Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
162160
Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
@@ -176,14 +174,25 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
176174
rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
177175
}
178176

179-
auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
177+
bool predIsUsed = !op->getResult(1).use_empty();
178+
UnitAttr returnValueAndIsValidAttr = nullptr;
179+
Type resultTy = valueTy;
180+
if (predIsUsed) {
181+
returnValueAndIsValidAttr = rewriter.getUnitAttr();
182+
resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
183+
{valueTy, predTy});
184+
}
180185
Value shfl = rewriter.create<NVVM::ShflOp>(
181186
loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
182187
maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
183-
Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
184-
Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
185-
186-
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
188+
if (predIsUsed) {
189+
Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
190+
Value isActiveSrcLane =
191+
rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
192+
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
193+
} else {
194+
rewriter.replaceOp(op, {shfl, nullptr});
195+
}
187196
return success();
188197
}
189198
};

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ gpu.module @test_module_3 {
112112

113113
gpu.module @test_module_4 {
114114
// CHECK-LABEL: func @gpu_shuffle()
115-
func.func @gpu_shuffle() -> (f32, f32, f32, f32) {
115+
func.func @gpu_shuffle() -> (f32, f32, f32, f32, i1, i1, i1, i1) {
116116
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
117117
%arg0 = arith.constant 1.0 : f32
118118
// CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
@@ -143,11 +143,41 @@ gpu.module @test_module_4 {
143143
// CHECK: nvvm.shfl.sync idx {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
144144
%shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32
145145

146-
func.return %shfl, %shflu, %shfld, %shfli : f32, f32,f32, f32
146+
func.return %shfl, %shflu, %shfld, %shfli, %pred, %predu, %predd, %predi
147+
: f32, f32,f32, f32, i1, i1, i1, i1
147148
}
148-
}
149149

150+
// CHECK-LABEL: func @gpu_shuffle_unused_pred()
151+
func.func @gpu_shuffle_unused_pred() -> (f32, f32, f32, f32) {
152+
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
153+
%arg0 = arith.constant 1.0 : f32
154+
// CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
155+
%arg1 = arith.constant 4 : i32
156+
// CHECK: %[[#WIDTH:]] = llvm.mlir.constant(23 : i32) : i32
157+
%arg2 = arith.constant 23 : i32
158+
// CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : i32
159+
// CHECK: %[[#MINUS_ONE:]] = llvm.mlir.constant(-1 : i32) : i32
160+
// CHECK: %[[#THIRTY_TWO:]] = llvm.mlir.constant(32 : i32) : i32
161+
// CHECK: %[[#NUM_LANES:]] = llvm.sub %[[#THIRTY_TWO]], %[[#WIDTH]] : i32
162+
// CHECK: %[[#MASK:]] = llvm.lshr %[[#MINUS_ONE]], %[[#NUM_LANES]] : i32
163+
// CHECK: %[[#CLAMP:]] = llvm.sub %[[#WIDTH]], %[[#ONE]] : i32
164+
// CHECK: %[[#SHFL:]] = nvvm.shfl.sync bfly %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] : f32 -> f32
165+
%shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : f32
166+
// CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : i32
167+
// CHECK: %[[#MINUS_ONE:]] = llvm.mlir.constant(-1 : i32) : i32
168+
// CHECK: %[[#THIRTY_TWO:]] = llvm.mlir.constant(32 : i32) : i32
169+
// CHECK: %[[#NUM_LANES:]] = llvm.sub %[[#THIRTY_TWO]], %[[#WIDTH]] : i32
170+
// CHECK: %[[#MASK:]] = llvm.lshr %[[#MINUS_ONE]], %[[#NUM_LANES]] : i32
171+
// CHECK: %[[#SHFL:]] = nvvm.shfl.sync up %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#NUM_LANES]] : f32 -> f32
172+
%shflu, %predu = gpu.shuffle up %arg0, %arg1, %arg2 : f32
173+
// CHECK: nvvm.shfl.sync down {{.*}} : f32 -> f32
174+
%shfld, %predd = gpu.shuffle down %arg0, %arg1, %arg2 : f32
175+
// CHECK: nvvm.shfl.sync idx {{.*}} : f32 -> f32
176+
%shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32
150177

178+
func.return %shfl, %shflu, %shfld, %shfli : f32, f32,f32, f32
179+
}
180+
}
151181

152182
gpu.module @test_module_5 {
153183
// CHECK-LABEL: func @gpu_sync()

0 commit comments

Comments
 (0)