Skip to content

Commit 08756e3

Browse files
committed
[InstCombine] Fold select of clamped shifts
If we are feeding a shift into a select conditioned by an inbounds check for the shift amount, then we can strip any mask/clamp limit that has been put on the shift amount Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), Y) --> (select (icmp_ugt A, BW-1), (shift X, A), T) Fold (select (icmp_ugt A, BW-1), Y, (shift X, (umin A, C))) --> (select (icmp_ugt A, BW-1), Y, (shift X, A)) Alive2: https://alive2.llvm.org/ce/z/xC6FwD Fixes #109888
1 parent 2b30a76 commit 08756e3

File tree

2 files changed

+73
-18
lines changed

2 files changed

+73
-18
lines changed

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2761,6 +2761,67 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC,
27612761
return nullptr;
27622762
}
27632763

2764+
static Instruction *foldSelectWithClampedShift(SelectInst &SI,
2765+
InstCombinerImpl &IC,
2766+
IRBuilderBase &Builder) {
2767+
Value *CondVal = SI.getCondition();
2768+
Value *TrueVal = SI.getTrueValue();
2769+
Value *FalseVal = SI.getFalseValue();
2770+
Type *SelType = SI.getType();
2771+
uint64_t BW = SelType->getScalarSizeInBits();
2772+
2773+
auto MatchClampedShift = [&](Value *V, Value *Amt) -> BinaryOperator * {
2774+
Value *X, *Limit;
2775+
2776+
// Fold (select (icmp_ugt A, BW-1), TrueVal, (shift X, (umin A, C)))
2777+
// --> (select (icmp_ugt A, BW-1), TrueVal, (shift X, A))
2778+
// Fold (select (icmp_ult A, BW), (shift X, (umin A, C)), FalseVal)
2779+
// --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
2780+
// iff C >= BW-1
2781+
if (match(V, m_OneUse(m_Shift(m_Value(X),
2782+
m_UMin(m_Specific(Amt), m_Value(Limit)))))) {
2783+
KnownBits KnownLimit = IC.computeKnownBits(Limit, 0, &SI);
2784+
if (KnownLimit.getMinValue().uge(BW - 1))
2785+
return cast<BinaryOperator>(V);
2786+
}
2787+
2788+
// Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), FalseVal)
2789+
// --> (select (icmp_ugt A, BW-1), (shift X, A), FalseVal)
2790+
// Fold (select (icmp_ult A, BW), (shift X, (and A, C)), FalseVal)
2791+
// --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
2792+
// iff Pow2 element width and C masks all amt bits.
2793+
if (isPowerOf2_64(BW) &&
2794+
match(V, m_OneUse(m_Shift(m_Value(X),
2795+
m_And(m_Specific(Amt), m_Value(Limit)))))) {
2796+
KnownBits KnownLimit = IC.computeKnownBits(Limit, 0, &SI);
2797+
if (KnownLimit.countMinTrailingOnes() >= Log2_64(BW))
2798+
return cast<BinaryOperator>(V);
2799+
}
2800+
2801+
return nullptr;
2802+
};
2803+
2804+
Value *Amt;
2805+
if (match(CondVal, m_SpecificICmp(ICmpInst::ICMP_UGT, m_Value(Amt),
2806+
m_SpecificInt(BW - 1)))) {
2807+
if (BinaryOperator *ShiftI = MatchClampedShift(FalseVal, Amt))
2808+
return SelectInst::Create(
2809+
CondVal, TrueVal,
2810+
Builder.CreateBinOp(ShiftI->getOpcode(), ShiftI->getOperand(0), Amt));
2811+
}
2812+
2813+
if (match(CondVal, m_SpecificICmp(ICmpInst::ICMP_ULT, m_Value(Amt),
2814+
m_SpecificInt(BW)))) {
2815+
if (BinaryOperator *ShiftI = MatchClampedShift(TrueVal, Amt))
2816+
return SelectInst::Create(
2817+
CondVal,
2818+
Builder.CreateBinOp(ShiftI->getOpcode(), ShiftI->getOperand(0), Amt),
2819+
FalseVal);
2820+
}
2821+
2822+
return nullptr;
2823+
}
2824+
27642825
static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) {
27652826
FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition());
27662827
if (!FI)
@@ -3871,6 +3932,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
38713932
if (Instruction *I = foldSelectExtConst(SI))
38723933
return I;
38733934

