Skip to content

Commit cdea38f

Browse files
Reland "[LoopVectorizer] Add support for chaining partial reductions #120272" (#124282)
Change `getScaledReduction` to take an existing vector, rather than creating and returning a new one each call. Rename `getScaledReduction` to `getScaledReductions` to more accurately reflect what it's now doing. --------- Co-authored-by: Karlo Basioli <[email protected]>
1 parent b29bf3d commit cdea38f

File tree

4 files changed

+1087
-36
lines changed

4 files changed

+1087
-36
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8684,12 +8684,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
86848684
/// are valid so recipes can be formed later.
86858685
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
86868686
// Find all possible partial reductions.
8687-
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8687+
SmallVector<std::pair<PartialReductionChain, unsigned>>
86888688
PartialReductionChains;
8689-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8690-
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8691-
getScaledReduction(Phi, RdxDesc, Range))
8692-
PartialReductionChains.push_back(*Pair);
8689+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
8690+
getScaledReductions(Phi, RdxDesc.getLoopExitInstr(), Range,
8691+
PartialReductionChains);
8692+
}
86938693

86948694
// A partial reduction is invalid if any of its extends are used by
86958695
// something that isn't another partial reduction. This is because the
@@ -8717,39 +8717,54 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
87178717
}
87188718
}
87198719

8720-
std::optional<std::pair<PartialReductionChain, unsigned>>
8721-
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8722-
const RecurrenceDescriptor &Rdx,
8723-
VFRange &Range) {
8720+
bool VPRecipeBuilder::getScaledReductions(
8721+
Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
8722+
SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
8723+
8724+
if (!CM.TheLoop->contains(RdxExitInstr))
8725+
return false;
8726+
87248727
// TODO: Allow scaling reductions when predicating. The select at
87258728
// the end of the loop chooses between the phi value and most recent
87268729
// reduction result, both of which have different VFs to the active lane
87278730
// mask when scaling.
8728-
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8729-
return std::nullopt;
8731+
if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr->getParent()))
8732+
return false;
87308733

8731-
auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8734+
auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
87328735
if (!Update)
8733-
return std::nullopt;
8736+
return false;
87348737

87358738
Value *Op = Update->getOperand(0);
87368739
Value *PhiOp = Update->getOperand(1);
8737-
if (Op == PHI) {
8738-
Op = Update->getOperand(1);
8739-
PhiOp = Update->getOperand(0);
8740+
if (Op == PHI)
8741+
std::swap(Op, PhiOp);
8742+
8743+
// Try and get a scaled reduction from the first non-phi operand.
8744+
// If one is found, we use the discovered reduction instruction in
8745+
// place of the accumulator for costing.
8746+
if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8747+
if (getScaledReductions(PHI, OpInst, Range, Chains)) {
8748+
PHI = Chains.rbegin()->first.Reduction;
8749+
8750+
Op = Update->getOperand(0);
8751+
PhiOp = Update->getOperand(1);
8752+
if (Op == PHI)
8753+
std::swap(Op, PhiOp);
8754+
}
87408755
}
87418756
if (PhiOp != PHI)
8742-
return std::nullopt;
8757+
return false;
87438758

87448759
auto *BinOp = dyn_cast<BinaryOperator>(Op);
87458760
if (!BinOp || !BinOp->hasOneUse())
8746-
return std::nullopt;
8761+
return false;
87478762

87488763
using namespace llvm::PatternMatch;
87498764
Value *A, *B;
87508765
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
87518766
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8752-
return std::nullopt;
8767+
return false;
87538768

87548769
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
87558770
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
@@ -8759,7 +8774,7 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87598774
TTI::PartialReductionExtendKind OpBExtend =
87608775
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
87618776

8762-
PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8777+
PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
87638778

87648779
unsigned TargetScaleFactor =
87658780
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8773,10 +8788,12 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87738788
std::make_optional(BinOp->getOpcode()));
87748789
return Cost.isValid();
87758790
},
8776-
Range))
8777-
return std::make_pair(Chain, TargetScaleFactor);
8791+
Range)) {
8792+
Chains.push_back(std::make_pair(Chain, TargetScaleFactor));
8793+
return true;
8794+
}
87788795

8779-
return std::nullopt;
8796+
return false;
87808797
}
87818798

87828799
VPRecipeBase *
@@ -8871,12 +8888,14 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
88718888
"Unexpected number of operands for partial reduction");
88728889

88738890
VPValue *BinOp = Operands[0];
8874-
VPValue *Phi = Operands[1];
8875-
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
8876-
std::swap(BinOp, Phi);
8877-
8878-
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
8879-
Reduction);
8891+
VPValue *Accumulator = Operands[1];
8892+
VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
8893+
if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8894+
isa<VPPartialReductionRecipe>(BinOpRecipe))
8895+
std::swap(BinOp, Accumulator);
8896+
8897+
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp,
8898+
Accumulator, Reduction);
88808899
}
88818900

88828901
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,

llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,16 @@ class VPRecipeBuilder {
139139

140140
/// Examines reduction operations to see if the target can use a cheaper
141141
/// operation with a wider per-iteration input VF and narrower PHI VF.
142-
/// Returns null if no scaled reduction was found, otherwise a pair with a
143-
/// struct containing reduction information and the scaling factor between the
144-
/// number of elements in the input and output.
145-
std::optional<std::pair<PartialReductionChain, unsigned>>
146-
getScaledReduction(PHINode *PHI, const RecurrenceDescriptor &Rdx,
147-
VFRange &Range);
142+
/// Each element within Chains is a pair with a struct containing reduction
143+
/// information and the scaling factor between the number of elements in
144+
/// the input and output.
145+
/// Recursively calls itself to identify chained scaled reductions.
146+
/// Returns true if this invocation added an entry to Chains, otherwise false.
147+
/// i.e. returns false in the case that a subcall adds an entry to Chains,
148+
/// but the top-level call does not.
149+
bool getScaledReductions(
150+
Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
151+
SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains);
148152

149153
public:
150154
VPRecipeBuilder(VPlan &Plan, Loop *OrigLoop, const TargetLibraryInfo *TLI,

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2462,7 +2462,10 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe {
24622462
: VPSingleDefRecipe(VPDef::VPPartialReductionSC,
24632463
ArrayRef<VPValue *>({Op0, Op1}), ReductionInst),
24642464
Opcode(Opcode) {
2465-
assert(isa<VPReductionPHIRecipe>(getOperand(1)->getDefiningRecipe()) &&
2465+
[[maybe_unused]] auto *AccumulatorRecipe =
2466+
getOperand(1)->getDefiningRecipe();
2467+
assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
2468+
isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
24662469
"Unexpected operand order for partial reduction recipe");
24672470
}
24682471
~VPPartialReductionRecipe() override = default;

0 commit comments

Comments
 (0)