-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[InstCombine] Fold select of clamped shifts #114797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Simon Pilgrim (RKSimon) ChangesIf we are feeding a shift into a select conditioned by an inbounds check for the shift amount, then we can strip any mask/clamp limit that has been put on the shift amount Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), Y) --> (select (icmp_ugt A, BW-1), (shift X, A), T) Alive2: https://alive2.llvm.org/ce/z/xC6FwD Fixes #109888 Full diff: https://github.com/llvm/llvm-project/pull/114797.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 999ad1adff20b8..826a9ec8f0eb98 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2761,6 +2761,67 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC,
return nullptr;
}
+static Instruction *foldSelectWithClampedShift(SelectInst &SI,
+ InstCombinerImpl &IC,
+ IRBuilderBase &Builder) {
+ Value *CondVal = SI.getCondition();
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
+ Type *SelType = SI.getType();
+ uint64_t BW = SelType->getScalarSizeInBits();
+
+ auto MatchClampedShift = [&](Value *V, Value *Amt) -> BinaryOperator * {
+ Value *X, *Limit;
+
+ // Fold (select (icmp_ugt A, BW-1), TrueVal, (shift X, (umin A, C)))
+ // --> (select (icmp_ugt A, BW-1), TrueVal, (shift X, A))
+ // Fold (select (icmp_ult A, BW), (shift X, (umin A, C)), FalseVal)
+ // --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
+ // iff C >= BW-1
+ if (match(V, m_OneUse(m_Shift(m_Value(X),
+ m_UMin(m_Specific(Amt), m_Value(Limit)))))) {
+ KnownBits KnownLimit = IC.computeKnownBits(Limit, 0, &SI);
+ if (KnownLimit.getMinValue().uge(BW - 1))
+ return cast<BinaryOperator>(V);
+ }
+
+ // Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), FalseVal)
+ // --> (select (icmp_ugt A, BW-1), (shift X, A), FalseVal)
+ // Fold (select (icmp_ult A, BW), (shift X, (and A, C)), FalseVal)
+ // --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
+ // iff Pow2 element width and C masks all amt bits.
+ if (isPowerOf2_64(BW) &&
+ match(V, m_OneUse(m_Shift(m_Value(X),
+ m_And(m_Specific(Amt), m_Value(Limit)))))) {
+ KnownBits KnownLimit = IC.computeKnownBits(Limit, 0, &SI);
+ if (KnownLimit.countMinTrailingOnes() >= Log2_64(BW))
+ return cast<BinaryOperator>(V);
+ }
+
+ return nullptr;
+ };
+
+ Value *Amt;
+ if (match(CondVal, m_SpecificICmp(ICmpInst::ICMP_UGT, m_Value(Amt),
+ m_SpecificInt(BW - 1)))) {
+ if (BinaryOperator *ShiftI = MatchClampedShift(FalseVal, Amt))
+ return SelectInst::Create(
+ CondVal, TrueVal,
+ Builder.CreateBinOp(ShiftI->getOpcode(), ShiftI->getOperand(0), Amt));
+ }
+
+ if (match(CondVal, m_SpecificICmp(ICmpInst::ICMP_ULT, m_Value(Amt),
+ m_SpecificInt(BW)))) {
+ if (BinaryOperator *ShiftI = MatchClampedShift(TrueVal, Amt))
+ return SelectInst::Create(
+ CondVal,
+ Builder.CreateBinOp(ShiftI->getOpcode(), ShiftI->getOperand(0), Amt),
+ FalseVal);
+ }
+
+ return nullptr;
+}
+
static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) {
FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition());
if (!FI)
@@ -3817,6 +3878,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Instruction *I = foldSelectExtConst(SI))
return I;
+ if (Instruction *I = foldSelectWithClampedShift(SI, *this, Builder))
+ return I;
+
if (Instruction *I = foldSelectWithSRem(SI, *this, Builder))
return I;
diff --git a/llvm/test/Transforms/InstCombine/select-shift-clamp.ll b/llvm/test/Transforms/InstCombine/select-shift-clamp.ll
new file mode 100644
index 00000000000000..9be6e71b67a351
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/select-shift-clamp.ll
@@ -0,0 +1,227 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+declare void @use_i17(i17)
+declare void @use_i32(i32)
+
+; Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), FalseVal)
+; --> (select (icmp_ugt A, BW-1), (shift X, A), FalseVal)
+; Fold (select (icmp_ult A, BW), (shift X, (and A, C)), FalseVal)
+; --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
+; iff Pow2 element width and C masks all amt bits.
+
+define i32 @select_ult_shl_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ult_shl_clamp_and_i32(
+; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
+; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]]
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %c = icmp ult i32 %a1, 32
+ %m = and i32 %a1, 31
+ %s = shl i32 %a0, %m
+ %r = select i1 %c, i32 %s, i32 %a2
+ ret i32 %r
+}
+
+define i32 @select_ule_ashr_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ule_ashr_clamp_and_i32(
+; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
+; CHECK-NEXT: [[TMP1:%.*]] = ashr i32 [[A0:%.*]], [[A1]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]]
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %c = icmp ule i32 %a1, 31
+ %m = and i32 %a1, 127
+ %s = ashr i32 %a0, %m
+ %r = select i1 %c, i32 %s, i32 %a2
+ ret i32 %r
+}
+
+define i32 @select_ugt_lshr_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ugt_lshr_clamp_and_i32(
+; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1:%.*]], 31
+; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[A0:%.*]], [[A1]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[TMP1]]
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %c = icmp ugt i32 %a1, 31
+ %m = and i32 %a1, 31
+ %s = lshr i32 %a0, %m
+ %r = select i1 %c, i32 %a2, i32 %s
+ ret i32 %r
+}
+
+define i32 @select_uge_shl_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_uge_shl_clamp_and_i32(
+; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1:%.*]], 31
+; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[TMP1]]
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %c = icmp uge i32 %a1, 32
+ %m = and i32 %a1, 63
+ %s = shl i32 %a0, %m
+ %r = select i1 %c, i32 %a2, i32 %s
+ ret i32 %r
+}
+
+; negative test - multiuse
+define i32 @select_ule_ashr_clamp_and_i32_multiuse(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ule_ashr_clamp_and_i32_multiuse(
+; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
+; CHECK-NEXT: [[M:%.*]] = and i32 [[A1]], 127
+; CHECK-NEXT: [[S:%.*]] = ashr i32 [[A0:%.*]], [[M]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[S]], i32 [[A2:%.*]]
+; CHECK-NEXT: call void @use_i32(i32 [[S]])
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %c = icmp ule i32 %a1, 31
+ %m = and i32 %a1, 127
+ %s = ashr i32 %a0, %m
+ %r = select i1 %c, i32 %s, i32 %a2
+ call void @use_i32(i32 %s)
+ ret i32 %r
+}
+
+; negative test - mask doesn't cover all legal amount bit
+define i32 @select_ult_shl_clamp_and_i32_badmask(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ult_shl_clamp_and_i32_badmask(
+; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
+; CHECK-NEXT: [[M:%.*]] = and i32 [[A1]], 28
+; CHECK-NEXT: [[S:%.*]] = shl i32 [[A0:%.*]], [[M]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[S]], i32 [[A2:%.*]]
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %c = icmp ult i32 %a1, 32
+ %m = and i32 %a1, 28
+ %s = shl i32 %a0, %m
+ %r = select i1 %c, i32 %s, i32 %a2
+ ret i32 %r
+}
+
+; negative test - non-pow2
+define i17 @select_uge_lshr_clamp_and_i17_nonpow2(i17 %a0, i17 %a1, i17 %a2) {
+; CHECK-LABEL: @select_uge_lshr_clamp_and_i17_nonpow2(
+; CHECK-NEXT: [[C:%.*]] = icmp ugt i17 [[A1:%.*]], 16
+; CHECK-NEXT: [[M:%.*]] = and i17 [[A1]], 255
+; CHECK-NEXT: [[S:%.*]] = lshr i17 [[A0:%.*]], [[M]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i17 [[A2:%.*]], i17 [[S]]
+; CHECK-NEXT: ret i17 [[R]]
+;
+ %c = icmp uge i17 %a1, 17
+ %m = and i17 %a1, 255
+ %s = lshr i17 %a0, %m
+ %r = select i1 %c, i17 %a2, i17 %s
+ ret i17 %r
+}
+
+; Fold (select (icmp_ugt A, BW-1), TrueVal, (shift X, (umin A, C)))
+; --> (select (icmp_ugt A, BW-1), TrueVal, (shift X, A))
+; Fold (select (icmp_ult A, BW), (shift X, (umin A, C)), FalseVal)
+; --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
+; iff C >= BW-1
+
+define i32 @select_ult_shl_clamp_umin_i32(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ult_shl_clamp_umin_i32(
+; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
+; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]]
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %c = icmp ult i32 %a1, 32
+ %m = call i32 @llvm.umin.i32(i32 %a1, i32 31)
+ %s = shl i32 %a0, %m
+ %r = select i1 %c, i32 %s, i32 %a2
+ ret i32 %r
+}
+
+define i17 @select_ule_ashr_clamp_umin_i17(i17 %a0, i17 %a1, i17 %a2) {
+; CHECK-LABEL: @select_ule_ashr_clamp_umin_i17(
+; CHECK-NEXT: [[C:%.*]] = icmp ult i17 [[A1:%.*]], 17
+; CHECK-NEXT: [[TMP1:%.*]] = ashr i17 [[A0:%.*]], [[A1]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i17 [[TMP1]], i17 [[A2:%.*]]
+; CHECK-NEXT: ret i17 [[R]]
+;
+ %c = icmp ule i17 %a1, 16
+ %m = call i17 @llvm.umin.i17(i17 %a1, i17 17)
+ %s = ashr i17 %a0, %m
+ %r = select i1 %c, i17 %s, i17 %a2
+ ret i17 %r
+}
+
+define i32 @select_ugt_shl_clamp_umin_i32(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ugt_shl_clamp_umin_i32(
+; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1:%.*]], 31
+; CHECK-NEXT: [[S:%.*]] = shl i32 [[A0:%.*]], [[A1]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[S]]
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %c = icmp ugt i32 %a1, 31
+ %m = call i32 @llvm.umin.i32(i32 %a1, i32 128)
+ %s = shl i32 %a0, %m
+ %r = select i1 %c, i32 %a2, i32 %s
+ ret i32 %r
+}
+
+define <2 x i32> @select_uge_lshr_clamp_umin_v2i32(<2 x i32> %a0, <2 x i32> %a1, <2 x i32> %a2) {
+; CHECK-LABEL: @select_uge_lshr_clamp_umin_v2i32(
+; CHECK-NEXT: [[C:%.*]] = icmp ugt <2 x i32> [[A1:%.*]], <i32 31, i32 31>
+; CHECK-NEXT: [[S:%.*]] = lshr <2 x i32> [[A0:%.*]], [[A1]]
+; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[C]], <2 x i32> [[A2:%.*]], <2 x i32> [[S]]
+; CHECK-NEXT: ret <2 x i32> [[R]]
+;
+ %c = icmp uge <2 x i32> %a1, <i32 32, i32 32>
+ %m = call <2 x i32> @llvm.umin.v2i32(<2 x i32> %a1, <2 x i32> <i32 63, i32 31>)
+ %s = lshr <2 x i32> %a0, %m
+ %r = select <2 x i1> %c, <2 x i32> %a2, <2 x i32> %s
+ ret <2 x i32> %r
+}
+
+; negative test - multiuse
+define i32 @select_ugt_shl_clamp_umin_i32_multiuse(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ugt_shl_clamp_umin_i32_multiuse(
+; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[A1:%.*]], 32
+; CHECK-NEXT: [[M:%.*]] = call i32 @llvm.umin.i32(i32 [[A1]], i32 128)
+; CHECK-NEXT: [[S:%.*]] = shl i32 [[A0:%.*]], [[M]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[S]]
+; CHECK-NEXT: call void @use_i32(i32 [[S]])
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %c = icmp ugt i32 %a1, 32
+ %m = call i32 @llvm.umin.i32(i32 %a1, i32 128)
+ %s = shl i32 %a0, %m
+ %r = select i1 %c, i32 %a2, i32 %s
+ call void @use_i32(i32 %s)
+ ret i32 %r
+}
+
+; negative test - umin limit doesn't cover all legal amounts
+define i17 @select_uge_lshr_clamp_umin_i17_badlimit(i17 %a0, i17 %a1, i17 %a2) {
+; CHECK-LABEL: @select_uge_lshr_clamp_umin_i17_badlimit(
+; CHECK-NEXT: [[C:%.*]] = icmp ugt i17 [[A1:%.*]], 15
+; CHECK-NEXT: [[M:%.*]] = call i17 @llvm.umin.i17(i17 [[A1]], i17 12)
+; CHECK-NEXT: [[S:%.*]] = lshr i17 [[A0:%.*]], [[M]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i17 [[A2:%.*]], i17 [[S]]
+; CHECK-NEXT: ret i17 [[R]]
+;
+ %c = icmp uge i17 %a1, 16
+ %m = call i17 @llvm.umin.i17(i17 %a1, i17 12)
+ %s = lshr i17 %a0, %m
+ %r = select i1 %c, i17 %a2, i17 %s
+ ret i17 %r
+}
+
+define range(i64 0, -9223372036854775807) <4 x i64> @PR109888(<4 x i64> %0) {
+; CHECK-LABEL: @PR109888(
+; CHECK-NEXT: [[C:%.*]] = icmp ult <4 x i64> [[TMP0:%.*]], <i64 64, i64 64, i64 64, i64 64>
+; CHECK-NEXT: [[TMP2:%.*]] = shl nuw <4 x i64> <i64 1, i64 1, i64 1, i64 1>, [[TMP0]]
+; CHECK-NEXT: [[R:%.*]] = select <4 x i1> [[C]], <4 x i64> [[TMP2]], <4 x i64> zeroinitializer
+; CHECK-NEXT: ret <4 x i64> [[R]]
+;
+ %c = icmp ult <4 x i64> %0, <i64 64, i64 64, i64 64, i64 64>
+ %m = and <4 x i64> %0, <i64 63, i64 63, i64 63, i64 63>
+ %s = shl nuw <4 x i64> <i64 1, i64 1, i64 1, i64 1>, %m
+ %r = select <4 x i1> %c, <4 x i64> %s, <4 x i64> zeroinitializer
+ ret <4 x i64> %r
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two questions about the alive2 proof:
- please provide a generalized proof.
- There are some freeze instructions in the proof. Should we add
isGuaranteedNotToBeUndefOrPoison
check?
BTW, I think we can generalize this fold with a recursive function like simplifyWithCond
. We already handle equality predicates via simplifyWithOpReplaced
.
Or just reusing CondContext. |
I think that #97289 might cover this? |
OK for me to push the test coverage to trunk? |
We should call |
Maybe it would be best to not actually base it on SimplifyDemandedBits, and just use a separate recursive simplification that only does basic known bits simplification like dropping bitwise ops. It seems like SimplifyDemandedBits can have too many undesirable effects, and as you say, may also not recurse everywhere we want for this purpose. |
Apologies for missing some of the newer InstCombine methods - I haven't had to touch this code for a few years. SimplifyDemandedBits would limit us to pow2 cases, but I'm not sure how important non-pow2 really is? We could just use SimplifyMultipleUseDemandedBits - that currently works for the AND case at least. |
…vm#109888 Add baseline tests for removing shift amount clamps which are also bound by a select: Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), Y) --> (select (icmp_ugt A, BW-1), (shift X, A), T) Fold (select (icmp_ugt A, BW-1), Y, (shift X, (umin A, C))) --> (select (icmp_ugt A, BW-1), Y, (shift X, A))
If we are feeding a shift into a select conditioned by an inbounds check for the shift amount, then we can strip any mask/clamp limit that has been put on the shift amount Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), Y) --> (select (icmp_ugt A, BW-1), (shift X, A), T) Fold (select (icmp_ugt A, BW-1), Y, (shift X, (umin A, C))) --> (select (icmp_ugt A, BW-1), Y, (shift X, A)) Alive2: https://alive2.llvm.org/ce/z/xC6FwD Fixes llvm#109888
7b58920
to
79b1bc8
Compare
If we are feeding a shift into a select conditioned by an inbounds check for the shift amount, then we can strip any mask/clamp limit that has been put on the shift amount
Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), Y) --> (select (icmp_ugt A, BW-1), (shift X, A), T)
Fold (select (icmp_ugt A, BW-1), Y, (shift X, (umin A, C))) --> (select (icmp_ugt A, BW-1), Y, (shift X, A))
Alive2: https://alive2.llvm.org/ce/z/xC6FwD
Fixes #109888