Skip to content

Commit 0d0369e

Browse files
fhahnpawosm-arm
authored andcommitted
[VPlan] Add ComputeFindLastIVResult opcode (NFC). (#132689)
This moves the logic for computing the FindLastIV reduction result to its own opcode. A follow-up patch will update the new opcode to also take the start value, to fix llvm/llvm-project#126836. PR: llvm/llvm-project#132689
1 parent f14add5 commit 0d0369e

File tree

6 files changed

+47
-14
lines changed

6 files changed

+47
-14
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7549,7 +7549,8 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
75497549
BasicBlock *BypassBlock) {
75507550
auto *EpiRedResult = dyn_cast<VPInstruction>(R);
75517551
if (!EpiRedResult ||
7552-
EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult)
7552+
(EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
7553+
EpiRedResult->getOpcode() != VPInstruction::ComputeFindLastIVResult))
75537554
return;
75547555

75557556
auto *EpiRedHeaderPhi =
@@ -9690,8 +9691,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
96909691
Builder.createSelect(Cond, OrigExitingVPV, PhiR, {}, "", FMFs);
96919692
OrigExitingVPV->replaceUsesWithIf(NewExitingVPV, [](VPUser &U, unsigned) {
96929693
return isa<VPInstruction>(&U) &&
9693-
cast<VPInstruction>(&U)->getOpcode() ==
9694-
VPInstruction::ComputeReductionResult;
9694+
(cast<VPInstruction>(&U)->getOpcode() ==
9695+
VPInstruction::ComputeReductionResult ||
9696+
cast<VPInstruction>(&U)->getOpcode() ==
9697+
VPInstruction::ComputeFindLastIVResult);
96959698
});
96969699
if (CM.usePredicatedReductionSelect(
96979700
PhiR->getRecurrenceDescriptor().getOpcode(), PhiTy))
@@ -9736,8 +9739,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
97369739
// also modeled in VPlan.
97379740
VPBuilder::InsertPointGuard Guard(Builder);
97389741
Builder.setInsertPoint(MiddleVPBB, IP);
9739-
auto *FinalReductionResult = Builder.createNaryOp(
9740-
VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
9742+
auto *FinalReductionResult =
9743+
Builder.createNaryOp(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
9744+
RdxDesc.getRecurrenceKind())
9745+
? VPInstruction::ComputeFindLastIVResult
9746+
: VPInstruction::ComputeReductionResult,
9747+
{PhiR, NewExitingVPV}, ExitDL);
97419748
// Update all users outside the vector region.
97429749
OrigExitingVPV->replaceUsesWithIf(
97439750
FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
12131213
CanonicalIVIncrementForPart,
12141214
BranchOnCount,
12151215
BranchOnCond,
1216+
ComputeFindLastIVResult,
12161217
ComputeReductionResult,
12171218
// Takes the VPValue to extract from as first operand and the lane or part
12181219
// to extract as second operand, counting from the end starting with 1 for

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
6464
inferScalarType(R->getOperand(1)) &&
6565
"different types inferred for different operands");
6666
return IntegerType::get(Ctx, 1);
67+
case VPInstruction::ComputeFindLastIVResult:
6768
case VPInstruction::ComputeReductionResult: {
6869
auto *PhiR = cast<VPReductionPHIRecipe>(R->getOperand(0));
6970
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,29 @@ Value *VPInstruction::generate(VPTransformState &State) {
575575
Builder.GetInsertBlock()->getTerminator()->eraseFromParent();
576576
return CondBr;
577577
}
578+
case VPInstruction::ComputeFindLastIVResult: {
579+
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
580+
// and will be removed by breaking up the recipe further.
581+
auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
582+
// Get its reduction variable descriptor.
583+
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
584+
RecurKind RK = RdxDesc.getRecurrenceKind();
585+
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
586+
"Unexpected reduction kind");
587+
assert(!PhiR->isInLoop() &&
588+
"In-loop FindLastIV reduction is not supported yet");
589+
590+
// The recipe's operands are the reduction phi, followed by one operand for
591+
// each part of the reduction.
592+
unsigned UF = getNumOperands() - 1;
593+
Value *ReducedPartRdx = State.get(getOperand(1));
594+
for (unsigned Part = 1; Part < UF; ++Part) {
595+
ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
596+
State.get(getOperand(1 + Part)));
597+
}
598+
599+
return createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
600+
}
578601
case VPInstruction::ComputeReductionResult: {
579602
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
580603
// and will be removed by breaking up the recipe further.
@@ -584,6 +607,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
584607
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
585608

586609
RecurKind RK = RdxDesc.getRecurrenceKind();
610+
assert(!RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
611+
"should be handled by ComputeFindLastIVResult");
587612

588613
Type *PhiTy = OrigPhi->getType();
589614
// The recipe's operands are the reduction phi, followed by one operand for
@@ -619,9 +644,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
619644
if (Op != Instruction::ICmp && Op != Instruction::FCmp)
620645
ReducedPartRdx = Builder.CreateBinOp(
621646
(Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx");
622-
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
623-
ReducedPartRdx =
624-
createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx, RdxPart);
625647
else
626648
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
627649
}
@@ -630,8 +652,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
630652
// Create the reduction after the loop. Note that inloop reductions create
631653
// the target reduction in the loop using a Reduction recipe.
632654
if ((State.VF.isVector() ||
633-
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
634-
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
655+
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) &&
635656
!PhiR->isInLoop()) {
636657
// TODO: Support in-order reductions based on the recurrence descriptor.
637658
// All ops in the reduction inherit fast-math-flags from the recurrence
@@ -642,9 +663,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
642663
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
643664
ReducedPartRdx =
644665
createAnyOfReduction(Builder, ReducedPartRdx, RdxDesc, OrigPhi);
645-
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
646-
ReducedPartRdx =
647-
createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
648666
else
649667
ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
650668

@@ -742,6 +760,7 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
742760

743761
bool VPInstruction::isVectorToScalar() const {
744762
return getOpcode() == VPInstruction::ExtractFromEnd ||
763+
getOpcode() == VPInstruction::ComputeFindLastIVResult ||
745764
getOpcode() == VPInstruction::ComputeReductionResult ||
746765
getOpcode() == VPInstruction::AnyOf;
747766
}
@@ -913,6 +932,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
913932
case VPInstruction::ExtractFromEnd:
914933
O << "extract-from-end";
915934
break;
935+
case VPInstruction::ComputeFindLastIVResult:
936+
O << "compute-find-last-iv-result";
937+
break;
916938
case VPInstruction::ComputeReductionResult:
917939
O << "compute-reduction-result";
918940
break;

llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
348348
// the parts to compute the final reduction value.
349349
VPValue *Op1;
350350
if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
351+
m_VPValue(), m_VPValue(Op1))) ||
352+
match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
351353
m_VPValue(), m_VPValue(Op1)))) {
352354
addUniformForAllParts(cast<VPInstruction>(&R));
353355
for (unsigned Part = 1; Part != UF; ++Part)

llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ define i64 @find_last_iv(ptr %a, i64 %n, i64 %start) {
234234
; CHECK-NEXT: Successor(s): middle.block
235235
; CHECK-EMPTY:
236236
; CHECK-NEXT: middle.block:
237-
; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-reduction-result ir<%rdx>, ir<%cond>
237+
; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-find-last-iv-result ir<%rdx>, ir<%cond>
238238
; CHECK-NEXT: EMIT vp<[[EXT:%.+]]> = extract-from-end vp<[[RDX_RES]]>, ir<1>
239239
; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<{{.+}}>
240240
; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n>

0 commit comments

Comments
 (0)