Skip to content

Commit ef2d6da

Browse files
authored
[InstCombine] Transform (fcmp + fadd + sel) into (fcmp + sel + fadd) (#106492)
Transform `fcmp + fadd + sel` into `fcmp + sel + fadd` which enables the possibility of transforming `fcmp + sel` into `maxnum/minnum` intrinsics. Alive2 results: https://alive2.llvm.org/ce/z/2cmimW https://alive2.llvm.org/ce/z/Qh9ZJt https://alive2.llvm.org/ce/z/vtLj3R
1 parent b816c26 commit ef2d6da

File tree

3 files changed

+745
-0
lines changed

3 files changed

+745
-0
lines changed

llvm/include/llvm/IR/FMF.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,20 @@ class FastMathFlags {
108108

109109
/// Print fast-math flags to \p O.
110110
void print(raw_ostream &O) const;
111+
112+
/// Intersect rewrite-based flags
113+
static inline FastMathFlags intersectRewrite(FastMathFlags LHS,
114+
FastMathFlags RHS) {
115+
const unsigned RewriteMask =
116+
AllowReassoc | AllowReciprocal | AllowContract | ApproxFunc;
117+
return FastMathFlags(RewriteMask & LHS.Flags & RHS.Flags);
118+
}
119+
120+
/// Union value flags
121+
static inline FastMathFlags unionValue(FastMathFlags LHS, FastMathFlags RHS) {
122+
const unsigned ValueMask = NoNaNs | NoInfs | NoSignedZeros;
123+
return FastMathFlags(ValueMask & (LHS.Flags | RHS.Flags));
124+
}
111125
};
112126

113127
inline FastMathFlags operator|(FastMathFlags LHS, FastMathFlags RHS) {

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3645,6 +3645,60 @@ static bool hasAffectedValue(Value *V, SmallPtrSetImpl<Value *> &Affected,
36453645
return false;
36463646
}
36473647

3648+
// This transformation enables the possibility of transforming fcmp + sel into
3649+
// a fmaxnum/fminnum intrinsic.
3650+
static Value *foldSelectIntoAddConstant(SelectInst &SI,
3651+
InstCombiner::BuilderTy &Builder) {
3652+
// Do this transformation only when select instruction gives NaN and NSZ
3653+
// guarantee.
3654+
auto *SIFOp = dyn_cast<FPMathOperator>(&SI);
3655+
if (!SIFOp || !SIFOp->hasNoSignedZeros() || !SIFOp->hasNoNaNs())
3656+
return nullptr;
3657+
3658+
// select((fcmp Pred, X, 0), (fadd X, C), C)
3659+
// => fadd((select (fcmp Pred, X, 0), X, 0), C)
3660+
//
3661+
// Pred := OGT, OGE, OLT, OLE, UGT, UGE, ULT, and ULE
3662+
Instruction *FAdd;
3663+
Constant *C;
3664+
Value *X, *Z;
3665+
CmpInst::Predicate Pred;
3666+
3667+
// Note: OneUse check for `Cmp` is necessary because it makes sure that other
3668+
// InstCombine folds don't undo this transformation and cause an infinite
3669+
// loop. Furthermore, it could also increase the operation count.
3670+
if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
3671+
m_OneUse(m_Instruction(FAdd)), m_Constant(C))) ||
3672+
match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))),
3673+
m_Constant(C), m_OneUse(m_Instruction(FAdd))))) {
3674+
// Only these relational predicates can be transformed into maxnum/minnum
3675+
// intrinsic.
3676+
if (!CmpInst::isRelational(Pred) || !match(Z, m_AnyZeroFP()))
3677+
return nullptr;
3678+
3679+
if (!match(FAdd, m_FAdd(m_Specific(X), m_Specific(C))))
3680+
return nullptr;
3681+
3682+
Value *NewSelect = Builder.CreateSelect(SI.getCondition(), X, Z, "", &SI);
3683+
NewSelect->takeName(&SI);
3684+
3685+
Value *NewFAdd = Builder.CreateFAdd(NewSelect, C);
3686+
NewFAdd->takeName(FAdd);
3687+
3688+
// Propagate FastMath flags
3689+
FastMathFlags SelectFMF = SI.getFastMathFlags();
3690+
FastMathFlags FAddFMF = FAdd->getFastMathFlags();
3691+
FastMathFlags NewFMF = FastMathFlags::intersectRewrite(SelectFMF, FAddFMF) |
3692+
FastMathFlags::unionValue(SelectFMF, FAddFMF);
3693+
cast<Instruction>(NewFAdd)->setFastMathFlags(NewFMF);
3694+
cast<Instruction>(NewSelect)->setFastMathFlags(NewFMF);
3695+
3696+
return NewFAdd;
3697+
}
3698+
3699+
return nullptr;
3700+
}
3701+
36483702
Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
36493703
Value *CondVal = SI.getCondition();
36503704
Value *TrueVal = SI.getTrueValue();
@@ -4041,6 +4095,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
40414095
if (Value *V = foldRoundUpIntegerWithPow2Alignment(SI, Builder))
40424096
return replaceInstUsesWith(SI, V);
40434097

4098+
if (Value *V = foldSelectIntoAddConstant(SI, Builder))
4099+
return replaceInstUsesWith(SI, V);
4100+
40444101
// select(mask, mload(,,mask,0), 0) -> mload(,,mask,0)
40454102
// Load inst is intentionally not checked for hasOneUse()
40464103
if (match(FalseVal, m_Zero()) &&

0 commit comments

Comments
 (0)