@@ -8684,12 +8684,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8684
8684
// / are valid so recipes can be formed later.
8685
8685
void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8686
8686
// Find all possible partial reductions.
8687
- SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8687
+ SmallVector<std::pair<PartialReductionChain, unsigned >>
8688
8688
PartialReductionChains;
8689
- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8690
- if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8691
- getScaledReduction (Phi, RdxDesc, Range))
8692
- PartialReductionChains. push_back (*Pair);
8689
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ()) {
8690
+ getScaledReductions (Phi, RdxDesc. getLoopExitInstr (), Range,
8691
+ PartialReductionChains);
8692
+ }
8693
8693
8694
8694
// A partial reduction is invalid if any of its extends are used by
8695
8695
// something that isn't another partial reduction. This is because the
@@ -8717,39 +8717,54 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8717
8717
}
8718
8718
}
8719
8719
8720
- std::optional<std::pair<PartialReductionChain, unsigned >>
8721
- VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8722
- const RecurrenceDescriptor &Rdx,
8723
- VFRange &Range) {
8720
+ bool VPRecipeBuilder::getScaledReductions (
8721
+ Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
8722
+ SmallVectorImpl<std::pair<PartialReductionChain, unsigned >> &Chains) {
8723
+
8724
+ if (!CM.TheLoop ->contains (RdxExitInstr))
8725
+ return false ;
8726
+
8724
8727
// TODO: Allow scaling reductions when predicating. The select at
8725
8728
// the end of the loop chooses between the phi value and most recent
8726
8729
// reduction result, both of which have different VFs to the active lane
8727
8730
// mask when scaling.
8728
- if (CM.blockNeedsPredicationForAnyReason (Rdx. getLoopExitInstr () ->getParent ()))
8729
- return std::nullopt ;
8731
+ if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr ->getParent ()))
8732
+ return false ;
8730
8733
8731
- auto *Update = dyn_cast<BinaryOperator>(Rdx. getLoopExitInstr () );
8734
+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
8732
8735
if (!Update)
8733
- return std::nullopt ;
8736
+ return false ;
8734
8737
8735
8738
Value *Op = Update->getOperand (0 );
8736
8739
Value *PhiOp = Update->getOperand (1 );
8737
- if (Op == PHI) {
8738
- Op = Update->getOperand (1 );
8739
- PhiOp = Update->getOperand (0 );
8740
+ if (Op == PHI)
8741
+ std::swap (Op, PhiOp);
8742
+
8743
+ // Try and get a scaled reduction from the first non-phi operand.
8744
+ // If one is found, we use the discovered reduction instruction in
8745
+ // place of the accumulator for costing.
8746
+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8747
+ if (getScaledReductions (PHI, OpInst, Range, Chains)) {
8748
+ PHI = Chains.rbegin ()->first .Reduction ;
8749
+
8750
+ Op = Update->getOperand (0 );
8751
+ PhiOp = Update->getOperand (1 );
8752
+ if (Op == PHI)
8753
+ std::swap (Op, PhiOp);
8754
+ }
8740
8755
}
8741
8756
if (PhiOp != PHI)
8742
- return std::nullopt ;
8757
+ return false ;
8743
8758
8744
8759
auto *BinOp = dyn_cast<BinaryOperator>(Op);
8745
8760
if (!BinOp || !BinOp->hasOneUse ())
8746
- return std::nullopt ;
8761
+ return false ;
8747
8762
8748
8763
using namespace llvm ::PatternMatch;
8749
8764
Value *A, *B;
8750
8765
if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
8751
8766
!match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8752
- return std::nullopt ;
8767
+ return false ;
8753
8768
8754
8769
Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8755
8770
Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
@@ -8759,7 +8774,7 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8759
8774
TTI::PartialReductionExtendKind OpBExtend =
8760
8775
TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8761
8776
8762
- PartialReductionChain Chain (Rdx. getLoopExitInstr () , ExtA, ExtB, BinOp);
8777
+ PartialReductionChain Chain (RdxExitInstr , ExtA, ExtB, BinOp);
8763
8778
8764
8779
unsigned TargetScaleFactor =
8765
8780
PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
@@ -8773,10 +8788,12 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8773
8788
std::make_optional (BinOp->getOpcode ()));
8774
8789
return Cost.isValid ();
8775
8790
},
8776
- Range))
8777
- return std::make_pair (Chain, TargetScaleFactor);
8791
+ Range)) {
8792
+ Chains.push_back (std::make_pair (Chain, TargetScaleFactor));
8793
+ return true ;
8794
+ }
8778
8795
8779
- return std::nullopt ;
8796
+ return false ;
8780
8797
}
8781
8798
8782
8799
VPRecipeBase *
@@ -8871,12 +8888,14 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8871
8888
" Unexpected number of operands for partial reduction" );
8872
8889
8873
8890
VPValue *BinOp = Operands[0 ];
8874
- VPValue *Phi = Operands[1 ];
8875
- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
8876
- std::swap (BinOp, Phi);
8877
-
8878
- return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
8879
- Reduction);
8891
+ VPValue *Accumulator = Operands[1 ];
8892
+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe ();
8893
+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8894
+ isa<VPPartialReductionRecipe>(BinOpRecipe))
8895
+ std::swap (BinOp, Accumulator);
8896
+
8897
+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp,
8898
+ Accumulator, Reduction);
8880
8899
}
8881
8900
8882
8901
void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
0 commit comments