Skip to content

[InstCombine] Fold select of clamped shifts #114797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

RKSimon
Copy link
Collaborator

@RKSimon RKSimon commented Nov 4, 2024

If we are feeding a shift into a select conditioned by an inbounds check for the shift amount, then we can strip any mask/clamp limit that has been put on the shift amount

Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), Y) --> (select (icmp_ugt A, BW-1), (shift X, A), T)
Fold (select (icmp_ugt A, BW-1), Y, (shift X, (umin A, C))) --> (select (icmp_ugt A, BW-1), Y, (shift X, A))

Alive2: https://alive2.llvm.org/ce/z/xC6FwD

Fixes #109888

@llvmbot
Copy link
Member

llvmbot commented Nov 4, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Simon Pilgrim (RKSimon)

Changes

If we are feeding a shift into a select conditioned by an inbounds check for the shift amount, then we can strip any mask/clamp limit that has been put on the shift amount

Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), Y) --> (select (icmp_ugt A, BW-1), (shift X, A), T)
Fold (select (icmp_ugt A, BW-1), Y, (shift X, (umin A, C))) --> (select (icmp_ugt A, BW-1), Y, (shift X, A))

Alive2: https://alive2.llvm.org/ce/z/xC6FwD

Fixes #109888


Full diff: https://github.com/llvm/llvm-project/pull/114797.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+64)
  • (added) llvm/test/Transforms/InstCombine/select-shift-clamp.ll (+227)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 999ad1adff20b8..826a9ec8f0eb98 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2761,6 +2761,67 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC,
   return nullptr;
 }
 
+static Instruction *foldSelectWithClampedShift(SelectInst &SI,
+                                               InstCombinerImpl &IC,
+                                               IRBuilderBase &Builder) {
+  Value *CondVal = SI.getCondition();
+  Value *TrueVal = SI.getTrueValue();
+  Value *FalseVal = SI.getFalseValue();
+  Type *SelType = SI.getType();
+  uint64_t BW = SelType->getScalarSizeInBits();
+
+  auto MatchClampedShift = [&](Value *V, Value *Amt) -> BinaryOperator * {
+    Value *X, *Limit;
+
+    // Fold (select (icmp_ugt A, BW-1), TrueVal, (shift X, (umin A, C)))
+    //  --> (select (icmp_ugt A, BW-1), TrueVal, (shift X, A))
+    // Fold (select (icmp_ult A, BW), (shift X, (umin A, C)), FalseVal)
+    //  --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
+    // iff C >= BW-1
+    if (match(V, m_OneUse(m_Shift(m_Value(X),
+                                  m_UMin(m_Specific(Amt), m_Value(Limit)))))) {
+      KnownBits KnownLimit = IC.computeKnownBits(Limit, 0, &SI);
+      if (KnownLimit.getMinValue().uge(BW - 1))
+        return cast<BinaryOperator>(V);
+    }
+
+    // Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), FalseVal)
+    //  --> (select (icmp_ugt A, BW-1), (shift X, A), FalseVal)
+    // Fold (select (icmp_ult A, BW), (shift X, (and A, C)), FalseVal)
+    //  --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
+    // iff Pow2 element width and C masks all amt bits.
+    if (isPowerOf2_64(BW) &&
+        match(V, m_OneUse(m_Shift(m_Value(X),
+                                  m_And(m_Specific(Amt), m_Value(Limit)))))) {
+      KnownBits KnownLimit = IC.computeKnownBits(Limit, 0, &SI);
+      if (KnownLimit.countMinTrailingOnes() >= Log2_64(BW))
+        return cast<BinaryOperator>(V);
+    }
+
+    return nullptr;
+  };
+
+  Value *Amt;
+  if (match(CondVal, m_SpecificICmp(ICmpInst::ICMP_UGT, m_Value(Amt),
+                                    m_SpecificInt(BW - 1)))) {
+    if (BinaryOperator *ShiftI = MatchClampedShift(FalseVal, Amt))
+      return SelectInst::Create(
+          CondVal, TrueVal,
+          Builder.CreateBinOp(ShiftI->getOpcode(), ShiftI->getOperand(0), Amt));
+  }
+
+  if (match(CondVal, m_SpecificICmp(ICmpInst::ICMP_ULT, m_Value(Amt),
+                                    m_SpecificInt(BW)))) {
+    if (BinaryOperator *ShiftI = MatchClampedShift(TrueVal, Amt))
+      return SelectInst::Create(
+          CondVal,
+          Builder.CreateBinOp(ShiftI->getOpcode(), ShiftI->getOperand(0), Amt),
+          FalseVal);
+  }
+
+  return nullptr;
+}
+
 static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) {
   FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition());
   if (!FI)
