Skip to content

Commit 8ddbc01

Browse files
authored
[VPlan] Manage FindLastIV start value in ComputeFindLastIVResult (NFC) (#132690)
Keep the start value as operand of ComputeFindLastIVResult. A follow-up patch will use this to make sure the start value is frozen if needed. Depends on #132689 PR: #132690
1 parent fb993cd commit 8ddbc01

File tree

8 files changed

+41
-16
lines changed

8 files changed

+41
-16
lines changed

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src,
423423
/// Create a reduction of the given vector \p Src for a reduction of the
424424
/// kind RecurKind::IFindLastIV or RecurKind::FFindLastIV. The reduction
425425
/// operation is described by \p Desc.
426-
Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
426+
Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src, Value *Start,
427427
const RecurrenceDescriptor &Desc);
428428

429429
/// Create an ordered reduction intrinsic using the given recurrence

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,11 +1233,11 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
12331233
}
12341234

12351235
Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
1236+
Value *Start,
12361237
const RecurrenceDescriptor &Desc) {
12371238
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
12381239
Desc.getRecurrenceKind()) &&
12391240
"Unexpected reduction kind");
1240-
Value *StartVal = Desc.getRecurrenceStartValue();
12411241
Value *Sentinel = Desc.getSentinelValue();
12421242
Value *MaxRdx = Src->getType()->isVectorTy()
12431243
? Builder.CreateIntMaxReduce(Src, true)
@@ -1246,7 +1246,7 @@ Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
12461246
// reduction is sentinel value.
12471247
Value *Cmp =
12481248
Builder.CreateCmp(CmpInst::ICMP_NE, MaxRdx, Sentinel, "rdx.select.cmp");
1249-
return Builder.CreateSelect(Cmp, MaxRdx, StartVal, "rdx.select");
1249+
return Builder.CreateSelect(Cmp, MaxRdx, Start, "rdx.select");
12501250
}
12511251

12521252
Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9890,14 +9890,19 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
98909890
// bc.merge.rdx phi nodes, hence it needs to be created unconditionally here
98919891
// even for in-loop reductions, until the reduction resume value handling is
98929892
// also modeled in VPlan.
9893+
VPInstruction *FinalReductionResult;
98939894
VPBuilder::InsertPointGuard Guard(Builder);
98949895
Builder.setInsertPoint(MiddleVPBB, IP);
9895-
auto *FinalReductionResult =
9896-
Builder.createNaryOp(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
9897-
RdxDesc.getRecurrenceKind())
9898-
? VPInstruction::ComputeFindLastIVResult
9899-
: VPInstruction::ComputeReductionResult,
9900-
{PhiR, NewExitingVPV}, ExitDL);
9896+
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(
9897+
RdxDesc.getRecurrenceKind())) {
9898+
VPValue *Start = PhiR->getStartValue();
9899+
FinalReductionResult =
9900+
Builder.createNaryOp(VPInstruction::ComputeFindLastIVResult,
9901+
{PhiR, Start, NewExitingVPV}, ExitDL);
9902+
} else {
9903+
FinalReductionResult = Builder.createNaryOp(
9904+
VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
9905+
}
99019906
// Update all users outside the vector region.
99029907
OrigExitingVPV->replaceUsesWithIf(
99039908
FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
5151

5252
switch (Opcode) {
5353
case Instruction::ExtractElement:
54+
case Instruction::Freeze:
5455
return inferScalarType(R->getOperand(0));
5556
case Instruction::Select: {
5657
Type *ResTy = inferScalarType(R->getOperand(1));

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,16 @@ using BinaryVPInstruction_match =
216216
BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false,
217217
VPInstruction>;
218218

219+
template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode,
220+
bool Commutative, typename... RecipeTys>
221+
using TernaryRecipe_match = Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>,
222+
Opcode, Commutative, RecipeTys...>;
223+
224+
template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode>
225+
using TernaryVPInstruction_match =
226+
TernaryRecipe_match<Op0_t, Op1_t, Op2_t, Opcode, /*Commutative*/ false,
227+
VPInstruction>;
228+
219229
template <typename Op0_t, typename Op1_t, unsigned Opcode,
220230
bool Commutative = false>
221231
using AllBinaryRecipe_match =
@@ -234,6 +244,13 @@ m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) {
234244
return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
235245
}
236246

247+
template <unsigned Opcode, typename Op0_t, typename Op1_t, typename Op2_t>
248+
inline TernaryVPInstruction_match<Op0_t, Op1_t, Op2_t, Opcode>
249+
m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
250+
return TernaryVPInstruction_match<Op0_t, Op1_t, Op2_t, Opcode>(
251+
{Op0, Op1, Op2});
252+
}
253+
237254
template <typename Op0_t>
238255
inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not>
239256
m_Not(const Op0_t &Op0) {

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -627,14 +627,15 @@ Value *VPInstruction::generate(VPTransformState &State) {
627627

628628
// The recipe's operands are the reduction phi, followed by one operand for
629629
// each part of the reduction.
630-
unsigned UF = getNumOperands() - 1;
631-
Value *ReducedPartRdx = State.get(getOperand(1));
630+
unsigned UF = getNumOperands() - 2;
631+
Value *ReducedPartRdx = State.get(getOperand(2));
632632
for (unsigned Part = 1; Part < UF; ++Part) {
633633
ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
634-
State.get(getOperand(1 + Part)));
634+
State.get(getOperand(2 + Part)));
635635
}
636636

637-
return createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
637+
return createFindLastIVReduction(Builder, ReducedPartRdx,
638+
State.get(getOperand(1), true), RdxDesc);
638639
}
639640
case VPInstruction::ComputeReductionResult: {
640641
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
@@ -951,6 +952,8 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
951952
return true;
952953
case VPInstruction::PtrAdd:
953954
return Op == getOperand(0) || vputils::onlyFirstLaneUsed(this);
955+
case VPInstruction::ComputeFindLastIVResult:
956+
return Op == getOperand(1);
954957
};
955958
llvm_unreachable("switch should return");
956959
}
@@ -1592,7 +1595,6 @@ void VPWidenRecipe::execute(VPTransformState &State) {
15921595
}
15931596
case Instruction::Freeze: {
15941597
Value *Op = State.get(getOperand(0));
1595-
15961598
Value *Freeze = Builder.CreateFreeze(Op);
15971599
State.set(this, Freeze);
15981600
break;

llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
350350
if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
351351
m_VPValue(), m_VPValue(Op1))) ||
352352
match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
353-
m_VPValue(), m_VPValue(Op1)))) {
353+
m_VPValue(), m_VPValue(), m_VPValue(Op1)))) {
354354
addUniformForAllParts(cast<VPInstruction>(&R));
355355
for (unsigned Part = 1; Part != UF; ++Part)
356356
R.addOperand(getValueForPart(Op1, Part));

llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ define i64 @find_last_iv(ptr %a, i64 %n, i64 %start) {
234234
; CHECK-NEXT: Successor(s): middle.block
235235
; CHECK-EMPTY:
236236
; CHECK-NEXT: middle.block:
237-
; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-find-last-iv-result ir<%rdx>, ir<%cond>
237+
; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-find-last-iv-result ir<%rdx>, ir<%start>, ir<%cond>
238238
; CHECK-NEXT: EMIT vp<[[EXT:%.+]]> = extract-from-end vp<[[RDX_RES]]>, ir<1>
239239
; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<{{.+}}>
240240
; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n>

0 commit comments

Comments
 (0)