Skip to content

Commit e4c27f0

Browse files
committed
[VPlan] Dispatch to multiple exit blocks via middle blocks.
A more lightweight variant of llvm#109193, which dispatches to multiple exit blocks via the middle blocks.
1 parent 79d5c6a commit e4c27f0

File tree

12 files changed

+514
-59
lines changed

12 files changed

+514
-59
lines changed

llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ class LoopVectorizationLegality {
287287
/// we can use in-order reductions.
288288
bool canVectorizeFPMath(bool EnableStrictReductions);
289289

290+
bool canVectorizeMultiCond() const;
291+
290292
/// Return true if we can vectorize this loop while folding its tail by
291293
/// masking.
292294
bool canFoldTailByMasking() const;

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ AllowStridedPointerIVs("lv-strided-pointer-ivs", cl::init(false), cl::Hidden,
4343
cl::desc("Enable recognition of non-constant strided "
4444
"pointer induction variables."));
4545

46+
static cl::opt<bool> EnableMultiCond("enable-multi-cond-vectorization",
47+
cl::init(false), cl::Hidden, cl::desc(""));
48+
4649
namespace llvm {
4750
cl::opt<bool>
4851
HintsAllowReordering("hints-allow-reordering", cl::init(true), cl::Hidden,
@@ -1378,6 +1381,8 @@ bool LoopVectorizationLegality::isFixedOrderRecurrence(
13781381
}
13791382

13801383
bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) const {
1384+
if (canVectorizeMultiCond() && BB != TheLoop->getHeader())
1385+
return true;
13811386
return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
13821387
}
13831388

@@ -1514,6 +1519,35 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() {
15141519
return true;
15151520
}
15161521

1522+
bool LoopVectorizationLegality::canVectorizeMultiCond() const {
1523+
if (!EnableMultiCond)
1524+
return false;
1525+
SmallVector<BasicBlock *> Exiting;
1526+
TheLoop->getExitingBlocks(Exiting);
1527+
if (Exiting.size() != 2 || Exiting[0] != TheLoop->getHeader() ||
1528+
Exiting[1] != TheLoop->getLoopLatch() ||
1529+
any_of(*TheLoop->getHeader(), [](Instruction &I) {
1530+
return I.mayReadFromMemory() || I.mayHaveSideEffects();
1531+
}))
1532+
return false;
1533+
CmpInst::Predicate Pred;
1534+
Value *A, *B;
1535+
if (!match(
1536+
TheLoop->getHeader()->getTerminator(),
1537+
m_Br(m_ICmp(Pred, m_Value(A), m_Value(B)), m_Value(), m_Value())) ||
1538+
Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE)
1539+
return false;
1540+
if (any_of(TheLoop->getBlocks(), [this](BasicBlock *BB) {
1541+
return any_of(*BB, [this](Instruction &I) {
1542+
return any_of(I.users(), [this](User *U) {
1543+
return !TheLoop->contains(cast<Instruction>(U)->getParent());
1544+
});
1545+
});
1546+
}))
1547+
return false;
1548+
return true;
1549+
}
1550+
15171551
// Helper function to canVectorizeLoopNestCFG.
15181552
bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp,
15191553
bool UseVPlanNativePath) {

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,9 +1362,11 @@ class LoopVectorizationCostModel {
13621362
// If we might exit from anywhere but the latch, must run the exiting
13631363
// iteration in scalar form.
13641364
if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) {
1365-
LLVM_DEBUG(
1366-
dbgs() << "LV: Loop requires scalar epilogue: multiple exits\n");
1367-
return true;
1365+
if (!Legal->canVectorizeMultiCond()) {
1366+
LLVM_DEBUG(
1367+
dbgs() << "LV: Loop requires scalar epilogue: multiple exits\n");
1368+
return true;
1369+
}
13681370
}
13691371
if (IsVectorizing && InterleaveInfo.requiresScalarEpilogue()) {
13701372
LLVM_DEBUG(dbgs() << "LV: Loop requires scalar epilogue: "
@@ -2535,8 +2537,17 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) {
25352537
LoopVectorPreHeader = OrigLoop->getLoopPreheader();
25362538
assert(LoopVectorPreHeader && "Invalid loop structure");
25372539
LoopExitBlock = OrigLoop->getUniqueExitBlock(); // may be nullptr
2538-
assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF.isVector())) &&
2539-
"multiple exit loop without required epilogue?");
2540+
if (Legal->canVectorizeMultiCond()) {
2541+
BasicBlock *Latch = OrigLoop->getLoopLatch();
2542+
BasicBlock *TrueSucc =
2543+
cast<BranchInst>(Latch->getTerminator())->getSuccessor(0);
2544+
BasicBlock *FalseSucc =
2545+
cast<BranchInst>(Latch->getTerminator())->getSuccessor(1);
2546+
LoopExitBlock = OrigLoop->contains(TrueSucc) ? FalseSucc : TrueSucc;
2547+
} else {
2548+
assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF.isVector())) &&
2549+
"multiple exit loop without required epilogue?");
2550+
}
25402551

