Skip to content

[InstCombine] Missed optimization for select a%2==0, (a/2*2)*(a/2*2), 0 #92658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

jf-botto
Copy link
Contributor

@jf-botto jf-botto commented May 18, 2024

This is my first PR contributing to LLVM.
I'm more than happy to take any advice on board so that I can improve my future contributions.

Fixes #71533
Proof: https://alive2.llvm.org/ce/z/yDXwdM

@jf-botto jf-botto requested a review from nikic as a code owner May 18, 2024 16:44
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented May 18, 2024

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Jorge Botto (jf-botto)

Changes

This is my first PR contributing to LLVM.
I'm more than happy to take any advice on board so that I can improve my future contributions.

Fixes #71533
Proof: https://alive2.llvm.org/ce/z/PCkT2L


Full diff: https://github.com/llvm/llvm-project/pull/92658.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+39)
  • (modified) llvm/test/Transforms/InstCombine/select.ll (+72)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index a3ddb402bf662..e0cb7d951fb5f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1073,6 +1073,41 @@ static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
   return nullptr;
 }
 
+// (A % 2 == 0) ? (A/2*2) : B --> (A % 2 == 0) ? A : B
+// (A % 2 == 0) ? BinOp (A/2*2), (A/2*2) : B --> (A % 2 == 0) ? BinOp A, A : B
+static Value *foldSelectWithIcmpEqAndPattern(ICmpInst *Cmp, Value *TVal,
+                                             Value *FVal,
+                                             InstCombiner::BuilderTy &Builder) {
+  Value *A;
+  ConstantInt *MaskedConstant;
+  ICmpInst::Predicate Pred = Cmp->getPredicate();
+
+  // Checks if the comparison is (A % 2 == 0) and A is not undef.
+  if (!(Pred == ICmpInst::ICMP_EQ &&
+        match(Cmp->getOperand(0), m_And(m_Value(A), m_SpecificInt(1))) &&
+        match(Cmp->getOperand(1), m_SpecificInt(0)) &&
+        isGuaranteedNotToBeUndef(A)))
+    return nullptr;
+
+  // Checks if true branch matches (A % 2).
+  if (match(TVal,
+            m_OneUse(m_And(m_Specific(A), m_ConstantInt(MaskedConstant)))) &&
+      MaskedConstant->getValue().getSExtValue() == -2)
+    return Builder.CreateSelect(Cmp, A, FVal);
+
+  // Checks if true branch matches nested (A % 2) within a binary operation.
+  Value *MulVal;
+  if (match(TVal, m_OneUse(m_BinOp(m_Value(MulVal), m_Deferred(MulVal)))))
+    if (match(MulVal, m_And(m_Specific(A), m_ConstantInt(MaskedConstant))) &&
+        MaskedConstant->getValue().getSExtValue() == -2) {
+      Instruction::BinaryOps OpCode = cast<BinaryOperator>(TVal)->getOpcode();
+      Value *NewBinop = Builder.CreateBinOp(OpCode, A, A);
+      return Builder.CreateSelect(Cmp, NewBinop, FVal);
+    }
+
+  return nullptr;
+}
+
 /// Fold the following code sequence:
 /// \code
 ///   int a = ctlz(x & -x);
@@ -1933,6 +1968,10 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
   if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder))
     return replaceInstUsesWith(SI, V);
 
+  if (Value *V =
+          foldSelectWithIcmpEqAndPattern(ICI, TrueVal, FalseVal, Builder))
+    return replaceInstUsesWith(SI, V);
+
   return Changed ? &SI : nullptr;
 }
 
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index 2ade6faa99be3..a913552dc6cff 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -1456,6 +1456,78 @@ define <2 x i32> @select_icmp_slt0_xor_vec(<2 x i32> %x) {
   ret <2 x i32> %x.xor
 }
 