3935+
if (Instruction *I = foldSelectWithClampedShift(SI, *this, Builder))
3936+
return I;
3937+
38743938
if (Instruction *I = foldSelectWithSRem(SI, *this, Builder))
38753939
return I;
38763940

llvm/test/Transforms/InstCombine/select-shift-clamp.ll

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ declare void @use_i32(i32)
1313
define i32 @select_ult_shl_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
1414
; CHECK-LABEL: @select_ult_shl_clamp_and_i32(
1515
; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
16-
; CHECK-NEXT: [[M:%.*]] = and i32 [[A1]], 31
17-
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[M]]
16+
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]]
1817
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]]
1918
; CHECK-NEXT: ret i32 [[R]]
2019
;
@@ -28,8 +27,7 @@ define i32 @select_ult_shl_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
2827
define i32 @select_ule_ashr_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
2928
; CHECK-LABEL: @select_ule_ashr_clamp_and_i32(
3029
; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
31-
; CHECK-NEXT: [[M:%.*]] = and i32 [[A1]], 127
32-
; CHECK-NEXT: [[TMP1:%.*]] = ashr i32 [[A0:%.*]], [[M]]
30+
; CHECK-NEXT: [[TMP1:%.*]] = ashr i32 [[A0:%.*]], [[A1]]
3331
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]]
3432
; CHECK-NEXT: ret i32 [[R]]
3533
;
@@ -43,8 +41,7 @@ define i32 @select_ule_ashr_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
4341
define i32 @select_ugt_lshr_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
4442
; CHECK-LABEL: @select_ugt_lshr_clamp_and_i32(
4543
; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1:%.*]], 31
46-
; CHECK-NEXT: [[M:%.*]] = and i32 [[A1]], 31
47-
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[A0:%.*]], [[M]]
44+
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[A0:%.*]], [[A1]]
4845
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[TMP1]]
4946
; CHECK-NEXT: ret i32 [[R]]
5047
;
@@ -58,8 +55,7 @@ define i32 @select_ugt_lshr_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
5855
define i32 @select_uge_shl_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
5956
; CHECK-LABEL: @select_uge_shl_clamp_and_i32(
6057
; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1:%.*]], 31
61-
; CHECK-NEXT: [[M:%.*]] = and i32 [[A1]], 63
62-
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[M]]
58+
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]]
6359
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[TMP1]]
6460
; CHECK-NEXT: ret i32 [[R]]
6561
;
@@ -129,8 +125,7 @@ define i17 @select_uge_lshr_clamp_and_i17_nonpow2(i17 %a0, i17 %a1, i17 %a2) {
129125
define i32 @select_ult_shl_clamp_umin_i32(i32 %a0, i32 %a1, i32 %a2) {
130126
; CHECK-LABEL: @select_ult_shl_clamp_umin_i32(
131127
; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
132-
; CHECK-NEXT: [[M:%.*]] = call i32 @llvm.umin.i32(i32 [[A1]], i32 31)
133-
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[M]]
128+
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]]
134129
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]]
135130
; CHECK-NEXT: ret i32 [[R]]
136131
;
@@ -144,8 +139,7 @@ define i32 @select_ult_shl_clamp_umin_i32(i32 %a0, i32 %a1, i32 %a2) {
144139
define i17 @select_ule_ashr_clamp_umin_i17(i17 %a0, i17 %a1, i17 %a2) {
145140
; CHECK-LABEL: @select_ule_ashr_clamp_umin_i17(
146141
; CHECK-NEXT: [[C:%.*]] = icmp ult i17 [[A1:%.*]], 17
147-
; CHECK-NEXT: [[M:%.*]] = call i17 @llvm.umin.i17(i17 [[A1]], i17 17)
148-
; CHECK-NEXT: [[TMP1:%.*]] = ashr i17 [[A0:%.*]], [[M]]
142+
; CHECK-NEXT: [[TMP1:%.*]] = ashr i17 [[A0:%.*]], [[A1]]
149143
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i17 [[TMP1]], i17 [[A2:%.*]]
150144
; CHECK-NEXT: ret i17 [[R]]
151145
;
@@ -159,8 +153,7 @@ define i17 @select_ule_ashr_clamp_umin_i17(i17 %a0, i17 %a1, i17 %a2) {
159153
define i32 @select_ugt_shl_clamp_umin_i32(i32 %a0, i32 %a1, i32 %a2) {
160154
; CHECK-LABEL: @select_ugt_shl_clamp_umin_i32(
161155
; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1:%.*]], 31
162-
; CHECK-NEXT: [[M:%.*]] = call i32 @llvm.umin.i32(i32 [[A1]], i32 128)
163-
; CHECK-NEXT: [[S:%.*]] = shl i32 [[A0:%.*]], [[M]]
156+
; CHECK-NEXT: [[S:%.*]] = shl i32 [[A0:%.*]], [[A1]]
164157
; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[S]]
165158
; CHECK-NEXT: ret i32 [[R]]
166159
;
@@ -174,8 +167,7 @@ define i32 @select_ugt_shl_clamp_umin_i32(i32 %a0, i32 %a1, i32 %a2) {
174167
define <2 x i32> @select_uge_lshr_clamp_umin_v2i32(<2 x i32> %a0, <2 x i32> %a1, <2 x i32> %a2) {
175168
; CHECK-LABEL: @select_uge_lshr_clamp_umin_v2i32(
176169
; CHECK-NEXT: [[C:%.*]] = icmp ugt <2 x i32> [[A1:%.*]], <i32 31, i32 31>
177-
; CHECK-NEXT: [[M:%.*]] = call <2 x i32> @llvm.umin.v2i32(<2 x i32> [[A1]], <2 x i32> <i32 63, i32 31>)
178-
; CHECK-NEXT: [[S:%.*]] = lshr <2 x i32> [[A0:%.*]], [[M]]
170+
; CHECK-NEXT: [[S:%.*]] = lshr <2 x i32> [[A0:%.*]], [[A1]]
179171
; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[C]], <2 x i32> [[A2:%.*]], <2 x i32> [[S]]
180172
; CHECK-NEXT: ret <2 x i32> [[R]]
181173
;
@@ -223,8 +215,7 @@ define i17 @select_uge_lshr_clamp_umin_i17_badlimit(i17 %a0, i17 %a1, i17 %a2) {
223215
define range(i64 0, -9223372036854775807) <4 x i64> @PR109888(<4 x i64> %0) {
224216
; CHECK-LABEL: @PR109888(
225217
; CHECK-NEXT: [[C:%.*]] = icmp ult <4 x i64> [[TMP0:%.*]], <i64 64, i64 64, i64 64, i64 64>
226-
; CHECK-NEXT: [[M:%.*]] = and <4 x i64> [[TMP0]], <i64 63, i64 63, i64 63, i64 63>
227-
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw <4 x i64> <i64 1, i64 1, i64 1, i64 1>, [[M]]
218+
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw <4 x i64> <i64 1, i64 1, i64 1, i64 1>, [[TMP0]]
228219
; CHECK-NEXT: [[R:%.*]] = select <4 x i1> [[C]], <4 x i64> [[TMP2]], <4 x i64> zeroinitializer
229220
; CHECK-NEXT: ret <4 x i64> [[R]]
230221
;

0 commit comments

Comments
 (0)