25412552
LoopMiddleBlock =
25422553
SplitBlock(LoopVectorPreHeader, LoopVectorPreHeader->getTerminator(), DT,
@@ -2910,7 +2921,8 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State,
29102921
for (PHINode &PN : Exit->phis())
29112922
PSE.getSE()->forgetLcssaPhiWithNewPredecessor(OrigLoop, &PN);
29122923

2913-
if (Cost->requiresScalarEpilogue(VF.isVector())) {
2924+
if (Legal->canVectorizeMultiCond() ||
2925+
Cost->requiresScalarEpilogue(VF.isVector())) {
29142926
// No edge from the middle block to the unique exit block has been inserted
29152927
// and there is nothing to fix from vector loop; phis should have incoming
29162928
// from scalar loop only.
@@ -3554,7 +3566,8 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) {
35543566
TheLoop->getExitingBlocks(Exiting);
35553567
for (BasicBlock *E : Exiting) {
35563568
auto *Cmp = dyn_cast<Instruction>(E->getTerminator()->getOperand(0));
3557-
if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse())
3569+
if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse() &&
3570+
(TheLoop->getLoopLatch() == E || !Legal->canVectorizeMultiCond()))
35583571
AddToWorklistIfAllowed(Cmp);
35593572
}
35603573

@@ -7643,12 +7656,15 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
76437656
BestVPlan.execute(&State);
76447657

76457658
// 2.5 Collect reduction resume values.
7646-
auto *ExitVPBB =
7647-
cast<VPBasicBlock>(BestVPlan.getVectorLoopRegion()->getSingleSuccessor());
7648-
for (VPRecipeBase &R : *ExitVPBB) {
7649-
createAndCollectMergePhiForReduction(
7650-
dyn_cast<VPInstruction>(&R), State, OrigLoop,
7651-
State.CFG.VPBB2IRBB[ExitVPBB], ExpandedSCEVs);
7659+
VPBasicBlock *ExitVPBB = nullptr;
7660+
if (BestVPlan.getVectorLoopRegion()->getSingleSuccessor()) {
7661+
ExitVPBB = cast<VPBasicBlock>(
7662+
BestVPlan.getVectorLoopRegion()->getSingleSuccessor());
7663+
for (VPRecipeBase &R : *ExitVPBB) {
7664+
createAndCollectMergePhiForReduction(
7665+
dyn_cast<VPInstruction>(&R), State, OrigLoop,
7666+
State.CFG.VPBB2IRBB[ExitVPBB], ExpandedSCEVs);
7667+
}
76527668
}
76537669

76547670
// 2.6. Maintain Loop Hints
@@ -7674,6 +7690,7 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
76747690
LoopVectorizeHints Hints(L, true, *ORE);
76757691
Hints.setAlreadyVectorized();
76767692
}
7693+
76777694
TargetTransformInfo::UnrollingPreferences UP;
76787695
TTI.getUnrollingPreferences(L, *PSE.getSE(), UP, ORE);
76797696
if (!UP.UnrollVectorizedLoop || CanonicalIVStartValue)
@@ -7686,15 +7703,17 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
76867703
ILV.printDebugTracesAtEnd();
76877704

76887705
// 4. Adjust branch weight of the branch in the middle block.
7689-
auto *MiddleTerm =
7690-
cast<BranchInst>(State.CFG.VPBB2IRBB[ExitVPBB]->getTerminator());
7691-
if (MiddleTerm->isConditional() &&
7692-
hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
7693-
// Assume that `Count % VectorTripCount` is equally distributed.
7694-
unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue();
7695-
assert(TripCount > 0 && "trip count should not be zero");
7696-
const uint32_t Weights[] = {1, TripCount - 1};
7697-
setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
7706+
if (ExitVPBB) {
7707+
auto *MiddleTerm =
7708+
cast<BranchInst>(State.CFG.VPBB2IRBB[ExitVPBB]->getTerminator());
7709+
if (MiddleTerm->isConditional() &&
7710+
hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
7711+
// Assume that `Count % VectorTripCount` is equally distributed.
7712+
unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue();
7713+
assert(TripCount > 0 && "trip count should not be zero");
7714+
const uint32_t Weights[] = {1, TripCount - 1};
7715+
setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
7716+
}
76987717
}
76997718

77007719
return State.ExpandedSCEVs;
@@ -8079,7 +8098,7 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
80798098
// If source is an exiting block, we know the exit edge is dynamically dead
80808099
// in the vector loop, and thus we don't need to restrict the mask. Avoid
80818100
// adding uses of an otherwise potentially dead instruction.
8082-
if (OrigLoop->isLoopExiting(Src))
8101+
if (!Legal->canVectorizeMultiCond() && OrigLoop->isLoopExiting(Src))
80838102
return EdgeMaskCache[Edge] = SrcMask;
80848103

80858104
VPValue *EdgeMask = getVPValueOrAddLiveIn(BI->getCondition());
@@ -8729,6 +8748,8 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, bool HasNUW,
87298748
static SetVector<VPIRInstruction *> collectUsersInExitBlock(
87308749
Loop *OrigLoop, VPRecipeBuilder &Builder, VPlan &Plan,
87318750
const MapVector<PHINode *, InductionDescriptor> &Inductions) {
8751+
if (!Plan.getVectorLoopRegion()->getSingleSuccessor())
8752+
return {};
87328753
auto *MiddleVPBB =
87338754
cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSingleSuccessor());
87348755
// No edge from the middle block to the unique exit block has been inserted
@@ -8814,6 +8835,8 @@ static void addLiveOutsForFirstOrderRecurrences(
88148835
// TODO: Should be replaced by
88158836
// Plan->getScalarLoopRegion()->getSinglePredecessor() in the future once the
88168837
// scalar region is modeled as well.
8838+
if (!VectorRegion->getSingleSuccessor())
8839+
return;
88178840
auto *MiddleVPBB = cast<VPBasicBlock>(VectorRegion->getSingleSuccessor());
88188841
VPBasicBlock *ScalarPHVPBB = nullptr;
88198842
if (MiddleVPBB->getNumSuccessors() == 2) {
@@ -9100,10 +9123,15 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
91009123
"VPBasicBlock");
91019124
RecipeBuilder.fixHeaderPhis();
91029125

9103-
SetVector<VPIRInstruction *> ExitUsersToFix = collectUsersInExitBlock(
9104-
OrigLoop, RecipeBuilder, *Plan, Legal->getInductionVars());
9105-
addLiveOutsForFirstOrderRecurrences(*Plan, ExitUsersToFix);
9106-
addUsersInExitBlock(*Plan, ExitUsersToFix);
9126+
if (Legal->canVectorizeMultiCond()) {
9127+
VPlanTransforms::convertToMultiCond(*Plan, *PSE.getSE(), OrigLoop,
9128+
RecipeBuilder);
9129+
} else {
9130+
SetVector<VPIRInstruction *> ExitUsersToFix = collectUsersInExitBlock(
9131+
OrigLoop, RecipeBuilder, *Plan, Legal->getInductionVars());
9132+
addLiveOutsForFirstOrderRecurrences(*Plan, ExitUsersToFix);
9133+
addUsersInExitBlock(*Plan, ExitUsersToFix);
9134+
}
91079135

91089136
// ---------------------------------------------------------------------------
91099137
// Transform initial VPlan: Apply previously taken decisions, in order, to
@@ -9231,8 +9259,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
92319259
using namespace VPlanPatternMatch;
92329260
VPRegionBlock *VectorLoopRegion = Plan->getVectorLoopRegion();
92339261
VPBasicBlock *Header = VectorLoopRegion->getEntryBasicBlock();
9234-
VPBasicBlock *MiddleVPBB =
9235-
cast<VPBasicBlock>(VectorLoopRegion->getSingleSuccessor());
92369262
for (VPRecipeBase &R : Header->phis()) {
92379263
auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
92389264
if (!PhiR || !PhiR->isInLoop() || (MinVF.isScalar() && !PhiR->isOrdered()))
@@ -9251,8 +9277,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
92519277
for (VPUser *U : Cur->users()) {
92529278
auto *UserRecipe = cast<VPSingleDefRecipe>(U);
92539279
if (!UserRecipe->getParent()->getEnclosingLoopRegion()) {
9254-
assert(UserRecipe->getParent() == MiddleVPBB &&
9255-
"U must be either in the loop region or the middle block.");
92569280
continue;
92579281
}
92589282
Worklist.insert(UserRecipe);
@@ -9357,6 +9381,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
93579381
}
93589382
VPBasicBlock *LatchVPBB = VectorLoopRegion->getExitingBasicBlock();
93599383
Builder.setInsertPoint(&*LatchVPBB->begin());
9384+
if (!VectorLoopRegion->getSingleSuccessor())
9385+
return;
9386+
VPBasicBlock *MiddleVPBB =
9387+
cast<VPBasicBlock>(VectorLoopRegion->getSingleSuccessor());
93609388
VPBasicBlock::iterator IP = MiddleVPBB->getFirstNonPhi();
93619389
for (VPRecipeBase &R :
93629390
Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,14 @@ void VPIRBasicBlock::execute(VPTransformState *State) {
474474
// backedges. A backward successor is set when the branch is created.
475475
const auto &PredVPSuccessors = PredVPBB->getHierarchicalSuccessors();
476476
unsigned idx = PredVPSuccessors.front() == this ? 0 : 1;
477+
if (TermBr->getSuccessor(idx) &&
478+
PredVPBlock == getPlan()->getVectorLoopRegion() &&
479+
PredVPBlock->getNumSuccessors()) {
480+
// Update PRedBB and TermBr for BranchOnMultiCond in predecessor.
481+
PredBB = TermBr->getSuccessor(1);
482+
TermBr = cast<BranchInst>(PredBB->getTerminator());
483+
idx = 0;
484+
}
477485
assert(!TermBr->getSuccessor(idx) &&
478486
"Trying to reset an existing successor block.");
479487
TermBr->setSuccessor(idx, IRBB);
@@ -908,8 +916,8 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
908916
VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block");
909917
VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion);
910918

911-
VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
912919
if (!RequiresScalarEpilogueCheck) {
920+
VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
913921
VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
914922
return Plan;
915923
}
@@ -923,10 +931,14 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
923931
// we unconditionally branch to the scalar preheader. Do nothing.
924932
// 3) Otherwise, construct a runtime check.
925933
BasicBlock *IRExitBlock = TheLoop->getUniqueExitBlock();
926-
auto *VPExitBlock = VPIRBasicBlock::fromBasicBlock(IRExitBlock);
927-
// The connection order corresponds to the operands of the conditional branch.
928-
VPBlockUtils::insertBlockAfter(VPExitBlock, MiddleVPBB);
929-
VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
934+
if (IRExitBlock) {
935+
auto *VPExitBlock = VPIRBasicBlock::fromBasicBlock(IRExitBlock);
936+
// The connection order corresponds to the operands of the conditional
937+
// branch.
938+
VPBlockUtils::insertBlockAfter(VPExitBlock, MiddleVPBB);
939+
VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
940+
VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
941+
}
930942

931943
auto *ScalarLatchTerm = TheLoop->getLoopLatch()->getTerminator();
932944
// Here we use the same DebugLoc as the scalar loop latch terminator instead
@@ -1035,7 +1047,9 @@ void VPlan::execute(VPTransformState *State) {
10351047
// VPlan execution rather than earlier during VPlan construction.
10361048
BasicBlock *MiddleBB = State->CFG.ExitBB;
10371049
VPBasicBlock *MiddleVPBB =
1038-
cast<VPBasicBlock>(getVectorLoopRegion()->getSingleSuccessor());
1050+
getVectorLoopRegion()->getNumSuccessors() == 1
1051+
? cast<VPBasicBlock>(getVectorLoopRegion()->getSuccessors()[0])
1052+
: cast<VPBasicBlock>(getVectorLoopRegion()->getSuccessors()[1]);
10391053
// Find the VPBB for the scalar preheader, relying on the current structure
10401054
// when creating the middle block and its successrs: if there's a single
10411055
// predecessor, it must be the scalar preheader. Otherwise, the second
@@ -1048,6 +1062,10 @@ void VPlan::execute(VPTransformState *State) {
10481062
MiddleSuccs.size() == 1 ? MiddleSuccs[0] : MiddleSuccs[1]);
10491063
assert(!isa<VPIRBasicBlock>(ScalarPhVPBB) &&
10501064
"scalar preheader cannot be wrapped already");
1065+
if (ScalarPhVPBB->getNumSuccessors() != 0) {
1066+
ScalarPhVPBB = cast<VPBasicBlock>(ScalarPhVPBB->getSuccessors()[1]);
1067+
MiddleVPBB = cast<VPBasicBlock>(MiddleVPBB->getSuccessors()[1]);
1068+
}
10511069
replaceVPBBWithIRVPBB(ScalarPhVPBB, ScalarPh);
10521070
replaceVPBBWithIRVPBB(MiddleVPBB, MiddleBB);
10531071

@@ -1069,6 +1087,10 @@ void VPlan::execute(VPTransformState *State) {
10691087
VPBasicBlock *LatchVPBB = getVectorLoopRegion()->getExitingBasicBlock();
10701088
BasicBlock *VectorLatchBB = State->CFG.VPBB2IRBB[LatchVPBB];
10711089

1090+
if (!getVectorLoopRegion()->getSingleSuccessor())
1091+
VectorLatchBB =
1092+
cast<BranchInst>(VectorLatchBB->getTerminator())->getSuccessor(1);
1093+
10721094
// Fix the latch value of canonical, reduction and first-order recurrences
10731095
// phis in the vector loop.
10741096
VPBasicBlock *Header = getVectorLoopRegion()->getEntryBasicBlock();
@@ -1095,7 +1117,10 @@ void VPlan::execute(VPTransformState *State) {
10951117
// Move the last step to the end of the latch block. This ensures
10961118
// consistent placement of all induction updates.
10971119
Instruction *Inc = cast<Instruction>(Phi->getIncomingValue(1));
1098-
Inc->moveBefore(VectorLatchBB->getTerminator()->getPrevNode());
1120+
if (VectorLatchBB->getTerminator() == &*VectorLatchBB->getFirstNonPHI())
1121+
Inc->moveBefore(VectorLatchBB->getTerminator());
1122+
else
1123+
Inc->moveBefore(VectorLatchBB->getTerminator()->getPrevNode());
10991124

11001125
// Use the steps for the last part as backedge value for the induction.
11011126
if (auto *IV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&R))

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
12491249
// operand). Only generates scalar values (either for the first lane only or
12501250
// for all lanes, depending on its uses).
12511251
PtrAdd,
1252+
AnyOf,
12521253
};
12531254

12541255
private:

0 commit comments

Comments
 (0)