Skip to content

[LV] Extend FindLastIV to unsigned case #141752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions llvm/include/llvm/Analysis/IVDescriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,12 @@ enum class RecurKind {
FMulAdd, ///< Sum of float products with llvm.fmuladd(a * b + sum).
AnyOf, ///< AnyOf reduction with select(cmp(),x,y) where one of (x,y) is
///< loop invariant, and both x and y are integer type.
FindLastIV, ///< FindLast reduction with select(cmp(),x,y) where one of
///< (x,y) is increasing loop induction, and both x and y are
///< integer type.
FindLastIVSMin, ///< FindLast reduction with select(cmp(),x,y) where one of
///< (x,y) is increasing loop induction, and both x and y
///< are integer type, producing a SMin reduction.
FindLastIVUMin, ///< FindLast reduction with select(cmp(),x,y) where one of
///< (x,y) is increasing loop induction, and both x and y
///< are integer type, producting a UMin reduction.
// clang-format on
// TODO: Any_of and FindLast reduction need not be restricted to integer type
// only.
Expand Down Expand Up @@ -260,7 +263,14 @@ class RecurrenceDescriptor {
/// Returns true if the recurrence kind is of the form
/// select(cmp(),x,y) where one of (x,y) is increasing loop induction.
static bool isFindLastIVRecurrenceKind(RecurKind Kind) {
return Kind == RecurKind::FindLastIV;
return Kind == RecurKind::FindLastIVSMin ||
Kind == RecurKind::FindLastIVUMin;
}

/// Returns true if recurrece kind is a signed redux kind.
static bool isSignedRecurrenceKind(RecurKind Kind) {
return Kind == RecurKind::SMax || Kind == RecurKind::SMin ||
Kind == RecurKind::FindLastIVSMin;
}

/// Returns the type of the recurrence. This type can be narrower than the
Expand All @@ -272,8 +282,10 @@ class RecurrenceDescriptor {
Value *getSentinelValue() const {
assert(isFindLastIVRecurrenceKind(Kind) && "Unexpected recurrence kind");
Type *Ty = StartValue->getType();
return ConstantInt::get(Ty,
APInt::getSignedMinValue(Ty->getIntegerBitWidth()));
unsigned BW = Ty->getIntegerBitWidth();
return ConstantInt::get(Ty, isSignedRecurrenceKind(Kind)
? APInt::getSignedMinValue(BW)
: APInt::getMinValue(BW));
}

/// Returns a reference to the instructions used for type-promoting the
Expand Down
60 changes: 37 additions & 23 deletions llvm/lib/Analysis/IVDescriptors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
case RecurKind::UMax:
case RecurKind::UMin:
case RecurKind::AnyOf:
case RecurKind::FindLastIV:
case RecurKind::FindLastIVSMin:
case RecurKind::FindLastIVUMin:
return true;
}
return false;
Expand Down Expand Up @@ -700,47 +701,59 @@ RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
m_Value(NonRdxPhi)))))
return InstDesc(false, I);

auto IsIncreasingLoopInduction = [&](Value *V) {
// Returns a non-nullopt boolean indicating the signedness of the recurrence
// when a valid FindLastIV pattern is found.
auto GetRecurKind = [&](Value *V) -> std::optional<RecurKind> {
Type *Ty = V->getType();
if (!SE.isSCEVable(Ty))
return false;
return std::nullopt;

auto *AR = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(V));
if (!AR || AR->getLoop() != TheLoop)
return false;
return std::nullopt;

const SCEV *Step = AR->getStepRecurrence(SE);
if (!SE.isKnownPositive(Step))
return false;
return std::nullopt;

const ConstantRange IVRange = SE.getSignedRange(AR);
unsigned NumBits = Ty->getIntegerBitWidth();
// Keep the minimum value of the recurrence type as the sentinel value.
// The maximum acceptable range for the increasing induction variable,
// called the valid range, will be defined as
// [<sentinel value> + 1, <sentinel value>)
// where <sentinel value> is SignedMin(<recurrence type>)
// where <sentinel value> is [Signed|Unsigned]Min(<recurrence type>)
// TODO: This range restriction can be lifted by adding an additional
// virtual OR reduction.
const APInt Sentinel = APInt::getSignedMinValue(NumBits);
const ConstantRange ValidRange =
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
LLVM_DEBUG(dbgs() << "LV: FindLastIV valid range is " << ValidRange
<< ", and the signed range of " << *AR << " is "
<< IVRange << "\n");
// Ensure the induction variable does not wrap around by verifying that its
// range is fully contained within the valid range.
return ValidRange.contains(IVRange);
auto CheckRange = [&](bool IsSigned) {
const ConstantRange IVRange =
IsSigned ? SE.getSignedRange(AR) : SE.getUnsignedRange(AR);
unsigned NumBits = Ty->getIntegerBitWidth();
const APInt Sentinel = IsSigned ? APInt::getSignedMinValue(NumBits)
: APInt::getMinValue(NumBits);
const ConstantRange ValidRange =
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
LLVM_DEBUG(dbgs() << "LV: FindLastIV valid range is " << ValidRange
<< ", and the range of " << *AR << " is " << IVRange
<< "\n");

// Ensure the induction variable does not wrap around by verifying that
// its range is fully contained within the valid range.
return ValidRange.contains(IVRange);
};
if (CheckRange(true))
return RecurKind::FindLastIVSMin;
if (CheckRange(false))
return RecurKind::FindLastIVUMin;
return std::nullopt;
};

