Skip to content

[InstCombine] Fold max(max(x, c1) << c2, c3) —> max(x << c2, c3) when c3 >= c1 * 2 ^ c2 #140526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
83 changes: 83 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,84 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
: BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
}

//Try canonicalize min/max(x << shamt, c<<shamt) into max(x, c) << shamt
static Instruction *moveShiftAfterMinMax(IntrinsicInst *II, InstCombiner::BuilderTy &Builder) {
Intrinsic::ID MinMaxID = II->getIntrinsicID();
assert((MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin ||
MinMaxID == Intrinsic::umax || MinMaxID == Intrinsic::umin) &&
"Expected a min or max intrinsic");

Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1);
Value *InnerMax;
const APInt *C;
if (!match(Op0, m_OneUse(m_BinOp(m_Value(InnerMax), m_APInt(C)))) ||
!match(Op1, m_APInt(C)))
return nullptr;

auto* BinOpInst = cast<BinaryOperator>(Op0);
Instruction::BinaryOps BinOp = BinOpInst->getOpcode();
Value *X;
InnerMax = BinOpInst->getOperand(0);
// std::cout<< InnerMax->dump() <<std::endl;
if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umax>(m_Value(X),m_APInt(C))))){
if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smax>(m_Value(X),m_APInt(C))))){
if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::umin>(m_Value(X),m_APInt(C))))){
if(!match(InnerMax,m_OneUse(m_Intrinsic<Intrinsic::smin>(m_Value(X),m_APInt(C))))){
return nullptr;
}}}}

auto *InnerMaxInst = cast<IntrinsicInst>(InnerMax);

bool IsSigned = MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin;
if((IsSigned && !BinOpInst->hasNoSignedWrap()) ||
(!IsSigned && !BinOpInst->hasNoUnsignedWrap()))
return nullptr;

// Check if BinOp is a left shift
if (BinOp != Instruction::Shl) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, you are trying to implement solution 2 suggested by me: #139786 (comment)

If it is the case, you should generalize it to handle most of binops (excluding div/rem), then use simplifyBinOp and simplifyBinaryIntrinsic to check if min/max(c1 binop c2, c3) folds to c3.

return nullptr;
}

APInt C2=llvm::dyn_cast<llvm::ConstantInt>(BinOpInst->getOperand(1))->getValue() ;
APInt C3=llvm::dyn_cast<llvm::ConstantInt>(II->getArgOperand(1))->getValue();
APInt C1=llvm::dyn_cast<llvm::ConstantInt>(InnerMaxInst->getOperand(1))->getValue();

// Compute C1 * 2^C2
APInt Two = APInt(C2.getBitWidth(), 2);
APInt Pow2C2 = Two.shl(C2); // 2^C2
APInt C1TimesPow2C2 = C1 * Pow2C2; // C1 * 2^C2

// Check C3 >= C1 * 2^C2
if (C3.ult(C1TimesPow2C2)) {
return nullptr;
}

//Create new x binop c2
Value *NewBinOp = Builder.CreateBinOp(BinOp, InnerMaxInst->getOperand(0), BinOpInst->getOperand(1) );

//Create new min/max intrinsic with new binop and c3

if(IsSigned){
cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(true);
cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(false);
}else{
cast<Instruction>(NewBinOp) -> setHasNoUnsignedWrap(true);
cast<Instruction>(NewBinOp) -> setHasNoSignedWrap(false);
}


// Get the intrinsic function for MinMaxID
Type *Ty = II->getType();
Function *MinMaxFn = Intrinsic::getDeclaration(II->getModule(), MinMaxID, {Ty});

// Create new min/max intrinsic: MinMaxID(NewBinOp, C3) (not inserted)
Value *Args[] = {NewBinOp, Op1};
Instruction *NewMax = CallInst::Create(MinMaxFn, Args, "", nullptr);

return NewMax;
}

/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
Type *Ty = MinMax1.getType();
Expand Down Expand Up @@ -2035,6 +2113,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Instruction *I = moveAddAfterMinMax(II, Builder))
return I;

// minmax(x << shamt , c << shamt) -> minmax(x, c) << shamt
if (Instruction *I = moveShiftAfterMinMax(II, Builder))
return I;


// minmax (X & NegPow2C, Y & NegPow2C) --> minmax(X, Y) & NegPow2C
const APInt *RHSC;
if (match(I0, m_OneUse(m_And(m_Value(X), m_NegatedPower2(RHSC)))) &&
Expand Down
27 changes: 27 additions & 0 deletions llvm/test/Transforms/InstCombine/shift-binop.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add more negative tests/vector tests/multi-use tests, as suggested by https://llvm.org/docs/InstCombineContributorGuide.html#tests?

; RUN: opt < %s -passes=instcombine -S | FileCheck %s

define i8 @src(i8 %arg0) {
; CHECK-LABEL: @src(
; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
; CHECK-NEXT: ret i8 [[TMP2]]
;
%1 = call i8 @llvm.umax.i8(i8 %arg0, i8 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use named values.

%2 = shl nuw i8 %1, 1
%3 = call i8 @llvm.umax.i8(i8 %2, i8 16)
ret i8 %3
}

define i8 @tgt(i8 %arg0) {
; CHECK-LABEL: @tgt(
; CHECK-NEXT: [[TMP1:%.*]] = shl nuw i8 [[ARG0:%.*]], 1
; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.umax.i8(i8 [[TMP1]], i8 16)
; CHECK-NEXT: ret i8 [[TMP2]]
;
%1 = shl nuw i8 %arg0, 1
%2 = call i8 @llvm.umax.i8(i8 %1, i8 16)
ret i8 %2
}

declare i8 @llvm.umax.i8(i8, i8)
Loading