Skip to content

Commit 7bd2019

Browse files
committed
Reapply "[LoopVectorizer] Add support for chaining partial reductions (llvm#120272)" (llvm#124198)
1 parent d6e0798 commit 7bd2019

File tree

4 files changed

+1072
-25
lines changed

4 files changed

+1072
-25
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8684,12 +8684,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
86848684
/// are valid so recipes can be formed later.
86858685
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
86868686
// Find all possible partial reductions.
8687-
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8687+
SmallVector<std::pair<PartialReductionChain, unsigned>>
86888688
PartialReductionChains;
8689-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8690-
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8691-
getScaledReduction(Phi, RdxDesc, Range))
8692-
PartialReductionChains.push_back(*Pair);
8689+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
8690+
if (auto SR = getScaledReduction(Phi, RdxDesc.getLoopExitInstr(), Range))
8691+
PartialReductionChains.append(*SR);
8692+
}
86938693

86948694
// A partial reduction is invalid if any of its extends are used by
86958695
// something that isn't another partial reduction. This is because the
@@ -8717,26 +8717,44 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
87178717
}
87188718
}
87198719

8720-
std::optional<std::pair<PartialReductionChain, unsigned>>
8721-
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8722-
const RecurrenceDescriptor &Rdx,
8720+
std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
8721+
VPRecipeBuilder::getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
87238722
VFRange &Range) {
8723+
8724+
if (!CM.TheLoop->contains(RdxExitInstr))
8725+
return std::nullopt;
8726+
87248727
// TODO: Allow scaling reductions when predicating. The select at
87258728
// the end of the loop chooses between the phi value and most recent
87268729
// reduction result, both of which have different VFs to the active lane
87278730
// mask when scaling.
8728-
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8731+
if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr->getParent()))
87298732
return std::nullopt;
87308733

8731-
auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8734+
auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
87328735
if (!Update)
87338736
return std::nullopt;
87348737

87358738
Value *Op = Update->getOperand(0);
87368739
Value *PhiOp = Update->getOperand(1);
8737-
if (Op == PHI) {
8738-
Op = Update->getOperand(1);
8739-
PhiOp = Update->getOperand(0);
8740+
if (Op == PHI)
8741+
std::swap(Op, PhiOp);
8742+
8743+
SmallVector<std::pair<PartialReductionChain, unsigned>> Chains;
8744+
8745+
// Try and get a scaled reduction from the first non-phi operand.
8746+
// If one is found, we use the discovered reduction instruction in
8747+
// place of the accumulator for costing.
8748+
if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8749+
if (auto SR0 = getScaledReduction(PHI, OpInst, Range)) {
8750+
Chains.append(*SR0);
8751+
PHI = SR0->rbegin()->first.Reduction;
8752+
8753+
Op = Update->getOperand(0);
8754+
PhiOp = Update->getOperand(1);
8755+
if (Op == PHI)
8756+
std::swap(Op, PhiOp);
8757+
}
87408758
}
87418759
if (PhiOp != PHI)
87428760
return std::nullopt;
@@ -8759,7 +8777,7 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87598777
TTI::PartialReductionExtendKind OpBExtend =
87608778
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
87618779

8762-
PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8780+
PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
87638781

87648782
unsigned TargetScaleFactor =
87658783
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8774,9 +8792,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87748792
return Cost.isValid();
87758793
},
87768794
Range))
8777-
return std::make_pair(Chain, TargetScaleFactor);
8795+
Chains.push_back(std::make_pair(Chain, TargetScaleFactor));
87788796

8779-
return std::nullopt;
8797+
return Chains;
87808798
}
87818799

87828800
VPRecipeBase *
@@ -8871,12 +8889,14 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
88718889
"Unexpected number of operands for partial reduction");
88728890

88738891
VPValue *BinOp = Operands[0];
8874-
VPValue *Phi = Operands[1];
8875-
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
8876-
std::swap(BinOp, Phi);
8877-
8878-
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
8879-
Reduction);
8892+
VPValue *Accumulator = Operands[1];
8893+
VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
8894+
if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8895+
isa<VPPartialReductionRecipe>(BinOpRecipe))
8896+
std::swap(BinOp, Accumulator);
8897+
8898+
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp,
8899+
Accumulator, Reduction);
88808900
}
88818901

88828902
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,

llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ class VPRecipeBuilder {
142142
/// Returns null if no scaled reduction was found, otherwise a pair with a
143143
/// struct containing reduction information and the scaling factor between the
144144
/// number of elements in the input and output.
145-
std::optional<std::pair<PartialReductionChain, unsigned>>
146-
getScaledReduction(PHINode *PHI, const RecurrenceDescriptor &Rdx,
145+
std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
146+
getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
147147
VFRange &Range);
148148

149149
public:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2456,7 +2456,9 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe {
24562456
: VPSingleDefRecipe(VPDef::VPPartialReductionSC,
24572457
ArrayRef<VPValue *>({Op0, Op1}), ReductionInst),
24582458
Opcode(Opcode) {
2459-
assert(isa<VPReductionPHIRecipe>(getOperand(1)->getDefiningRecipe()) &&
2459+
auto *AccumulatorRecipe = getOperand(1)->getDefiningRecipe();
2460+
assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
2461+
isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
24602462
"Unexpected operand order for partial reduction recipe");
24612463
}
24622464
~VPPartialReductionRecipe() override = default;

0 commit comments

Comments
 (0)