Skip to content

[VPlan] Add ComputeFindLastIVResult opcode (NFC). #132689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7612,7 +7612,8 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
BasicBlock *BypassBlock) {
auto *EpiRedResult = dyn_cast<VPInstruction>(R);
if (!EpiRedResult ||
EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult)
(EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
EpiRedResult->getOpcode() != VPInstruction::ComputeFindLastIVResult))
return;

auto *EpiRedHeaderPhi =
Expand Down Expand Up @@ -9817,8 +9818,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
Builder.createSelect(Cond, OrigExitingVPV, PhiR, {}, "", FMFs);
OrigExitingVPV->replaceUsesWithIf(NewExitingVPV, [](VPUser &U, unsigned) {
return isa<VPInstruction>(&U) &&
cast<VPInstruction>(&U)->getOpcode() ==
VPInstruction::ComputeReductionResult;
(cast<VPInstruction>(&U)->getOpcode() ==
VPInstruction::ComputeReductionResult ||
cast<VPInstruction>(&U)->getOpcode() ==
VPInstruction::ComputeFindLastIVResult);
});
if (CM.usePredicatedReductionSelect(
PhiR->getRecurrenceDescriptor().getOpcode(), PhiTy))
Expand Down Expand Up @@ -9863,8 +9866,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
// also modeled in VPlan.
VPBuilder::InsertPointGuard Guard(Builder);
Builder.setInsertPoint(MiddleVPBB, IP);
auto *FinalReductionResult = Builder.createNaryOp(
VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
auto *FinalReductionResult =
Builder.createNaryOp(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
RdxDesc.getRecurrenceKind())
? VPInstruction::ComputeFindLastIVResult
: VPInstruction::ComputeReductionResult,
{PhiR, NewExitingVPV}, ExitDL);
// Update all users outside the vector region.
OrigExitingVPV->replaceUsesWithIf(
FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
BranchOnCount,
BranchOnCond,
Broadcast,
ComputeFindLastIVResult,
ComputeReductionResult,
// Takes the VPValue to extract from as first operand and the lane or part
// to extract as second operand, counting from the end starting with 1 for
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
inferScalarType(R->getOperand(1)) &&
"different types inferred for different operands");
return IntegerType::get(Ctx, 1);
case VPInstruction::ComputeFindLastIVResult:
case VPInstruction::ComputeReductionResult: {
auto *PhiR = cast<VPReductionPHIRecipe>(R->getOperand(0));
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
Expand Down
36 changes: 28 additions & 8 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,27 @@ Value *VPInstruction::generate(VPTransformState &State) {
return Builder.CreateVectorSplat(
State.VF, State.get(getOperand(0), /*IsScalar*/ true), "broadcast");
}
case VPInstruction::ComputeFindLastIVResult: {
// The recipe's operands are the reduction phi, followed by one operand for
// each part of the reduction.
unsigned UF = getNumOperands() - 1;
Value *ReducedPartRdx = State.get(getOperand(1));
for (unsigned Part = 1; Part < UF; ++Part) {
ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
State.get(getOperand(1 + Part)));
}

// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
// and will be removed by breaking up the recipe further.
auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
// Get its reduction variable descriptor.
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
RecurKind RK = RdxDesc.getRecurrenceKind();

assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK));
assert(!PhiR->isInLoop());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK));
assert(!PhiR->isInLoop());
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) && "Unexpected reduction kind");
assert(!PhiR->isInLoop() && "In-loop FindLastIV reduction is not supported yet");

and move assertions to the front of the switch case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done thanks

return createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
}
case VPInstruction::ComputeReductionResult: {
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
// and will be removed by breaking up the recipe further.
Expand All @@ -623,6 +644,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();

RecurKind RK = RdxDesc.getRecurrenceKind();
assert(!RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
"should be handled by ComputeFindLastIVResult");

Type *PhiTy = OrigPhi->getType();
// The recipe's operands are the reduction phi, followed by one operand for
Expand Down Expand Up @@ -658,9 +681,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
if (Op != Instruction::ICmp && Op != Instruction::FCmp)
ReducedPartRdx = Builder.CreateBinOp(
(Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx");
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
ReducedPartRdx =
createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx, RdxPart);
else
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
}
Expand All @@ -669,8 +689,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
// Create the reduction after the loop. Note that inloop reductions create
// the target reduction in the loop using a Reduction recipe.
if ((State.VF.isVector() ||
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) &&
!PhiR->isInLoop()) {
// TODO: Support in-order reductions based on the recurrence descriptor.
// All ops in the reduction inherit fast-math-flags from the recurrence
Expand All @@ -681,9 +700,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
ReducedPartRdx =
createAnyOfReduction(Builder, ReducedPartRdx, RdxDesc, OrigPhi);
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
ReducedPartRdx =
createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
else
ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);

Expand Down Expand Up @@ -829,6 +845,7 @@ bool VPInstruction::isVectorToScalar() const {
return getOpcode() == VPInstruction::ExtractFromEnd ||
getOpcode() == Instruction::ExtractElement ||
getOpcode() == VPInstruction::FirstActiveLane ||
getOpcode() == VPInstruction::ComputeFindLastIVResult ||
getOpcode() == VPInstruction::ComputeReductionResult ||
getOpcode() == VPInstruction::AnyOf;
}
Expand Down Expand Up @@ -1011,6 +1028,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
case VPInstruction::ExtractFromEnd:
O << "extract-from-end";
break;
case VPInstruction::ComputeFindLastIVResult:
O << "compute-find-last-iv-result";
break;
case VPInstruction::ComputeReductionResult:
O << "compute-reduction-result";
break;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
// the parts to compute the final reduction value.
VPValue *Op1;
if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
m_VPValue(), m_VPValue(Op1))) ||
match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
m_VPValue(), m_VPValue(Op1)))) {
addUniformForAllParts(cast<VPInstruction>(&R));
for (unsigned Part = 1; Part != UF; ++Part)
Expand Down