-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[InstCombine] Missed optimization for select a%2==0, (a/2*2)*(a/2*2), 0 #92658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-transforms Author: Jorge Botto (jf-botto) ChangesThis is my first PR contributing to LLVM. Fixes #71533 Full diff: https://github.com/llvm/llvm-project/pull/92658.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index a3ddb402bf662..e0cb7d951fb5f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1073,6 +1073,41 @@ static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
return nullptr;
}
+// (A % 2 == 0) ? (A/2*2) : B --> (A % 2 == 0) ? A : B
+// (A % 2 == 0) ? BinOp (A/2*2), (A/2*2) : B --> (A % 2 == 0) ? BinOp A, A : B
+static Value *foldSelectWithIcmpEqAndPattern(ICmpInst *Cmp, Value *TVal,
+ Value *FVal,
+ InstCombiner::BuilderTy &Builder) {
+ Value *A;
+ ConstantInt *MaskedConstant;
+ ICmpInst::Predicate Pred = Cmp->getPredicate();
+
+ // Checks if the comparison is (A % 2 == 0) and A is not undef.
+ if (!(Pred == ICmpInst::ICMP_EQ &&
+ match(Cmp->getOperand(0), m_And(m_Value(A), m_SpecificInt(1))) &&
+ match(Cmp->getOperand(1), m_SpecificInt(0)) &&
+ isGuaranteedNotToBeUndef(A)))
+ return nullptr;
+
+ // Checks if true branch matches (A % 2).
+ if (match(TVal,
+ m_OneUse(m_And(m_Specific(A), m_ConstantInt(MaskedConstant)))) &&
+ MaskedConstant->getValue().getSExtValue() == -2)
+ return Builder.CreateSelect(Cmp, A, FVal);
+
+ // Checks if true branch matches nested (A % 2) within a binary operation.
+ Value *MulVal;
+ if (match(TVal, m_OneUse(m_BinOp(m_Value(MulVal), m_Deferred(MulVal)))))
+ if (match(MulVal, m_And(m_Specific(A), m_ConstantInt(MaskedConstant))) &&
+ MaskedConstant->getValue().getSExtValue() == -2) {
+ Instruction::BinaryOps OpCode = cast<BinaryOperator>(TVal)->getOpcode();
+ Value *NewBinop = Builder.CreateBinOp(OpCode, A, A);
+ return Builder.CreateSelect(Cmp, NewBinop, FVal);
+ }
+
+ return nullptr;
+}
+
/// Fold the following code sequence:
/// \code
/// int a = ctlz(x & -x);
@@ -1933,6 +1968,10 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);
+ if (Value *V =
+ foldSelectWithIcmpEqAndPattern(ICI, TrueVal, FalseVal, Builder))
+ return replaceInstUsesWith(SI, V);
+
return Changed ? &SI : nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index 2ade6faa99be3..a913552dc6cff 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -1456,6 +1456,78 @@ define <2 x i32> @select_icmp_slt0_xor_vec(<2 x i32> %x) {
ret <2 x i32> %x.xor
}
+define i8 @select_icmp_eq_mul_and(i8 noundef %a, i8 %b) {
+; CHECK-LABEL: @select_icmp_eq_mul_and(
+; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A:%.*]], 1
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT: [[TMP2:%.*]] = mul i8 [[A]], [[A]]
+; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[TMP2]], i8 [[B:%.*]]
+; CHECK-NEXT: ret i8 [[RETVAL_0]]
+;
+ %1 = and i8 %a, 1
+ %cmp = icmp eq i8 %1, 0
+ %div7 = and i8 %a, -2
+ %mul = mul i8 %div7, %div7
+ %retval.0 = select i1 %cmp, i8 %mul, i8 %b
+ ret i8 %retval.0
+}
+
+define i8 @select_icmp_eq_shl_and(i8 noundef %a, i8 %b) {
+; CHECK-LABEL: @select_icmp_eq_shl_and(
+; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A:%.*]], 1
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT: [[TMP2:%.*]] = shl i8 [[A]], [[A]]
+; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[TMP2]], i8 [[B:%.*]]
+; CHECK-NEXT: ret i8 [[RETVAL_0]]
+;
+ %1 = and i8 %a, 1
+ %cmp = icmp eq i8 %1, 0
+ %div7 = and i8 %a, -2
+ %shl = shl i8 %div7, %div7
+ %retval.0 = select i1 %cmp, i8 %shl, i8 %b
+ ret i8 %retval.0
+}
+
+define i8 @select_icmp_eq_and(i8 noundef %a, i8 %b) {
+; CHECK-LABEL: @select_icmp_eq_and(
+; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A:%.*]], 1
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[A]], i8 [[B:%.*]]
+; CHECK-NEXT: ret i8 [[RETVAL_0]]
+;
+ %1 = and i8 %a, 1
+ %cmp = icmp eq i8 %1, 0
+ %div7 = and i8 %a, -2
+ %retval.0 = select i1 %cmp, i8 %div7, i8 %b
+ ret i8 %retval.0
+}
+
+;negative test
+define i8 @select_and(i8 noundef %a, i8 %b, i1 %cmp) {
+; CHECK-LABEL: @select_and(
+; CHECK-NEXT: [[DIV7:%.*]] = and i8 [[A:%.*]], -2
+; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[CMP:%.*]], i8 [[DIV7]], i8 [[B:%.*]]
+; CHECK-NEXT: ret i8 [[RETVAL_0]]
+;
+ %div7 = and i8 %a, -2
+ %retval.0 = select i1 %cmp, i8 %div7, i8 %b
+ ret i8 %retval.0
+}
+
+;negative test
+define i8 @select_mul_and(i8 noundef %a, i8 %b, i1 %cmp) {
+; CHECK-LABEL: @select_mul_and(
+; CHECK-NEXT: [[DIV7:%.*]] = and i8 [[A:%.*]], -2
+; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[DIV7]], [[DIV7]]
+; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[CMP:%.*]], i8 [[MUL]], i8 [[B:%.*]]
+; CHECK-NEXT: ret i8 [[RETVAL_0]]
+;
+ %div7 = and i8 %a, -2
+ %mul = mul i8 %div7, %div7
+ %retval.0 = select i1 %cmp, i8 %mul, i8 %b
+ ret i8 %retval.0
+}
+
define <4 x i32> @canonicalize_to_shuffle(<4 x i32> %a, <4 x i32> %b) {
; CHECK-LABEL: @canonicalize_to_shuffle(
; CHECK-NEXT: [[SEL:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> <i32 0, i32 5, i32 6, i32 3>
|
5e577d9
to
70553da
Compare
|
||
// Checks if true branche matches the pattern 'A % 2'. | ||
if (match(TVal, | ||
m_OneUse(m_c_And(m_Value(A), m_ConstantInt(MaskedConstant)))) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think you want m_APInt
here (and make MaskedConstant
an const APInt *
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Will fix it in my next commit.
KnownBits Known; | ||
Known = IC.computeKnownBits(A, 0, &SI); | ||
IC.computeKnownBitsFromCond(A, Cmp, Known, 0, &SI, false); | ||
if (Known.Zero[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think it would make more sense to generalize this as (~MaskedConstant).isSubtsetOf(Known.Zero)
|
||
return nullptr; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I think this would be better as a recursive function.
The base case being matching m_c_And(m_Value(A), m_APInt(Mask))
and then you can try to simplfy operands of binops/etc... you find along the way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your comments, I had thought about something similar but wasn't entirely sure about what approach to go with. Will work on this.
You can test this locally with the following command:git-clang-format --diff 59476c99983d3813b412c9b0c0464365644c23a8 70241c0ef98c24fe358c4f9d59614ee46feda4cf --extensions cpp,h -- llvm/include/llvm/Analysis/ValueTracking.h llvm/include/llvm/Transforms/InstCombine/InstCombiner.h llvm/lib/Analysis/ValueTracking.cpp llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp View the diff from clang-format here.diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4ce1ed00bd..eaa8faaa2d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1082,7 +1082,7 @@ static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
/// if all bits that are zero in the negated constant
/// are also zero in A's known zero bits.
static Value *foldAndMaskPattern(Value *V, Value *Cmp, SelectInst &SI,
- InstCombinerImpl &IC, unsigned Depth = 0) {
+ InstCombinerImpl &IC, unsigned Depth = 0) {
Value *A;
const APInt *MaskedConstant;
@@ -4166,8 +4166,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
}
- // Attempts to recursively identify and fold (AND A constant) --> A
- // in the true branch of the select if all bits
+ // Attempts to recursively identify and fold (AND A constant) --> A
+ // in the true branch of the select if all bits
// that are zero in the negated constant are also zero in A's known zero bits.
if (TrueVal->hasOneUse())
if (Value *newTrueOp = foldAndMaskPattern(TrueVal, CondVal, SI, *this))
|
This is my first PR contributing to LLVM.
I'm more than happy to take any advice on board so that I can improve my future contributions.
Fixes #71533
Proof: https://alive2.llvm.org/ce/z/yDXwdM