Skip to content

[InstCombine] Fold max(max(x, c1) << c2, c3) —> max(x << c2, c3) when c3 >= c1 * 2 ^ c2 #140526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
162 changes: 162 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,163 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
: BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
}


static bool rightDistributesOverLeft(Instruction::BinaryOps ROp, bool HasNUW,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some header comments for this function.

Copy link
Member

@dtcxzyw dtcxzyw May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide alive2 proof for this function ((X LOp Y) ROp Z -> (X ROp Z) LOp (Y ROp Z))?
Reference: #140526 (comment)

bool HasNSW, Intrinsic::ID LOp) {
switch (LOp) {
case Intrinsic::umax:
case Intrinsic::umin:
// Unsigned min/max distribute over addition and left shift if no unsigned
// wrap.
if (HasNUW && (ROp == Instruction::Add || ROp == Instruction::Shl))
return true;
// Multiplication preserves order for unsigned min/max with no unsigned
// wrap.
if (HasNUW && ROp == Instruction::Mul)
return true;
return false;
case Intrinsic::smax:
case Intrinsic::smin:
// Signed min/max distribute over addition if no signed wrap.
if (HasNSW && ROp == Instruction::Add)
return true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't hold for smin/smax: https://alive2.llvm.org/ce/z/XFf_U8

// Multiplication preserves order for signed min/max with no signed wrap.
if (HasNSW && ROp == Instruction::Mul)
return true;
return false;
default:
return false;
}
}

/// Try canonicalize max(max(X,C1) binop C2, C3) -> max(X binop C2, max(C1
/// binop C2, C3)) -> max(X binop C2, C4) max(max(X,C1) binop C2, C3) -> //
/// Associative laws max(max(X binop C2, C1 binop C2), C3) -> // Commutative
/// laws max(X binop C2, max(C1 binop C2, C3)) -> // Constant fold max(X binop
/// C2, C4)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

static Instruction *reduceMinMax(IntrinsicInst *II,
InstCombiner::BuilderTy &Builder) {
Intrinsic::ID MinMaxID = II->getIntrinsicID();
assert(isa<MinMaxIntrinsic>(II) && "Expected a min or max intrinsic");

Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1);
Value *InnerMax;
const APInt *C;
if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) ||
!match(Op1, m_APInt(C)))
Comment on lines +1220 to +1222
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const APInt *C;
if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) ||
!match(Op1, m_APInt(C)))
Constant *C2, *C3;
if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_ImmConstant(C2)))) ||
!match(Op1, m_ImmConstant(C3)))

return nullptr;

auto *BinOpInst = cast<BinaryOperator>(Op0);
Instruction::BinaryOps BinOp = BinOpInst->getOpcode();

InnerMax = BinOpInst->getOperand(0);

auto *InnerMinMaxInst = dyn_cast<MinMaxIntrinsic>(BinOpInst->getOperand(0));
if (!InnerMinMaxInst || !InnerMinMaxInst->hasOneUse())
return nullptr;

bool IsSigned = InnerMinMaxInst->isSigned();
if (InnerMinMaxInst->getIntrinsicID() != MinMaxID)
return nullptr;

if ((IsSigned && !BinOpInst->hasNoSignedWrap()) ||
(!IsSigned && !BinOpInst->hasNoUnsignedWrap()))
return nullptr;

if (!rightDistributesOverLeft(BinOp, BinOpInst->hasNoUnsignedWrap(),
BinOpInst->hasNoSignedWrap(),
InnerMinMaxInst->getIntrinsicID()))
return nullptr;

// Get constant values
APInt C1 = llvm::dyn_cast<llvm::ConstantInt>(InnerMinMaxInst->getOperand(1))
->getValue();
APInt C2 =
llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue();
APInt C3 =
llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();

// Constant fold: Compute C1 binop C2
APInt C1BinOpC2, Two, Pow2C2, C1TimesPow2C2;
bool overflow = false;
switch (BinOp) {
case Instruction::Add:
C1BinOpC2 = IsSigned ? C1.sadd_ov(C2, overflow) : C1.uadd_ov(C2, overflow);
break;
case Instruction::Mul:
C1BinOpC2 = IsSigned ? C1.smul_ov(C2, overflow) : C1.umul_ov(C2, overflow);
break;
case Instruction::Sub:
C1BinOpC2 = IsSigned ? C1.ssub_ov(C2, overflow) : C1.usub_ov(C2, overflow);
break;
case Instruction::Shl:
// Compute C1 * 2^C2
Two = APInt(C2.getBitWidth(), 2);
Pow2C2 = Two.shl(C2); // 2^C2
C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2

// Check C3 >= C1 * 2^C2
if (C3.ult(C1TimesPow2C2)) {
return nullptr;
} else {
C1BinOpC2 = C1.shl(C2);
}
break;
default:
return nullptr; // Unsupported binary operation
}

