Skip to content

Commit 420c056

Browse files
authored
[VPlan] Add ComputeFindLastIVResult opcode (NFC). (#132689)
This moves the logic for computing the FindLastIV reduction result to its own opcode. A follow-up patch will update the new opcode to also take the start value, to fix #126836. PR: #132689
1 parent 1a7402d commit 420c056

File tree

6 files changed

+47
-14
lines changed

6 files changed

+47
-14
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7610,7 +7610,8 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
76107610
BasicBlock *BypassBlock) {
76117611
auto *EpiRedResult = dyn_cast<VPInstruction>(R);
76127612
if (!EpiRedResult ||
7613-
EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult)
7613+
(EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
7614+
EpiRedResult->getOpcode() != VPInstruction::ComputeFindLastIVResult))
76147615
return;
76157616

76167617
auto *EpiRedHeaderPhi =
@@ -9819,8 +9820,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
98199820
Builder.createSelect(Cond, OrigExitingVPV, PhiR, {}, "", FMFs);
98209821
OrigExitingVPV->replaceUsesWithIf(NewExitingVPV, [](VPUser &U, unsigned) {
98219822
return isa<VPInstruction>(&U) &&
9822-
cast<VPInstruction>(&U)->getOpcode() ==
9823-
VPInstruction::ComputeReductionResult;
9823+
(cast<VPInstruction>(&U)->getOpcode() ==
9824+
VPInstruction::ComputeReductionResult ||
9825+
cast<VPInstruction>(&U)->getOpcode() ==
9826+
VPInstruction::ComputeFindLastIVResult);
98249827
});
98259828
if (CM.usePredicatedReductionSelect(
98269829
PhiR->getRecurrenceDescriptor().getOpcode(), PhiTy))
@@ -9865,8 +9868,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
98659868
// also modeled in VPlan.
98669869
VPBuilder::InsertPointGuard Guard(Builder);
98679870
Builder.setInsertPoint(MiddleVPBB, IP);
9868-
auto *FinalReductionResult = Builder.createNaryOp(
9869-
VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
9871+
auto *FinalReductionResult =
9872+
Builder.createNaryOp(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
9873+
RdxDesc.getRecurrenceKind())
9874+
? VPInstruction::ComputeFindLastIVResult
9875+
: VPInstruction::ComputeReductionResult,
9876+
{PhiR, NewExitingVPV}, ExitDL);
98709877
// Update all users outside the vector region.
98719878
OrigExitingVPV->replaceUsesWithIf(
98729879
FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
865865
BranchOnCount,
866866
BranchOnCond,
867867
Broadcast,
868+
ComputeFindLastIVResult,
868869
ComputeReductionResult,
869870
// Takes the VPValue to extract from as first operand and the lane or part
870871
// to extract as second operand, counting from the end starting with 1 for

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
6666
inferScalarType(R->getOperand(1)) &&
6767
"different types inferred for different operands");
6868
return IntegerType::get(Ctx, 1);
69+
case VPInstruction::ComputeFindLastIVResult:
6970
case VPInstruction::ComputeReductionResult: {
7071
auto *PhiR = cast<VPReductionPHIRecipe>(R->getOperand(0));
7172
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,29 @@ Value *VPInstruction::generate(VPTransformState &State) {
613613
return Builder.CreateVectorSplat(
614614
State.VF, State.get(getOperand(0), /*IsScalar*/ true), "broadcast");
615615
}
616+
case VPInstruction::ComputeFindLastIVResult: {
617+
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
618+
// and will be removed by breaking up the recipe further.
619+
auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
620+
// Get its reduction variable descriptor.
621+
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
622+
RecurKind RK = RdxDesc.getRecurrenceKind();
623+
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
624+
"Unexpected reduction kind");
625+
assert(!PhiR->isInLoop() &&
626+
"In-loop FindLastIV reduction is not supported yet");
627+
628+
// The recipe's operands are the reduction phi, followed by one operand for
629+
// each part of the reduction.
630+
unsigned UF = getNumOperands() - 1;
631+
Value *ReducedPartRdx = State.get(getOperand(1));
632+
for (unsigned Part = 1; Part < UF; ++Part) {
633+
ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
634+
State.get(getOperand(1 + Part)));
635+
}
636+
637+
return createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
638+
}
616639
case VPInstruction::ComputeReductionResult: {
617640
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
618641
// and will be removed by breaking up the recipe further.
@@ -622,6 +645,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
622645
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
623646

624647
RecurKind RK = RdxDesc.getRecurrenceKind();
648+
assert(!RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
649+
"should be handled by ComputeFindLastIVResult");
625650

626651
Type *PhiTy = OrigPhi->getType();
627652
// The recipe's operands are the reduction phi, followed by one operand for
@@ -657,9 +682,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
657682
if (Op != Instruction::ICmp && Op != Instruction::FCmp)
658683
ReducedPartRdx = Builder.CreateBinOp(
659684
(Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx");
660-
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
661-
ReducedPartRdx =
662-
createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx, RdxPart);
663685
else
664686
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
665687
}
@@ -668,8 +690,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
668690
// Create the reduction after the loop. Note that inloop reductions create
669691
// the target reduction in the loop using a Reduction recipe.
670692
if ((State.VF.isVector() ||
671-
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
672-
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
693+
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) &&
673694
!PhiR->isInLoop()) {
674695
// TODO: Support in-order reductions based on the recurrence descriptor.
675696
// All ops in the reduction inherit fast-math-flags from the recurrence
@@ -680,9 +701,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
680701
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
681702
ReducedPartRdx =
682703
createAnyOfReduction(Builder, ReducedPartRdx, RdxDesc, OrigPhi);
683-
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
684-
ReducedPartRdx =
685-
createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
686704
else
687705
ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
688706

@@ -828,6 +846,7 @@ bool VPInstruction::isVectorToScalar() const {
828846
return getOpcode() == VPInstruction::ExtractFromEnd ||
829847
getOpcode() == Instruction::ExtractElement ||
830848
getOpcode() == VPInstruction::FirstActiveLane ||
849+
getOpcode() == VPInstruction::ComputeFindLastIVResult ||
831850
getOpcode() == VPInstruction::ComputeReductionResult ||
832851
getOpcode() == VPInstruction::AnyOf;
833852
}
@@ -1010,6 +1029,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
10101029
case VPInstruction::ExtractFromEnd:
10111030
O << "extract-from-end";
10121031
break;
1032+
case VPInstruction::ComputeFindLastIVResult:
1033+
O << "compute-find-last-iv-result";
1034+
break;
10131035
case VPInstruction::ComputeReductionResult:
10141036
O << "compute-reduction-result";
10151037
break;

llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
348348
// the parts to compute the final reduction value.
349349
VPValue *Op1;
350350
if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
351+
m_VPValue(), m_VPValue(Op1))) ||
352+
match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
351353
m_VPValue(), m_VPValue(Op1)))) {
352354
addUniformForAllParts(cast<VPInstruction>(&R));
353355
for (unsigned Part = 1; Part != UF; ++Part)

llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ define i64 @find_last_iv(ptr %a, i64 %n, i64 %start) {
234234
; CHECK-NEXT: Successor(s): middle.block
235235
; CHECK-EMPTY:
236236
; CHECK-NEXT: middle.block:
237-
; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-reduction-result ir<%rdx>, ir<%cond>
237+
; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-find-last-iv-result ir<%rdx>, ir<%cond>
238238
; CHECK-NEXT: EMIT vp<[[EXT:%.+]]> = extract-from-end vp<[[RDX_RES]]>, ir<1>
239239
; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<{{.+}}>
240240
; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n>

0 commit comments

Comments
 (0)