// We are looking for selects of the form:
// select(cmp(), phi, increasing_loop_induction) or
// select(cmp(), increasing_loop_induction, phi)
// TODO: Support for monotonically decreasing induction variable
if (!IsIncreasingLoopInduction(NonRdxPhi))
return InstDesc(false, I);
if (auto RK = GetRecurKind(NonRdxPhi))
return InstDesc(I, *RK);

return InstDesc(I, RecurKind::FindLastIV);
return InstDesc(false, I);
}

RecurrenceDescriptor::InstDesc
Expand Down Expand Up @@ -985,8 +998,8 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
<< "\n");
return true;
}
if (AddReductionVar(Phi, RecurKind::FindLastIV, TheLoop, FMF, RedDes, DB, AC,
DT, SE)) {
if (AddReductionVar(Phi, RecurKind::FindLastIVSMin, TheLoop, FMF, RedDes, DB,
AC, DT, SE)) {
LLVM_DEBUG(dbgs() << "Found a FindLastIV reduction PHI." << *Phi << "\n");
return true;
}
Expand Down Expand Up @@ -1137,7 +1150,8 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
case RecurKind::Mul:
return Instruction::Mul;
case RecurKind::AnyOf:
case RecurKind::FindLastIV:
case RecurKind::FindLastIVSMin:
case RecurKind::FindLastIVUMin:
case RecurKind::Or:
return Instruction::Or;
case RecurKind::And:
Expand Down
9 changes: 6 additions & 3 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23105,7 +23105,8 @@ class HorizontalReduction {
case RecurKind::FMul:
case RecurKind::FMulAdd:
case RecurKind::AnyOf:
case RecurKind::FindLastIV:
case RecurKind::FindLastIVSMin:
case RecurKind::FindLastIVUMin:
case RecurKind::FMaximumNum:
case RecurKind::FMinimumNum:
case RecurKind::None:
Expand Down Expand Up @@ -23239,7 +23240,8 @@ class HorizontalReduction {
case RecurKind::FMul:
case RecurKind::FMulAdd:
case RecurKind::AnyOf:
case RecurKind::FindLastIV:
case RecurKind::FindLastIVSMin:
case RecurKind::FindLastIVUMin:
case RecurKind::FMaximumNum:
case RecurKind::FMinimumNum:
case RecurKind::None:
Expand Down Expand Up @@ -23338,7 +23340,8 @@ class HorizontalReduction {
case RecurKind::FMul:
case RecurKind::FMulAdd:
case RecurKind::AnyOf:
case RecurKind::FindLastIV:
case RecurKind::FindLastIVSMin:
case RecurKind::FindLastIVUMin:
case RecurKind::FMaximumNum:
case RecurKind::FMinimumNum:
case RecurKind::None:
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
// each part of the reduction.
unsigned UF = getNumOperands() - 2;
Value *ReducedPartRdx = State.get(getOperand(2));
auto MinMaxOp = RecurrenceDescriptor::isSignedRecurrenceKind(RK)
? RecurKind::SMax
: RecurKind::UMax;
for (unsigned Part = 1; Part < UF; ++Part) {
ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
ReducedPartRdx = createMinMaxOp(Builder, MinMaxOp, ReducedPartRdx,
State.get(getOperand(2 + Part)));
}

Expand Down
Loading