Skip to content

Commit 3675c5f

Browse files
fhahnpawosm-arm
authored andcommitted
[VPlan] Add ComputeFindLastIVResult opcode (NFC). (llvm#132689)
This moves the logic for computing the FindLastIV reduction result to its own opcode. A follow-up patch will update the new opcode to also take the start value, to fix llvm#126836. PR: llvm#132689
1 parent e4598e0 commit 3675c5f

File tree

6 files changed

+47
-14
lines changed

6 files changed

+47
-14
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7514,7 +7514,8 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
75147514
BasicBlock *BypassBlock) {
75157515
auto *EpiRedResult = dyn_cast<VPInstruction>(R);
75167516
if (!EpiRedResult ||
7517-
EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult)
7517+
(EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
7518+
EpiRedResult->getOpcode() != VPInstruction::ComputeFindLastIVResult))
75187519
return;
75197520

75207521
auto *EpiRedHeaderPhi =
@@ -9643,8 +9644,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
96439644
Builder.createSelect(Cond, OrigExitingVPV, PhiR, {}, "", FMFs);
96449645
OrigExitingVPV->replaceUsesWithIf(NewExitingVPV, [](VPUser &U, unsigned) {
96459646
return isa<VPInstruction>(&U) &&
9646-
cast<VPInstruction>(&U)->getOpcode() ==
9647-
VPInstruction::ComputeReductionResult;
9647+
(cast<VPInstruction>(&U)->getOpcode() ==
9648+
VPInstruction::ComputeReductionResult ||
9649+
cast<VPInstruction>(&U)->getOpcode() ==
9650+
VPInstruction::ComputeFindLastIVResult);
96489651
});
96499652
if (CM.usePredicatedReductionSelect(
96509653
PhiR->getRecurrenceDescriptor().getOpcode(), PhiTy))
@@ -9689,8 +9692,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
96899692
// also modeled in VPlan.
96909693
VPBuilder::InsertPointGuard Guard(Builder);
96919694
Builder.setInsertPoint(MiddleVPBB, IP);
9692-
auto *FinalReductionResult = Builder.createNaryOp(
9693-
VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
9695+
auto *FinalReductionResult =
9696+
Builder.createNaryOp(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
9697+
RdxDesc.getRecurrenceKind())
9698+
? VPInstruction::ComputeFindLastIVResult
9699+
: VPInstruction::ComputeReductionResult,
9700+
{PhiR, NewExitingVPV}, ExitDL);
96949701
// Update all users outside the vector region.
96959702
OrigExitingVPV->replaceUsesWithIf(
96969703
FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
12091209
CanonicalIVIncrementForPart,
12101210
BranchOnCount,
12111211
BranchOnCond,
1212+
ComputeFindLastIVResult,
12121213
ComputeReductionResult,
12131214
// Takes the VPValue to extract from as first operand and the lane or part
12141215
// to extract as second operand, counting from the end starting with 1 for

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
6464
inferScalarType(R->getOperand(1)) &&
6565
"different types inferred for different operands");
6666
return IntegerType::get(Ctx, 1);
67+
case VPInstruction::ComputeFindLastIVResult:
6768
case VPInstruction::ComputeReductionResult: {
6869
auto *PhiR = cast<VPReductionPHIRecipe>(R->getOperand(0));
6970
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,29 @@ Value *VPInstruction::generate(VPTransformState &State) {
575575
Builder.GetInsertBlock()->getTerminator()->eraseFromParent();
576576
return CondBr;
577577
}
578+
case VPInstruction::ComputeFindLastIVResult: {
579+
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
580+
// and will be removed by breaking up the recipe further.
581+
auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
582+
// Get its reduction variable descriptor.
583+
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
584+
RecurKind RK = RdxDesc.getRecurrenceKind();
585+
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
586+
"Unexpected reduction kind");
587+
assert(!PhiR->isInLoop() &&
588+
"In-loop FindLastIV reduction is not supported yet");
589+
590+
// The recipe's operands are the reduction phi, followed by one operand for
591+
// each part of the reduction.
592+
unsigned UF = getNumOperands() - 1;
593+
Value *ReducedPartRdx = State.get(getOperand(1));
594+
for (unsigned Part = 1; Part < UF; ++Part) {
595+
ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
596+
State.get(getOperand(1 + Part)));
597+
}
598+
599+
return createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
600+
}
578601
case VPInstruction::ComputeReductionResult: {
579602
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
580603
// and will be removed by breaking up the recipe further.
@@ -584,6 +607,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
584607
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
585608

586609
RecurKind RK = RdxDesc.getRecurrenceKind();
610+
assert(!RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
611+
"should be handled by ComputeFindLastIVResult");
587612

588613
Type *PhiTy = OrigPhi->getType();
589614
// The recipe's operands are the reduction phi, followed by one operand for
@@ -619,9 +644,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
619644
if (Op != Instruction::ICmp && Op != Instruction::FCmp)
620645
ReducedPartRdx = Builder.CreateBinOp(
621646
(Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx");
622-
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
623-
ReducedPartRdx =
624-
createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx, RdxPart);
625647
else
626648
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
627649
}
@@ -630,8 +652,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
630652
// Create the reduction after the loop. Note that inloop reductions create
631653
// the target reduction in the loop using a Reduction recipe.
632654
if ((State.VF.isVector() ||
633-
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
634-
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
655+
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) &&
635656
!PhiR->isInLoop()) {
636657
// TODO: Support in-order reductions based on the recurrence descriptor.
637658
// All ops in the reduction inherit fast-math-flags from the recurrence
@@ -642,9 +663,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
642663
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
643664
ReducedPartRdx =
644665
createAnyOfReduction(Builder, ReducedPartRdx, RdxDesc, OrigPhi);
645-
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
646-
ReducedPartRdx =
647-
createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
648666
else
649667
ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
650668

@@ -742,6 +760,7 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
742760

743761
bool VPInstruction::isVectorToScalar() const {
744762
return getOpcode() == VPInstruction::ExtractFromEnd ||
763+
getOpcode() == VPInstruction::ComputeFindLastIVResult ||
745764
getOpcode() == VPInstruction::ComputeReductionResult ||
746765
getOpcode() == VPInstruction::AnyOf;
747766
}
@@ -913,6 +932,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
913932
case VPInstruction::ExtractFromEnd:
914933
O << "extract-from-end";
915934
break;
935+
case VPInstruction::ComputeFindLastIVResult:
936+
O << "compute-find-last-iv-result";
937+
break;
916938
case VPInstruction::ComputeReductionResult:
917939
O << "compute-reduction-result";
918940
break;

llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
348348
// the parts to compute the final reduction value.
349349
VPValue *Op1;
350350
if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
351+
m_VPValue(), m_VPValue(Op1))) ||
352+
match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
351353
m_VPValue(), m_VPValue(Op1)))) {
352354
addUniformForAllParts(cast<VPInstruction>(&R));
353355
for (unsigned Part = 1; Part != UF; ++Part)

llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ define i64 @find_last_iv(ptr %a, i64 %n, i64 %start) {
234234
; CHECK-NEXT: Successor(s): middle.block
235235
; CHECK-EMPTY:
236236
; CHECK-NEXT: middle.block:
237-
; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-reduction-result ir<%rdx>, ir<%cond>
237+
; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-find-last-iv-result ir<%rdx>, ir<%cond>
238238
; CHECK-NEXT: EMIT vp<[[EXT:%.+]]> = extract-from-end vp<[[RDX_RES]]>, ir<1>
239239
; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<{{.+}}>
240240
; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n>

0 commit comments

Comments
 (0)