@@ -3817,6 +3878,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
   if (Instruction *I = foldSelectExtConst(SI))
     return I;
 
+  if (Instruction *I = foldSelectWithClampedShift(SI, *this, Builder))
+    return I;
+
   if (Instruction *I = foldSelectWithSRem(SI, *this, Builder))
     return I;
 
diff --git a/llvm/test/Transforms/InstCombine/select-shift-clamp.ll b/llvm/test/Transforms/InstCombine/select-shift-clamp.ll
new file mode 100644
index 00000000000000..9be6e71b67a351
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/select-shift-clamp.ll
@@ -0,0 +1,227 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+declare void @use_i17(i17)
+declare void @use_i32(i32)
+
+; Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), FalseVal)
+;  --> (select (icmp_ugt A, BW-1), (shift X, A), FalseVal)
+; Fold (select (icmp_ult A, BW), (shift X, (and A, C)), FalseVal)
+;  --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
+; iff Pow2 element width and C masks all amt bits.
+
+define i32 @select_ult_shl_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ult_shl_clamp_and_i32(
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp ult i32 %a1, 32
+  %m = and i32 %a1, 31
+  %s = shl i32 %a0, %m
+  %r = select i1 %c, i32 %s, i32 %a2
+  ret i32 %r
+}
+
+define i32 @select_ule_ashr_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ule_ashr_clamp_and_i32(
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr i32 [[A0:%.*]], [[A1]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp ule i32 %a1, 31
+  %m = and i32 %a1, 127
+  %s = ashr i32 %a0, %m
+  %r = select i1 %c, i32 %s, i32 %a2
+  ret i32 %r
+}
+
+define i32 @select_ugt_lshr_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ugt_lshr_clamp_and_i32(
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[A1:%.*]], 31
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[A0:%.*]], [[A1]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[TMP1]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp ugt i32 %a1, 31
+  %m = and i32 %a1, 31
+  %s = lshr i32 %a0, %m
+  %r = select i1 %c, i32 %a2, i32 %s
+  ret i32 %r
+}
+
+define i32 @select_uge_shl_clamp_and_i32(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_uge_shl_clamp_and_i32(
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[A1:%.*]], 31
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[TMP1]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp uge i32 %a1, 32
+  %m = and i32 %a1, 63
+  %s = shl i32 %a0, %m
+  %r = select i1 %c, i32 %a2, i32 %s
+  ret i32 %r
+}
+
+; negative test - multiuse
+define i32 @select_ule_ashr_clamp_and_i32_multiuse(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ule_ashr_clamp_and_i32_multiuse(
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
+; CHECK-NEXT:    [[M:%.*]] = and i32 [[A1]], 127
+; CHECK-NEXT:    [[S:%.*]] = ashr i32 [[A0:%.*]], [[M]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 [[S]], i32 [[A2:%.*]]
+; CHECK-NEXT:    call void @use_i32(i32 [[S]])
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp ule i32 %a1, 31
+  %m = and i32 %a1, 127
+  %s = ashr i32 %a0, %m
+  %r = select i1 %c, i32 %s, i32 %a2
+  call void @use_i32(i32 %s)
+  ret i32 %r
+}
+
+; negative test - mask doesn't cover all legal amount bit
+define i32 @select_ult_shl_clamp_and_i32_badmask(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ult_shl_clamp_and_i32_badmask(
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
+; CHECK-NEXT:    [[M:%.*]] = and i32 [[A1]], 28
+; CHECK-NEXT:    [[S:%.*]] = shl i32 [[A0:%.*]], [[M]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 [[S]], i32 [[A2:%.*]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp ult i32 %a1, 32
+  %m = and i32 %a1, 28
+  %s = shl i32 %a0, %m
+  %r = select i1 %c, i32 %s, i32 %a2
+  ret i32 %r
+}
+
+; negative test - non-pow2
+define i17 @select_uge_lshr_clamp_and_i17_nonpow2(i17 %a0, i17 %a1, i17 %a2) {
+; CHECK-LABEL: @select_uge_lshr_clamp_and_i17_nonpow2(
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i17 [[A1:%.*]], 16
+; CHECK-NEXT:    [[M:%.*]] = and i17 [[A1]], 255
+; CHECK-NEXT:    [[S:%.*]] = lshr i17 [[A0:%.*]], [[M]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i17 [[A2:%.*]], i17 [[S]]
+; CHECK-NEXT:    ret i17 [[R]]
+;
+  %c = icmp uge i17 %a1, 17
+  %m = and i17 %a1, 255
+  %s = lshr i17 %a0, %m
+  %r = select i1 %c, i17 %a2, i17 %s
+  ret i17 %r
+}
+
+; Fold (select (icmp_ugt A, BW-1), TrueVal, (shift X, (umin A, C)))
+;  --> (select (icmp_ugt A, BW-1), TrueVal, (shift X, A))
+; Fold (select (icmp_ult A, BW), (shift X, (umin A, C)), FalseVal)
+;  --> (select (icmp_ult A, BW), (shift X, A), FalseVal)
+; iff C >= BW-1
+
+define i32 @select_ult_shl_clamp_umin_i32(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ult_shl_clamp_umin_i32(
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[A1:%.*]], 32
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i32 [[A0:%.*]], [[A1]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 [[TMP1]], i32 [[A2:%.*]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp ult i32 %a1, 32
+  %m = call i32 @llvm.umin.i32(i32 %a1, i32 31)
+  %s = shl i32 %a0, %m
+  %r = select i1 %c, i32 %s, i32 %a2
+  ret i32 %r
+}
+
+define i17 @select_ule_ashr_clamp_umin_i17(i17 %a0, i17 %a1, i17 %a2) {
+; CHECK-LABEL: @select_ule_ashr_clamp_umin_i17(
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i17 [[A1:%.*]], 17
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr i17 [[A0:%.*]], [[A1]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i17 [[TMP1]], i17 [[A2:%.*]]
+; CHECK-NEXT:    ret i17 [[R]]
+;
+  %c = icmp ule i17 %a1, 16
+  %m = call i17 @llvm.umin.i17(i17 %a1, i17 17)
+  %s = ashr i17 %a0, %m
+  %r = select i1 %c, i17 %s, i17 %a2
+  ret i17 %r
+}
+
+define i32 @select_ugt_shl_clamp_umin_i32(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ugt_shl_clamp_umin_i32(
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[A1:%.*]], 31
+; CHECK-NEXT:    [[S:%.*]] = shl i32 [[A0:%.*]], [[A1]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[S]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp ugt i32 %a1, 31
+  %m = call i32 @llvm.umin.i32(i32 %a1, i32 128)
+  %s = shl i32 %a0, %m
+  %r = select i1 %c, i32 %a2, i32 %s
+  ret i32 %r
+}
+
+define <2 x i32> @select_uge_lshr_clamp_umin_v2i32(<2 x i32> %a0, <2 x i32> %a1, <2 x i32> %a2) {
+; CHECK-LABEL: @select_uge_lshr_clamp_umin_v2i32(
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt <2 x i32> [[A1:%.*]], <i32 31, i32 31>
+; CHECK-NEXT:    [[S:%.*]] = lshr <2 x i32> [[A0:%.*]], [[A1]]
+; CHECK-NEXT:    [[R:%.*]] = select <2 x i1> [[C]], <2 x i32> [[A2:%.*]], <2 x i32> [[S]]
+; CHECK-NEXT:    ret <2 x i32> [[R]]
+;
+  %c = icmp uge <2 x i32> %a1, <i32 32, i32 32>
+  %m = call <2 x i32> @llvm.umin.v2i32(<2 x i32> %a1, <2 x i32> <i32 63, i32 31>)
+  %s = lshr <2 x i32> %a0, %m
+  %r = select <2 x i1> %c, <2 x i32> %a2, <2 x i32> %s
+  ret <2 x i32> %r
+}
+
+; negative test - multiuse
+define i32 @select_ugt_shl_clamp_umin_i32_multiuse(i32 %a0, i32 %a1, i32 %a2) {
+; CHECK-LABEL: @select_ugt_shl_clamp_umin_i32_multiuse(
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[A1:%.*]], 32
+; CHECK-NEXT:    [[M:%.*]] = call i32 @llvm.umin.i32(i32 [[A1]], i32 128)
+; CHECK-NEXT:    [[S:%.*]] = shl i32 [[A0:%.*]], [[M]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 [[A2:%.*]], i32 [[S]]
+; CHECK-NEXT:    call void @use_i32(i32 [[S]])
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp ugt i32 %a1, 32
+  %m = call i32 @llvm.umin.i32(i32 %a1, i32 128)
+  %s = shl i32 %a0, %m
+  %r = select i1 %c, i32 %a2, i32 %s
+  call void @use_i32(i32 %s)
+  ret i32 %r
+}
+
+; negative test - umin limit doesn't cover all legal amounts
+define i17 @select_uge_lshr_clamp_umin_i17_badlimit(i17 %a0, i17 %a1, i17 %a2) {
+; CHECK-LABEL: @select_uge_lshr_clamp_umin_i17_badlimit(
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i17 [[A1:%.*]], 15
+; CHECK-NEXT:    [[M:%.*]] = call i17 @llvm.umin.i17(i17 [[A1]], i17 12)
+; CHECK-NEXT:    [[S:%.*]] = lshr i17 [[A0:%.*]], [[M]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i17 [[A2:%.*]], i17 [[S]]
+; CHECK-NEXT:    ret i17 [[R]]
+;
+  %c = icmp uge i17 %a1, 16
+  %m = call i17 @llvm.umin.i17(i17 %a1, i17 12)
+  %s = lshr i17 %a0, %m
+  %r = select i1 %c, i17 %a2, i17 %s
+  ret i17 %r
+}
+
+define range(i64 0, -9223372036854775807) <4 x i64> @PR109888(<4 x i64> %0) {
+; CHECK-LABEL: @PR109888(
+; CHECK-NEXT:    [[C:%.*]] = icmp ult <4 x i64> [[TMP0:%.*]], <i64 64, i64 64, i64 64, i64 64>
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw <4 x i64> <i64 1, i64 1, i64 1, i64 1>, [[TMP0]]
+; CHECK-NEXT:    [[R:%.*]] = select <4 x i1> [[C]], <4 x i64> [[TMP2]], <4 x i64> zeroinitializer
+; CHECK-NEXT:    ret <4 x i64> [[R]]
+;
+  %c = icmp ult <4 x i64> %0, <i64 64, i64 64, i64 64, i64 64>
+  %m = and <4 x i64> %0, <i64 63, i64 63, i64 63, i64 63>
+  %s = shl nuw <4 x i64> <i64 1, i64 1, i64 1, i64 1>, %m
+  %r = select <4 x i1> %c, <4 x i64> %s, <4 x i64> zeroinitializer
+  ret <4 x i64> %r
+}

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two questions about the alive2 proof:

  1. please provide a generalized proof.
  2. There are some freeze instructions in the proof. Should we add isGuaranteedNotToBeUndefOrPoison check?

BTW, I think we can generalize this fold with a recursive function like simplifyWithCond. We already handle equality predicates via simplifyWithOpReplaced.

@dtcxzyw
Copy link
Member

dtcxzyw commented Nov 4, 2024

BTW, I think we can generalize this fold with a recursive function like simplifyWithCond.

Or just reusing CondContext.

@RKSimon RKSimon marked this pull request as draft November 4, 2024 15:17
@nikic
Copy link
Contributor

nikic commented Nov 4, 2024

I think that #97289 might cover this?

@RKSimon
Copy link
Collaborator Author

RKSimon commented Nov 4, 2024

I think that #97289 might cover this?

OK for me to push the test coverage to trunk?

@dtcxzyw
Copy link
Member

dtcxzyw commented Nov 4, 2024

I think that #97289 might cover this?

We should call SimplifyDemandedBits on the RHS of shift operators.

@nikic
Copy link
Contributor

nikic commented Nov 4, 2024

I think that #97289 might cover this?

We should call SimplifyDemandedBits on the RHS of shift operators.

Maybe it would be best to not actually base it on SimplifyDemandedBits, and just use a separate recursive simplification that only does basic known bits simplification like dropping bitwise ops. It seems like SimplifyDemandedBits can have too many undesirable effects, and as you say, may also not recurse everywhere we want for this purpose.

@RKSimon
Copy link
Collaborator Author

RKSimon commented Nov 4, 2024

Apologies for missing some of the newer InstCombine methods - I haven't had to touch this code for a few years.

SimplifyDemandedBits would limit us to pow2 cases, but I'm not sure how important non-pow2 really is?

We could just use SimplifyMultipleUseDemandedBits - that currently works for the AND case at least.

…vm#109888

Add baseline tests for removing shift amount clamps which are also bound by a select:

Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), Y) --> (select (icmp_ugt A, BW-1), (shift X, A), T)
Fold (select (icmp_ugt A, BW-1), Y, (shift X, (umin A, C))) --> (select (icmp_ugt A, BW-1), Y, (shift X, A))
If we are feeding a shift into a select conditioned by an inbounds check for the shift amount, then we can strip any mask/clamp limit that has been put on the shift amount

Fold (select (icmp_ugt A, BW-1), (shift X, (and A, C)), Y) --> (select (icmp_ugt A, BW-1), (shift X, A), T)
Fold (select (icmp_ugt A, BW-1), Y, (shift X, (umin A, C))) --> (select (icmp_ugt A, BW-1), Y, (shift X, A))

Alive2: https://alive2.llvm.org/ce/z/xC6FwD

Fixes llvm#109888
@RKSimon RKSimon force-pushed the instcombine-select-clamped-shifts branch from 7b58920 to 79b1bc8 Compare November 14, 2024 14:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[AVX2] vpsllvq builtin-semantics are not recognized by LLVM vectors
4 participants