-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[InstCombine] Fold max(max(x, c1) << c2, c3) —> max(x << c2, c3) when c3 >= c1 * 2 ^ c2 #140526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ue Missed Optimization: max(max(x, c1) << c2, c3) —> max(x << c2, c3) when c3 >= c1 * 2 ^ c2 llvm#139786
…ax(x << c2, c3) when c3 >= c1 * 2 ^ c2 This patch fixes issue llvm#139786 where InstCombine where it Missed Optimization: max(max(x, c1) << c2, c3) —> max(x << c2, c3) when c3 >= c1 * 2 ^ c2. Pre-committed test in <commit-hash>. Alive2: https://alive2.llvm.org/ce/z/on2tJE
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-llvm-transforms Author: None (Charukesh827) ChangesAs suggested generalize to fold max(max(x, c1) binop c2, c3) —> max(x binop c2, c3) if c3>=C1* 2 ^ c2 is done. define i8 @src(i8 %arg0) { define i8 @tgt(i8 %arg0) { Full diff: https://github.com/llvm/llvm-project/pull/140526.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 3d35bf753c40e..53dd5f803f97b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1171,6 +1171,84 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
: BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
}
+
+//Try canonicalize min/max(x << shamt, c<<shamt) into max(x, c) << shamt
+static Instruction *moveShiftAfterMinMax(IntrinsicInst *II, InstCombiner::BuilderTy &Builder) {
+ Intrinsic::ID MinMaxID = II->getIntrinsicID();
+ assert((MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin ||
+ MinMaxID == Intrinsic::umax || MinMaxID == Intrinsic::umin) &&
+ "Expected a min or max intrinsic");
+
+ Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1);
+ Value *InnerMax;
+ const APInt *C;
+ if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) ||
+ !match(Op1, m_APInt(C)))
+ return nullptr;
+
+ auto* BinOpInst = cast<BinaryOperator>(Op0);
+ Instruction::BinaryOps BinOp = BinOpInst->getOpcode();
+ Value *X;
+ InnerMax = BinOpInst->getOperand(0);
+ // std::cout<< InnerMax->dump() <<std::endl;
+ if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umax>(m_Value(X),m_APInt(C))))){
+ if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smax>(m_Value(X),m_APInt(C))))){
+ if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umin>(m_Value(X),m_APInt(C))))){
+ if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smin>(m_Value(X),m_APInt(C))))){
+ return nullptr;
+ }}}}
+
+ auto *InnerMaxInst = cast<IntrinsicInst>(InnerMax);
+
+ bool IsSigned = MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin;
+ if((IsSigned && !BinOpInst->hasNoSignedWrap()) ||
+ (!IsSigned && !BinOpInst->hasNoUnsignedWrap()))
+ return nullptr;
+
+ // Check if BinOp is a left shift
+ if (BinOp != Instruction::Shl) {
+ return nullptr;
+ }
+
+ APInt C2=llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue() ;
+ APInt C3=llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
+ APInt C1=llvm::dyn_cast<llvm::ConstantInt>(InnerMaxInst->getOperand(1))->getValue();
+
+ // Compute C1 * 2^C2
+ APInt Two = APInt(C2.getBitWidth(), 2);
+ APInt Pow2C2 = Two.shl(C2); // 2^C2
+ APInt C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2
+
+ // Check C3 >= C1 * 2^C2
+ if (C3.ult(C1TimesPow2C2)) {
+ return nullptr;
+ }
+
+ //Create new x binop c2
+ Value *NewBinOp = Builder.CreateBinOp(BinOp, InnerMaxInst->getOperand(0), BinOpInst->getOperand(1) );
+
+ //Create new min/max intrinsic with new binop and c3
+
+ if(IsSigned){
+ cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(true);
+ cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(false);
+ }else{
+ cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(true);
+ cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(false);
+ }
+
+
+ // Get the intrinsic function for MinMaxID
+ Type *Ty = II->getType();
+ Function *MinMaxFn = Intrinsic::getDeclaration(II->getModule(), MinMaxID, {Ty});
+
+ // Create new min/max intrinsic: MinMaxID(NewBinOp, C3) (not inserted)
+ Value *Args[] = {NewBinOp, Op1};
+ Instruction *NewMax = CallInst::Create(MinMaxFn, Args, "", nullptr);
+
+ return NewMax;
+}
+
/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
Type *Ty = MinMax1.getType();
@@ -2035,6 +2113,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Instruction *I = moveAddAfterMinMax(II, Builder))
return I;
+ // minmax(x << shamt , c << shamt) -> minmax(x, c) << shamt
+ if (Instruction *I = moveShiftAfterMinMax(II, Builder))
+ return I;
+
+
// minmax (X & NegPow2C, Y & NegPow2C) --> minmax(X, Y) & NegPow2C
const APInt *RHSC;
if (match(I0, m_OneUse(m_And(m_Value(X), m_NegatedPower2(RHSC)))) &&
diff --git a/llvm/test/Transforms/InstCombine/shift-binop.ll b/llvm/test/Transforms/InstCombine/shift-binop.ll
new file mode 100644
index 0000000000000..78e9c5ea21181
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/shift-binop.ll
@@ -0,0 +1,27 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @src(i8 %arg0) {
+; CHECK-LABEL: @src(
+; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
+; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
+; CHECK-NEXT: ret i8 [[TMP2]]
+;
+ %1 = call i8 @llvm.umax.i8(i8 %arg0, i8 1)
+ %2 = shl nuw i8 %1, 1
+ %3 = call i8 @llvm.umax.i8(i8 %2, i8 16)
+ ret i8 %3
+}
+
+define i8 @tgt(i8 %arg0) {
+; CHECK-LABEL: @tgt(
+; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
+; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
+; CHECK-NEXT: ret i8 [[TMP2]]
+;
+ %1 = shl nuw i8 %arg0, 1
+ %2 = call i8 @llvm.umax.i8(i8 %1, i8 16)
+ ret i8 %2
+}
+
+declare i8 @llvm.umax.i8(i8, i8)
|
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions cpp -- llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp View the diff from clang-format here.diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 04c26bad7..5827d46f3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1175,7 +1175,6 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
: BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
}
-
static bool rightDistributesOverLeft(Instruction::BinaryOps ROp, bool HasNUW,
bool HasNSW, Intrinsic::ID LOp) {
switch (LOp) {
@@ -2195,11 +2194,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Instruction *I = moveAddAfterMinMax(II, Builder))
return I;
- // max(max(X,C1) binop C2, C3) -> max(X binop C2, max(C1 binop C2, C3)) -> max(X binop C2, C4)
- if (Instruction *I = reduceMinMax(II, Builder))
+ // max(max(X,C1) binop C2, C3) -> max(X binop C2, max(C1 binop C2, C3)) ->
+ // max(X binop C2, C4)
+ if (Instruction *I = reduceMinMax(II, Builder))
return I;
-
// minmax (X & NegPow2C, Y & NegPow2C) --> minmax(X, Y) & NegPow2C
const APInt *RHSC;
if (match(I0, m_OneUse(m_And(m_Value(X), m_NegatedPower2(RHSC)))) &&
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the alive2 proof into the PR description.
return nullptr; | ||
|
||
// Check if BinOp is a left shift | ||
if (BinOp != Instruction::Shl) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, you are trying to implement solution 2 suggested by me: #139786 (comment)
If it is the case, you should generalize it to handle most of binops (excluding div/rem), then use simplifyBinOp
and simplifyBinaryIntrinsic
to check if min/max(c1 binop c2, c3)
folds to c3
.
@@ -0,0 +1,27 @@ | |||
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add more negative tests/vector tests/multi-use tests, as suggested by https://llvm.org/docs/InstCombineContributorGuide.html#tests?
; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16) | ||
; CHECK-NEXT: ret i8 [[TMP2]] | ||
; | ||
%1 = call i8 @llvm.umax.i8(i8 %arg0, i8 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use named values.
"If it is the case, you should generalize it to handle most of binops (excluding div/rem), then use simplifyBinOp and simplifyBinaryIntrinsic to check if min/max(c1 binop c2, c3) folds to c3."
made most of suggested changes. Only thing is i didn't understand what has to be done for generalizing div, Please help me with that. about the suggestion you gave: "Solution1: Canonicalize max(x << shamt, c << shamt) into max(x, c) << shamt: https://alive2.llvm.org/ce/z/mQEDAQ Solution2: Generalize to fold max(max(x, c1) binop c2, c3) —> max(x binop c2, c3)." solution 1 was already available so i didn't make it. I only concentrated on solution 2 |
@@ -1174,6 +1174,86 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II, | |||
return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1)) | |||
: BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1)); | |||
} | |||
|
|||
// Try canonicalize max(max(X,C1) binop C2, C3) -> max(X binop C2, C3) | |||
static Instruction *moveShiftAfterMinMax(IntrinsicInst *II, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function name should be updated.
In fact, this fold can be decomposed into two steps:
max(max(X,C1) binop C2, C3) -> // Associative laws
max(max(X binop C2, C1 binop C2), C3) -> // Commutative laws
max(X binop C2, max(C1 binop C2, C3)) -> // Constant fold
max(X binop C2, C4)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max(X, C1) binop C2 -> max(X binop C2, C1 binop C2)
is not always safe for all binops. You can reuse the helper leftDistributesOverRight
.
Yeah. I think it is a better solution. |
1)max(X, C1) binop C2 -> max(X binop C2, C1 binop C2) is not always safe for all binops. You can reuse the helper leftDistributesOverRight. 2)The function name should be updated. In fact, this fold can be decomposed into two steps: max(max(X,C1) binop C2, C3) -> // Associative laws max(max(X binop C2, C1 binop C2), C3) -> // Commutative laws max(X binop C2, max(C1 binop C2, C3)) -> // Constant fold max(X binop C2, C4)
I wrote rightDistributesOverLeft function with the knowledge i have on this opt, please let me know is anything to be added or removed. I considered only Add, Sub, Mul, Shl operation, as i was able to verify only these. another optimization(i didn't write) dominates this optimization |
@@ -1174,6 +1174,163 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II, | |||
return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1)) | |||
: BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1)); | |||
} | |||
|
|||
|
|||
static bool rightDistributesOverLeft(Instruction::BinaryOps ROp, bool HasNUW, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some header comments for this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you provide alive2 proof for this function ((X LOp Y) ROp Z -> (X ROp Z) LOp (Y ROp Z)
)?
Reference: #140526 (comment)
/// Associative laws max(max(X binop C2, C1 binop C2), C3) -> // Commutative | ||
/// laws max(X binop C2, max(C1 binop C2, C3)) -> // Constant fold max(X binop | ||
/// C2, C4) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
case Intrinsic::smin: | ||
// Signed min/max distribute over addition if no signed wrap. | ||
if (HasNSW && ROp == Instruction::Add) | ||
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't hold for smin/smax: https://alive2.llvm.org/ce/z/XFf_U8
const APInt *C; | ||
if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) || | ||
!match(Op1, m_APInt(C))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const APInt *C; | |
if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) || | |
!match(Op1, m_APInt(C))) | |
Constant *C2, *C3; | |
if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_ImmConstant(C2)))) || | |
!match(Op1, m_ImmConstant(C3))) |
// Get constant values | ||
APInt C1 = llvm::dyn_cast<llvm::ConstantInt>(InnerMinMaxInst->getOperand(1)) | ||
->getValue(); | ||
APInt C2 = | ||
llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue(); | ||
APInt C3 = | ||
llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue(); | ||
|
||
// Constant fold: Compute C1 binop C2 | ||
APInt C1BinOpC2, Two, Pow2C2, C1TimesPow2C2; | ||
bool overflow = false; | ||
switch (BinOp) { | ||
case Instruction::Add: | ||
C1BinOpC2 = IsSigned ? C1.sadd_ov(C2, overflow) : C1.uadd_ov(C2, overflow); | ||
break; | ||
case Instruction::Mul: | ||
C1BinOpC2 = IsSigned ? C1.smul_ov(C2, overflow) : C1.umul_ov(C2, overflow); | ||
break; | ||
case Instruction::Sub: | ||
C1BinOpC2 = IsSigned ? C1.ssub_ov(C2, overflow) : C1.usub_ov(C2, overflow); | ||
break; | ||
case Instruction::Shl: | ||
// Compute C1 * 2^C2 | ||
Two = APInt(C2.getBitWidth(), 2); | ||
Pow2C2 = Two.shl(C2); // 2^C2 | ||
C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2 | ||
|
||
// Check C3 >= C1 * 2^C2 | ||
if (C3.ult(C1TimesPow2C2)) { | ||
return nullptr; | ||
} else { | ||
C1BinOpC2 = C1.shl(C2); | ||
} | ||
break; | ||
default: | ||
return nullptr; // Unsupported binary operation | ||
} | ||
|
||
// Constant fold: Compute MinMaxID(C1 binop C2, C3) to get C4 | ||
APInt C4; | ||
switch (MinMaxID) { | ||
case Intrinsic::umax: | ||
C4 = APIntOps::umax(C1BinOpC2, C3); | ||
break; | ||
case Intrinsic::umin: | ||
C4 = APIntOps::umin(C1BinOpC2, C3); | ||
break; | ||
case Intrinsic::smax: | ||
C4 = APIntOps::smax(C1BinOpC2, C3); | ||
break; | ||
case Intrinsic::smin: | ||
C4 = APIntOps::smin(C1BinOpC2, C3); | ||
break; | ||
default: | ||
return nullptr; // Unsupported intrinsic | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Get constant values | |
APInt C1 = llvm::dyn_cast<llvm::ConstantInt>(InnerMinMaxInst->getOperand(1)) | |
->getValue(); | |
APInt C2 = | |
llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue(); | |
APInt C3 = | |
llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue(); | |
// Constant fold: Compute C1 binop C2 | |
APInt C1BinOpC2, Two, Pow2C2, C1TimesPow2C2; | |
bool overflow = false; | |
switch (BinOp) { | |
case Instruction::Add: | |
C1BinOpC2 = IsSigned ? C1.sadd_ov(C2, overflow) : C1.uadd_ov(C2, overflow); | |
break; | |
case Instruction::Mul: | |
C1BinOpC2 = IsSigned ? C1.smul_ov(C2, overflow) : C1.umul_ov(C2, overflow); | |
break; | |
case Instruction::Sub: | |
C1BinOpC2 = IsSigned ? C1.ssub_ov(C2, overflow) : C1.usub_ov(C2, overflow); | |
break; | |
case Instruction::Shl: | |
// Compute C1 * 2^C2 | |
Two = APInt(C2.getBitWidth(), 2); | |
Pow2C2 = Two.shl(C2); // 2^C2 | |
C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2 | |
// Check C3 >= C1 * 2^C2 | |
if (C3.ult(C1TimesPow2C2)) { | |
return nullptr; | |
} else { | |
C1BinOpC2 = C1.shl(C2); | |
} | |
break; | |
default: | |
return nullptr; // Unsupported binary operation | |
} | |
// Constant fold: Compute MinMaxID(C1 binop C2, C3) to get C4 | |
APInt C4; | |
switch (MinMaxID) { | |
case Intrinsic::umax: | |
C4 = APIntOps::umax(C1BinOpC2, C3); | |
break; | |
case Intrinsic::umin: | |
C4 = APIntOps::umin(C1BinOpC2, C3); | |
break; | |
case Intrinsic::smax: | |
C4 = APIntOps::smax(C1BinOpC2, C3); | |
break; | |
case Intrinsic::smin: | |
C4 = APIntOps::smin(C1BinOpC2, C3); | |
break; | |
default: | |
return nullptr; // Unsupported intrinsic | |
} | |
Constant *C1; | |
if (!match(InnerMinMaxInst->getRHS(), m_ImmConstant(C1)) | |
return nullptr; | |
Constant *C1BinOpC2 = ConstantFoldBinaryOpOperands(BinOp, C1, C2, DL); | |
Constant *C4 = ConstantFoldBinaryIntrinsic(MinMaxID, C1BinOpC2, C3, C3->getType(), nullptr); |
As suggested generalize to fold max(max(x, c1) binop c2, c3) —> max(x binop c2, c3) if c3>=C1* 2 ^ c2 is done.
define i8 @src(i8 %arg0) {
%1 = call i8 @llvm.umax.i8(i8 %arg0, i8 1)
%2 = shl nuw i8 %1, 1
%3 = call i8 @llvm.umax.i8(i8 %2, i8 16)
ret i8 %3
}
define i8 @tgt(i8 %arg0) {
%1 = shl nuw i8 %arg0, 1
%2 = call i8 @llvm.umax.i8(i8 %1, i8 16)
ret i8 %2
}
Closes #139786.