Skip to content

[VPlan] Add ComputeFindLastIVResult opcode (NFC). #132689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7610,7 +7610,8 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
BasicBlock *BypassBlock) {
auto *EpiRedResult = dyn_cast<VPInstruction>(R);
if (!EpiRedResult ||
EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult)
(EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
EpiRedResult->getOpcode() != VPInstruction::ComputeFindLastIVResult))
return;

auto *EpiRedHeaderPhi =
Expand Down Expand Up @@ -9819,8 +9820,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
Builder.createSelect(Cond, OrigExitingVPV, PhiR, {}, "", FMFs);
OrigExitingVPV->replaceUsesWithIf(NewExitingVPV, [](VPUser &U, unsigned) {
return isa<VPInstruction>(&U) &&
cast<VPInstruction>(&U)->getOpcode() ==
VPInstruction::ComputeReductionResult;
(cast<VPInstruction>(&U)->getOpcode() ==
VPInstruction::ComputeReductionResult ||
cast<VPInstruction>(&U)->getOpcode() ==
VPInstruction::ComputeFindLastIVResult);
});
if (CM.usePredicatedReductionSelect(
PhiR->getRecurrenceDescriptor().getOpcode(), PhiTy))
Expand Down Expand Up @@ -9865,8 +9868,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
// also modeled in VPlan.
VPBuilder::InsertPointGuard Guard(Builder);
Builder.setInsertPoint(MiddleVPBB, IP);
auto *FinalReductionResult = Builder.createNaryOp(
VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
auto *FinalReductionResult =
Builder.createNaryOp(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
RdxDesc.getRecurrenceKind())
? VPInstruction::ComputeFindLastIVResult
: VPInstruction::ComputeReductionResult,
{PhiR, NewExitingVPV}, ExitDL);
// Update all users outside the vector region.
OrigExitingVPV->replaceUsesWithIf(
FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
BranchOnCount,
BranchOnCond,
Broadcast,
ComputeFindLastIVResult,
ComputeReductionResult,
// Takes the VPValue to extract from as first operand and the lane or part
// to extract as second operand, counting from the end starting with 1 for
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
inferScalarType(R->getOperand(1)) &&
"different types inferred for different operands");
return IntegerType::get(Ctx, 1);
case VPInstruction::ComputeFindLastIVResult:
case VPInstruction::ComputeReductionResult: {
auto *PhiR = cast<VPReductionPHIRecipe>(R->getOperand(0));
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
Expand Down
38 changes: 30 additions & 8 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,29 @@ Value *VPInstruction::generate(VPTransformState &State) {
return Builder.CreateVectorSplat(
State.VF, State.get(getOperand(0), /*IsScalar*/ true), "broadcast");
}
case VPInstruction::ComputeFindLastIVResult: {
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
// and will be removed by breaking up the recipe further.
auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
// Get its reduction variable descriptor.
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
RecurKind RK = RdxDesc.getRecurrenceKind();
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
"Unexpected reduction kind");
assert(!PhiR->isInLoop() &&
"In-loop FindLastIV reduction is not supported yet");

// The recipe's operands are the reduction phi, followed by one operand for
// each part of the reduction.
unsigned UF = getNumOperands() - 1;
Value *ReducedPartRdx = State.get(getOperand(1));
for (unsigned Part = 1; Part < UF; ++Part) {
ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
State.get(getOperand(1 + Part)));
}

return createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
}
case VPInstruction::ComputeReductionResult: {
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
// and will be removed by breaking up the recipe further.
Expand All @@ -622,6 +645,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();

RecurKind RK = RdxDesc.getRecurrenceKind();
assert(!RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
"should be handled by ComputeFindLastIVResult");

Type *PhiTy = OrigPhi->getType();
// The recipe's operands are the reduction phi, followed by one operand for
Expand Down Expand Up @@ -657,9 +682,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
if (Op != Instruction::ICmp && Op != Instruction::FCmp)
ReducedPartRdx = Builder.CreateBinOp(
(Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx");
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
ReducedPartRdx =
createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx, RdxPart);
else
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
}
Expand All @@ -668,8 +690,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
// Create the reduction after the loop. Note that inloop reductions create
// the target reduction in the loop using a Reduction recipe.
if ((State.VF.isVector() ||
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) &&
!PhiR->isInLoop()) {
// TODO: Support in-order reductions based on the recurrence descriptor.
// All ops in the reduction inherit fast-math-flags from the recurrence
Expand All @@ -680,9 +701,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
ReducedPartRdx =
createAnyOfReduction(Builder, ReducedPartRdx, RdxDesc, OrigPhi);
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
ReducedPartRdx =
createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
else
ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);

Expand Down Expand Up @@ -828,6 +846,7 @@ bool VPInstruction::isVectorToScalar() const {
return getOpcode() == VPInstruction::ExtractFromEnd ||
getOpcode() == Instruction::ExtractElement ||
getOpcode() == VPInstruction::FirstActiveLane ||
getOpcode() == VPInstruction::ComputeFindLastIVResult ||
getOpcode() == VPInstruction::ComputeReductionResult ||
getOpcode() == VPInstruction::AnyOf;
}
Expand Down Expand Up @@ -1010,6 +1029,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
case VPInstruction::ExtractFromEnd:
O << "extract-from-end";
break;
case VPInstruction::ComputeFindLastIVResult:
O << "compute-find-last-iv-result";
break;
case VPInstruction::ComputeReductionResult:
O << "compute-reduction-result";
break;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
// the parts to compute the final reduction value.
VPValue *Op1;
if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
m_VPValue(), m_VPValue(Op1))) ||
match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
m_VPValue(), m_VPValue(Op1)))) {
addUniformForAllParts(cast<VPInstruction>(&R));
for (unsigned Part = 1; Part != UF; ++Part)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ define i64 @find_last_iv(ptr %a, i64 %n, i64 %start) {
; CHECK-NEXT: Successor(s): middle.block
; CHECK-EMPTY:
; CHECK-NEXT: middle.block:
; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-reduction-result ir<%rdx>, ir<%cond>
; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-find-last-iv-result ir<%rdx>, ir<%cond>
; CHECK-NEXT: EMIT vp<[[EXT:%.+]]> = extract-from-end vp<[[RDX_RES]]>, ir<1>
; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<{{.+}}>
; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n>
Expand Down