// Constant fold: Compute MinMaxID(C1 binop C2, C3) to get C4
APInt C4;
switch (MinMaxID) {
case Intrinsic::umax:
C4 = APIntOps::umax(C1BinOpC2, C3);
break;
case Intrinsic::umin:
C4 = APIntOps::umin(C1BinOpC2, C3);
break;
case Intrinsic::smax:
C4 = APIntOps::smax(C1BinOpC2, C3);
break;
case Intrinsic::smin:
C4 = APIntOps::smin(C1BinOpC2, C3);
break;
default:
return nullptr; // Unsupported intrinsic
}
Comment on lines +1247 to +1302
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Get constant values
APInt C1 = llvm::dyn_cast<llvm::ConstantInt>(InnerMinMaxInst->getOperand(1))
->getValue();
APInt C2 =
llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue();
APInt C3 =
llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
// Constant fold: Compute C1 binop C2
APInt C1BinOpC2, Two, Pow2C2, C1TimesPow2C2;
bool overflow = false;
switch (BinOp) {
case Instruction::Add:
C1BinOpC2 = IsSigned ? C1.sadd_ov(C2, overflow) : C1.uadd_ov(C2, overflow);
break;
case Instruction::Mul:
C1BinOpC2 = IsSigned ? C1.smul_ov(C2, overflow) : C1.umul_ov(C2, overflow);
break;
case Instruction::Sub:
C1BinOpC2 = IsSigned ? C1.ssub_ov(C2, overflow) : C1.usub_ov(C2, overflow);
break;
case Instruction::Shl:
// Compute C1 * 2^C2
Two = APInt(C2.getBitWidth(), 2);
Pow2C2 = Two.shl(C2); // 2^C2
C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2
// Check C3 >= C1 * 2^C2
if (C3.ult(C1TimesPow2C2)) {
return nullptr;
} else {
C1BinOpC2 = C1.shl(C2);
}
break;
default:
return nullptr; // Unsupported binary operation
}
// Constant fold: Compute MinMaxID(C1 binop C2, C3) to get C4
APInt C4;
switch (MinMaxID) {
case Intrinsic::umax:
C4 = APIntOps::umax(C1BinOpC2, C3);
break;
case Intrinsic::umin:
C4 = APIntOps::umin(C1BinOpC2, C3);
break;
case Intrinsic::smax:
C4 = APIntOps::smax(C1BinOpC2, C3);
break;
case Intrinsic::smin:
C4 = APIntOps::smin(C1BinOpC2, C3);
break;
default:
return nullptr; // Unsupported intrinsic
}
Constant *C1;
if (!match(InnerMinMaxInst->getRHS(), m_ImmConstant(C1))
return nullptr;
Constant *C1BinOpC2 = ConstantFoldBinaryOpOperands(BinOp, C1, C2, DL);
Constant *C4 = ConstantFoldBinaryIntrinsic(MinMaxID, C1BinOpC2, C3, C3->getType(), nullptr);


// Create new X binop C2
Value *NewBinOp = Builder.CreateBinOp(BinOp, InnerMinMaxInst->getOperand(0),
BinOpInst->getOperand(1));

// Set overflow flags on new binary operation
if (auto *NewBinInst = dyn_cast<Instruction>(NewBinOp)) {
if (IsSigned) {
NewBinInst->setHasNoSignedWrap(true);
NewBinInst->setHasNoUnsignedWrap(false);
} else {
NewBinInst->setHasNoUnsignedWrap(true);
NewBinInst->setHasNoSignedWrap(false);
}
}

// Create constant for C4
Value *C4Val = ConstantInt::get(II->getType(), C4);

// Get the intrinsic function for MinMaxID
Type *Ty = II->getType();
Function *MinMaxFn =
Intrinsic::getDeclaration(II->getModule(), MinMaxID, {Ty});

// Create new min/max intrinsic: MinMaxID(NewBinOp, C4)
Value *Args[] = {NewBinOp, C4Val};
Instruction *NewMax = CallInst::Create(MinMaxFn, Args, "", nullptr);

return NewMax;
}

/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
Type *Ty = MinMax1.getType();
Expand Down Expand Up @@ -2038,6 +2195,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Instruction *I = moveAddAfterMinMax(II, Builder))
return I;

// max(max(X,C1) binop C2, C3) -> max(X binop C2, max(C1 binop C2, C3)) -> max(X binop C2, C4)
if (Instruction *I = reduceMinMax(II, Builder))
return I;


// minmax (X & NegPow2C, Y & NegPow2C) --> minmax(X, Y) & NegPow2C
const APInt *RHSC;
if (match(I0, m_OneUse(m_And(m_Value(X), m_NegatedPower2(RHSC)))) &&
Expand Down
27 changes: 27 additions & 0 deletions llvm/test/Transforms/InstCombine/shift-binop.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add more negative tests/vector tests/multi-use tests, as suggested by https://llvm.org/docs/InstCombineContributorGuide.html#tests?

; RUN: opt < %s -passes=instcombine -S | FileCheck %s

define i32 @src1(i32 %arg0) {
; CHECK-LABEL: @src1(
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 [[ARG0:%.*]], 2
; CHECK-NEXT: [[OUTMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[SHL]], i32 8)
; CHECK-NEXT: ret i32 [[OUTMIN]]
;
%1 = call i32 @llvm.umin.i32(i32 %arg0, i32 2)
%2 = shl nuw i32 %1, 2
%3 = call i32 @llvm.umin.i32(i32 %2, i32 16)
ret i32 %3
}

define i32 @src2(i32 %arg0) {
; CHECK-LABEL: @src2(
; CHECK-NEXT: [[INMAX:%.*]] = call i32 @llvm.smax.i32(i32 [[ARG0:%.*]], i32 2)
; CHECK-NEXT: [[SHL:%.*]] = shl nsw i32 [[INMAX]], 18
; CHECK-NEXT: [[OUTMAX:%.*]] = call i32 @llvm.smax.i32(i32 [[SHL]], i32 10)
; CHECK-NEXT: ret i32 [[OUTMAX]]
;
%1 = call i32 @llvm.smax.i32(i32 %arg0, i32 2)
%2 = shl nsw i32 %1, 18
%3 = call i32 @llvm.smax.i32(i32 %2, i32 10)
ret i32 %3
}
Loading