+define i8 @select_icmp_eq_mul_and(i8 noundef %a, i8 %b)  {
+; CHECK-LABEL: @select_icmp_eq_mul_and(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i8 [[A]], [[A]]
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[TMP2]], i8 [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[RETVAL_0]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div7 = and i8 %a, -2
+  %mul = mul i8 %div7, %div7
+  %retval.0 = select i1 %cmp, i8 %mul, i8 %b
+  ret i8 %retval.0
+}
+
+define i8 @select_icmp_eq_shl_and(i8 noundef %a, i8 %b)  {
+; CHECK-LABEL: @select_icmp_eq_shl_and(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i8 [[A]], [[A]]
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[TMP2]], i8 [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[RETVAL_0]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div7 = and i8 %a, -2
+  %shl = shl i8 %div7, %div7
+  %retval.0 = select i1 %cmp, i8 %shl, i8 %b
+  ret i8 %retval.0
+}
+
+define i8 @select_icmp_eq_and(i8 noundef %a, i8 %b)  {
+; CHECK-LABEL: @select_icmp_eq_and(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[A]], i8 [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[RETVAL_0]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div7 = and i8 %a, -2
+  %retval.0 = select i1 %cmp, i8 %div7, i8 %b
+  ret i8 %retval.0
+}
+
+;negative test
+define i8 @select_and(i8 noundef %a, i8 %b, i1 %cmp)  {
+; CHECK-LABEL: @select_and(
+; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A:%.*]], -2
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP:%.*]], i8 [[DIV7]], i8 [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[RETVAL_0]]
+;
+  %div7 = and i8 %a, -2
+  %retval.0 = select i1 %cmp, i8 %div7, i8 %b
+  ret i8 %retval.0
+}
+
+;negative test
+define i8 @select_mul_and(i8 noundef %a, i8 %b, i1 %cmp)  {
+; CHECK-LABEL: @select_mul_and(
+; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A:%.*]], -2
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[DIV7]], [[DIV7]]
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP:%.*]], i8 [[MUL]], i8 [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[RETVAL_0]]
+;
+  %div7 = and i8 %a, -2
+  %mul = mul i8 %div7, %div7
+  %retval.0 = select i1 %cmp, i8 %mul, i8 %b
+  ret i8 %retval.0
+}
+
 define <4 x i32> @canonicalize_to_shuffle(<4 x i32> %a, <4 x i32> %b) {
 ; CHECK-LABEL: @canonicalize_to_shuffle(
 ; CHECK-NEXT:    [[SEL:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> <i32 0, i32 5, i32 6, i32 3>


// Checks if true branche matches the pattern 'A % 2'.
if (match(TVal,
m_OneUse(m_c_And(m_Value(A), m_ConstantInt(MaskedConstant)))) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think you want m_APInt here (and make MaskedConstant an const APInt *

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Will fix it in my next commit.

KnownBits Known;
Known = IC.computeKnownBits(A, 0, &SI);
IC.computeKnownBitsFromCond(A, Cmp, Known, 0, &SI, false);
if (Known.Zero[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think it would make more sense to generalize this as (~MaskedConstant).isSubtsetOf(Known.Zero)


return nullptr;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I think this would be better as a recursive function.

The base case being matching m_c_And(m_Value(A), m_APInt(Mask)) and then you can try to simplfy operands of binops/etc... you find along the way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comments, I had thought about something similar but wasn't entirely sure about what approach to go with. Will work on this.

Copy link

github-actions bot commented Aug 3, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 59476c99983d3813b412c9b0c0464365644c23a8 70241c0ef98c24fe358c4f9d59614ee46feda4cf --extensions cpp,h -- llvm/include/llvm/Analysis/ValueTracking.h llvm/include/llvm/Transforms/InstCombine/InstCombiner.h llvm/lib/Analysis/ValueTracking.cpp llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4ce1ed00bd..eaa8faaa2d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1082,7 +1082,7 @@ static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
 /// if all bits that are zero in the negated constant
 /// are also zero in A's known zero bits.
 static Value *foldAndMaskPattern(Value *V, Value *Cmp, SelectInst &SI,
-                                     InstCombinerImpl &IC, unsigned Depth = 0) {
+                                 InstCombinerImpl &IC, unsigned Depth = 0) {
 
   Value *A;
   const APInt *MaskedConstant;
@@ -4166,8 +4166,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
     }
   }
 
-  // Attempts to recursively identify and fold (AND A constant) --> A 
-  // in the true branch of the select if all bits 
+  // Attempts to recursively identify and fold (AND A constant) --> A
+  // in the true branch of the select if all bits
   // that are zero in the negated constant are also zero in A's known zero bits.
   if (TrueVal->hasOneUse())
     if (Value *newTrueOp = foldAndMaskPattern(TrueVal, CondVal, SI, *this))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[InstCombine] Missed optimization for select a%2==0, (a/2*2)*(a/2*2), 0
4 participants