Skip to content

[LV] Vectorize selecting index of min/max element. #141431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions llvm/include/llvm/Analysis/IVDescriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ enum class RecurKind {
FindLastIV, ///< FindLast reduction with select(cmp(),x,y) where one of
///< (x,y) is increasing loop induction, and both x and y are
///< integer type.
FindFirstIVUMin, /// FindFirst reduction with select(icmp(),x,y) where one of
///< (x,y) is a decreasing loop induction, and both x and y
///< are integer type.
FindFirstIVSMin /// FindFirst reduction with select(icmp(),x,y) where one of
///< (x,y) is a decreasing loop induction, and both x and y
///< are integer type.

// clang-format on
// TODO: Any_of and FindLast reduction need not be restricted to integer type
// only.
Expand Down Expand Up @@ -160,12 +167,13 @@ class RecurrenceDescriptor {
/// Returns a struct describing whether the instruction is either a
/// Select(ICmp(A, B), X, Y), or
/// Select(FCmp(A, B), X, Y)
/// where one of (X, Y) is an increasing loop induction variable, and the
/// other is a PHI value.
/// where one of (X, Y) is an increasing (FindLast) or decreasing (FindFirst)
/// loop induction variable, and the other is a PHI value.
// TODO: Support non-monotonic variable. FindLast does not need be restricted
// to increasing loop induction variables.
static InstDesc isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
Instruction *I, ScalarEvolution &SE);
static InstDesc isFindIVPattern(RecurKind Kind, Loop *TheLoop,
PHINode *OrigPhi, Instruction *I,
ScalarEvolution &SE);

/// Returns a struct describing if the instruction is a
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
Expand Down Expand Up @@ -259,19 +267,37 @@ class RecurrenceDescriptor {
return Kind == RecurKind::FindLastIV;
}

/// Returns true if the recurrence kind is of the form
/// select(cmp(),x,y) where one of (x,y) is an increasing or decreasing loop
/// induction.
static bool isFindIVRecurrenceKind(RecurKind Kind) {
return Kind == RecurKind::FindLastIV ||
Kind == RecurKind::FindFirstIVUMin ||
Kind == RecurKind::FindFirstIVSMin;
}

/// Returns the type of the recurrence. This type can be narrower than the
/// actual type of the Phi if the recurrence has been type-promoted.
Type *getRecurrenceType() const { return RecurrenceType; }

/// Returns the sentinel value for FindLastIV recurrences to replace the start
/// value.
/// Returns the sentinel value for FindFirstIV &FindLastIV recurrences to
/// replace the start value.
Value *getSentinelValue() const {
assert(isFindLastIVRecurrenceKind(Kind) && "Unexpected recurrence kind");
Type *Ty = StartValue->getType();
return ConstantInt::get(Ty,
APInt::getSignedMinValue(Ty->getIntegerBitWidth()));
if (isFindLastIVRecurrenceKind(Kind)) {
return ConstantInt::get(
Ty, APInt::getSignedMinValue(Ty->getIntegerBitWidth()));
} else if (Kind == RecurKind::FindFirstIVSMin) {
return ConstantInt::get(
Ty, APInt::getSignedMaxValue(Ty->getIntegerBitWidth()));
} else {
assert(Kind == RecurKind::FindFirstIVUMin);
return ConstantInt::get(Ty, APInt::getMaxValue(Ty->getIntegerBitWidth()));
}
}

void setKind(RecurKind NewKind) { Kind = NewKind; }

/// Returns a reference to the instructions used for type-promoting the
/// recurrence.
const SmallPtrSet<Instruction *, 8> &getCastInsts() const { return CastInsts; }
Expand Down Expand Up @@ -303,6 +329,10 @@ class RecurrenceDescriptor {
/// AddReductionVar method, this field will be assigned the last met store.
StoreInst *IntermediateStore = nullptr;

/// True if this recurrence is used by another recurrence in the loop. Users
/// need to ensure that the final code-gen accounts for the use in the loop.
bool IsUsedByOtherRecurrence = false;

private:
// The starting value of the recurrence.
// It does not have to be zero!
Expand Down
114 changes: 90 additions & 24 deletions llvm/lib/Analysis/IVDescriptors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
case RecurKind::UMin:
case RecurKind::AnyOf:
case RecurKind::FindLastIV:
case RecurKind::FindFirstIVUMin:
case RecurKind::FindFirstIVSMin:
return true;
}
return false;
Expand Down Expand Up @@ -253,7 +255,7 @@ bool RecurrenceDescriptor::AddReductionVar(
// Data used for determining if the recurrence has been type-promoted.
Type *RecurrenceType = Phi->getType();
SmallPtrSet<Instruction *, 4> CastInsts;
unsigned MinWidthCastToRecurrenceType;
unsigned MinWidthCastToRecurrenceType = -1ull;
Instruction *Start = Phi;
bool IsSigned = false;

Expand Down Expand Up @@ -308,6 +310,7 @@ bool RecurrenceDescriptor::AddReductionVar(
// This is either:
// * An instruction type other than PHI or the reduction operation.
// * A PHI in the header other than the initial PHI.
bool IsUsedByOtherRecurrence = false;
while (!Worklist.empty()) {
Instruction *Cur = Worklist.pop_back_val();

Expand Down Expand Up @@ -369,15 +372,37 @@ bool RecurrenceDescriptor::AddReductionVar(

// Any reduction instruction must be of one of the allowed kinds. We ignore
// the starting value (the Phi or an AND instruction if the Phi has been
// type-promoted).
// type-promoted) and other in-loop users, if they form a FindLastIV
// reduction. In the latter case, the user of the IVDescriptors must account
// for that during codegen.
if (Cur != Start) {
ReduxDesc =
isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF, SE);
ExactFPMathInst = ExactFPMathInst == nullptr
? ReduxDesc.getExactFPMathInst()
: ExactFPMathInst;
if (!ReduxDesc.isRecurrence())
if (!ReduxDesc.isRecurrence()) {
if (isMinMaxRecurrenceKind(Kind)) {
// If the current recurrence is Min/Max, check if the current user is
// a select that is a FindLastIV reduction. During codegen, this
// recurrence needs to be turned into one that finds the first IV, as
// the value to compare against is a Min/Max recurrence.
auto *Sel = dyn_cast<SelectInst>(Cur);
if (!Sel || !Sel->getType()->isIntegerTy())
return false;
auto *OtherPhi = dyn_cast<PHINode>(Sel->getOperand(2));
if (!OtherPhi)
return false;
auto NewReduxDesc =
isRecurrenceInstr(TheLoop, OtherPhi, Cur, RecurKind::FindLastIV,
ReduxDesc, FuncFMF, SE);
if (NewReduxDesc.isRecurrence()) {
IsUsedByOtherRecurrence = true;
continue;
}
}
return false;
}
// FIXME: FMF is allowed on phi, but propagation is not handled correctly.
if (isa<FPMathOperator>(ReduxDesc.getPatternInst()) && !IsAPhi) {
FastMathFlags CurFMF = ReduxDesc.getPatternInst()->getFastMathFlags();
Expand Down Expand Up @@ -501,7 +526,7 @@ bool RecurrenceDescriptor::AddReductionVar(
// pattern or more than just a select and cmp. Zero implies that we saw a
// llvm.min/max intrinsic, which is always OK.
if (isMinMaxRecurrenceKind(Kind) && NumCmpSelectPatternInst != 2 &&
NumCmpSelectPatternInst != 0)
NumCmpSelectPatternInst != 0 && !IsUsedByOtherRecurrence)
return false;

if (isAnyOfRecurrenceKind(Kind) && NumCmpSelectPatternInst != 1)
Expand Down Expand Up @@ -533,7 +558,13 @@ bool RecurrenceDescriptor::AddReductionVar(
ExitInstruction = cast<Instruction>(IntermediateStore->getValueOperand());
}

if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction)
if (!FoundStartPHI || !FoundReduxOp)
return false;

if (IsUsedByOtherRecurrence) {
if (ExitInstruction)
return false;
} else if (!ExitInstruction)
return false;

const bool IsOrdered =
Expand Down Expand Up @@ -584,8 +615,9 @@ bool RecurrenceDescriptor::AddReductionVar(
// without needing a white list of instructions to ignore.
// This may also be useful for the inloop reductions, if it can be
// kept simple enough.
collectCastInstrs(TheLoop, ExitInstruction, RecurrenceType, CastInsts,
MinWidthCastToRecurrenceType);
if (ExitInstruction)
collectCastInstrs(TheLoop, ExitInstruction, RecurrenceType, CastInsts,
MinWidthCastToRecurrenceType);

// We found a reduction var if we have reached the original phi node and we
// only have a single instruction with out-of-loop users.
Expand All @@ -598,7 +630,7 @@ bool RecurrenceDescriptor::AddReductionVar(
FMF, ExactFPMathInst, RecurrenceType, IsSigned,
IsOrdered, CastInsts, MinWidthCastToRecurrenceType);
RedDes = RD;

RedDes.IsUsedByOtherRecurrence = IsUsedByOtherRecurrence;
return true;
}

Expand Down Expand Up @@ -683,8 +715,9 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
// value of the data type or a non-constant value by using mask and multiple
// reduction operations.
RecurrenceDescriptor::InstDesc
RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
Instruction *I, ScalarEvolution &SE) {
RecurrenceDescriptor::isFindIVPattern(RecurKind Kind, Loop *TheLoop,
PHINode *OrigPhi, Instruction *I,
ScalarEvolution &SE) {
// TODO: Support the vectorization of FindLastIV when the reduction phi is
// used by more than one select instruction. This vectorization is only
// performed when the SCEV of each increasing induction variable used by the
Expand All @@ -700,7 +733,7 @@ RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
m_Value(NonRdxPhi)))))
return InstDesc(false, I);

auto IsIncreasingLoopInduction = [&](Value *V) {
auto IsSupportedLoopInduction = [&](Value *V, RecurKind Kind) {
Type *Ty = V->getType();
if (!SE.isSCEVable(Ty))
return false;
Expand All @@ -710,21 +743,39 @@ RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
return false;

const SCEV *Step = AR->getStepRecurrence(SE);
if (!SE.isKnownPositive(Step))
if (Kind == RecurKind::FindFirstIVUMin ||
Kind == RecurKind::FindFirstIVSMin) {
if (!SE.isKnownNegative(Step))
return false;
} else if (!SE.isKnownPositive(Step))
return false;

const ConstantRange IVRange = SE.getSignedRange(AR);
unsigned NumBits = Ty->getIntegerBitWidth();
// Keep the minimum value of the recurrence type as the sentinel value.
// The maximum acceptable range for the increasing induction variable,
// called the valid range, will be defined as
// Keep the minimum (FindLast) or maximum (FindFirst) value of the
// recurrence type as the sentinel value. The maximum acceptable range for
// the induction variable, called the valid range, will be defined as
// [<sentinel value> + 1, <sentinel value>)
// where <sentinel value> is SignedMin(<recurrence type>)
// where <sentinel value> is SignedMin(<recurrence type>) for FindLast or
// UnsignedMax(<recurrence type>) for FindFirst.
// TODO: This range restriction can be lifted by adding an additional
// virtual OR reduction.
const APInt Sentinel = APInt::getSignedMinValue(NumBits);
const ConstantRange ValidRange =
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
const APInt Sentinel = Kind == RecurKind::FindFirstIVUMin
? APInt::getMaxValue(NumBits)
: (Kind == RecurKind::FindFirstIVSMin
? APInt::getSignedMaxValue(NumBits)
: APInt::getSignedMinValue(NumBits));
ConstantRange ValidRange = ConstantRange::getEmpty(NumBits);
if (Kind == RecurKind::FindFirstIVSMin)
ValidRange =
ConstantRange::getNonEmpty(APInt::getSignedMinValue(NumBits),
APInt::getSignedMaxValue(NumBits) - 1);
else {
const APInt Sentinel = Kind == RecurKind::FindFirstIVUMin
? APInt::getMaxValue(NumBits)
: APInt::getSignedMinValue(NumBits);
ValidRange = ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
}
LLVM_DEBUG(dbgs() << "LV: FindLastIV valid range is " << ValidRange
<< ", and the signed range of " << *AR << " is "
<< IVRange << "\n");
Expand All @@ -736,11 +787,18 @@ RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
// We are looking for selects of the form:
// select(cmp(), phi, increasing_loop_induction) or
// select(cmp(), increasing_loop_induction, phi)
// TODO: Support for monotonically decreasing induction variable
if (!IsIncreasingLoopInduction(NonRdxPhi))
if (Kind == RecurKind::FindLastIV) {
if (IsSupportedLoopInduction(NonRdxPhi, Kind))
return InstDesc(I, Kind);
return InstDesc(false, I);
}

if (IsSupportedLoopInduction(NonRdxPhi, RecurKind::FindFirstIVSMin))
return InstDesc(I, RecurKind::FindFirstIVSMin);
if (IsSupportedLoopInduction(NonRdxPhi, RecurKind::FindFirstIVUMin))
return InstDesc(I, RecurKind::FindFirstIVUMin);

return InstDesc(I, RecurKind::FindLastIV);
return InstDesc(false, I);
}

RecurrenceDescriptor::InstDesc
Expand Down Expand Up @@ -875,8 +933,8 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
if (Kind == RecurKind::FAdd || Kind == RecurKind::FMul ||
Kind == RecurKind::Add || Kind == RecurKind::Mul)
return isConditionalRdxPattern(Kind, I);
if (isFindLastIVRecurrenceKind(Kind) && SE)
return isFindLastIVPattern(L, OrigPhi, I, *SE);
if (isFindIVRecurrenceKind(Kind) && SE)
return isFindIVPattern(Kind, L, OrigPhi, I, *SE);
[[fallthrough]];
case Instruction::FCmp:
case Instruction::ICmp:
Expand Down Expand Up @@ -990,6 +1048,12 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
LLVM_DEBUG(dbgs() << "Found a FindLastIV reduction PHI." << *Phi << "\n");
return true;
}
if (AddReductionVar(Phi, RecurKind::FindFirstIVUMin, TheLoop, FMF, RedDes, DB,
AC, DT, SE)) {
LLVM_DEBUG(dbgs() << "Found a FindFirstV reduction PHI." << *Phi << "\n");
return true;
}

if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT,
SE)) {
LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n");
Expand Down Expand Up @@ -1153,6 +1217,8 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
case RecurKind::SMin:
case RecurKind::UMax:
case RecurKind::UMin:
case RecurKind::FindFirstIVUMin:
case RecurKind::FindFirstIVSMin:
return Instruction::ICmp;
case RecurKind::FMax:
case RecurKind::FMin:
Expand Down
16 changes: 10 additions & 6 deletions llvm/lib/Transforms/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1244,12 +1244,16 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
Value *Start,
const RecurrenceDescriptor &Desc) {
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
Desc.getRecurrenceKind()) &&
"Unexpected reduction kind");
assert(
RecurrenceDescriptor::isFindIVRecurrenceKind(Desc.getRecurrenceKind()) &&
"Unexpected reduction kind");
Value *Sentinel = Desc.getSentinelValue();
Value *MaxRdx = Src->getType()->isVectorTy()
? Builder.CreateIntMaxReduce(Src, true)
? (Desc.getRecurrenceKind() == RecurKind::FindLastIV
? Builder.CreateIntMaxReduce(Src, true)
: Builder.CreateIntMinReduce(
Src, Desc.getRecurrenceKind() ==
RecurKind::FindFirstIVSMin))
: Src;
// Correct the final reduction result back to the start value if the maximum
// reduction is sentinel value.
Expand Down Expand Up @@ -1345,8 +1349,8 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
RecurKind Kind) {
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
"AnyOf or FindLastIV reductions are not supported.");
!RecurrenceDescriptor::isFindIVRecurrenceKind(Kind) &&
"AnyOf, FindFirstIV and FindLastIV reductions are not supported.");
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
auto *SrcTy = cast<VectorType>(Src->getType());
Type *SrcEltTy = SrcTy->getElementType();
Expand Down
Loading
Loading