@@ -8684,12 +8684,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8684
8684
// / are valid so recipes can be formed later.
8685
8685
void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8686
8686
// Find all possible partial reductions.
8687
- SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8687
+ SmallVector<std::pair<PartialReductionChain, unsigned >>
8688
8688
PartialReductionChains;
8689
- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8690
- if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8691
- getScaledReduction (Phi, RdxDesc, Range))
8692
- PartialReductionChains. push_back (*Pair);
8689
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ()) {
8690
+ if (auto SR = getScaledReduction (Phi, RdxDesc. getLoopExitInstr (), Range))
8691
+ PartialReductionChains. append (*SR);
8692
+ }
8693
8693
8694
8694
// A partial reduction is invalid if any of its extends are used by
8695
8695
// something that isn't another partial reduction. This is because the
@@ -8717,26 +8717,44 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8717
8717
}
8718
8718
}
8719
8719
8720
- std::optional<std::pair<PartialReductionChain, unsigned >>
8721
- VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8722
- const RecurrenceDescriptor &Rdx,
8720
+ std::optional<SmallVector<std::pair<PartialReductionChain, unsigned >>>
8721
+ VPRecipeBuilder::getScaledReduction (Instruction *PHI, Instruction *RdxExitInstr,
8723
8722
VFRange &Range) {
8723
+
8724
+ if (!CM.TheLoop ->contains (RdxExitInstr))
8725
+ return std::nullopt;
8726
+
8724
8727
// TODO: Allow scaling reductions when predicating. The select at
8725
8728
// the end of the loop chooses between the phi value and most recent
8726
8729
// reduction result, both of which have different VFs to the active lane
8727
8730
// mask when scaling.
8728
- if (CM.blockNeedsPredicationForAnyReason (Rdx. getLoopExitInstr () ->getParent ()))
8731
+ if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr ->getParent ()))
8729
8732
return std::nullopt;
8730
8733
8731
- auto *Update = dyn_cast<BinaryOperator>(Rdx. getLoopExitInstr () );
8734
+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
8732
8735
if (!Update)
8733
8736
return std::nullopt;
8734
8737
8735
8738
Value *Op = Update->getOperand (0 );
8736
8739
Value *PhiOp = Update->getOperand (1 );
8737
- if (Op == PHI) {
8738
- Op = Update->getOperand (1 );
8739
- PhiOp = Update->getOperand (0 );
8740
+ if (Op == PHI)
8741
+ std::swap (Op, PhiOp);
8742
+
8743
+ SmallVector<std::pair<PartialReductionChain, unsigned >> Chains;
8744
+
8745
+ // Try and get a scaled reduction from the first non-phi operand.
8746
+ // If one is found, we use the discovered reduction instruction in
8747
+ // place of the accumulator for costing.
8748
+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8749
+ if (auto SR0 = getScaledReduction (PHI, OpInst, Range)) {
8750
+ Chains.append (*SR0);
8751
+ PHI = SR0->rbegin ()->first .Reduction ;
8752
+
8753
+ Op = Update->getOperand (0 );
8754
+ PhiOp = Update->getOperand (1 );
8755
+ if (Op == PHI)
8756
+ std::swap (Op, PhiOp);
8757
+ }
8740
8758
}
8741
8759
if (PhiOp != PHI)
8742
8760
return std::nullopt;
@@ -8759,7 +8777,7 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8759
8777
TTI::PartialReductionExtendKind OpBExtend =
8760
8778
TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8761
8779
8762
- PartialReductionChain Chain (Rdx. getLoopExitInstr () , ExtA, ExtB, BinOp);
8780
+ PartialReductionChain Chain (RdxExitInstr , ExtA, ExtB, BinOp);
8763
8781
8764
8782
unsigned TargetScaleFactor =
8765
8783
PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
@@ -8774,9 +8792,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8774
8792
return Cost.isValid ();
8775
8793
},
8776
8794
Range))
8777
- return std::make_pair (Chain, TargetScaleFactor);
8795
+ Chains. push_back ( std::make_pair (Chain, TargetScaleFactor) );
8778
8796
8779
- return std::nullopt ;
8797
+ return Chains ;
8780
8798
}
8781
8799
8782
8800
VPRecipeBase *
@@ -8871,12 +8889,14 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8871
8889
" Unexpected number of operands for partial reduction" );
8872
8890
8873
8891
VPValue *BinOp = Operands[0 ];
8874
- VPValue *Phi = Operands[1 ];
8875
- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
8876
- std::swap (BinOp, Phi);
8877
-
8878
- return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
8879
- Reduction);
8892
+ VPValue *Accumulator = Operands[1 ];
8893
+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe ();
8894
+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8895
+ isa<VPPartialReductionRecipe>(BinOpRecipe))
8896
+ std::swap (BinOp, Accumulator);
8897
+
8898
+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp,
8899
+ Accumulator, Reduction);
8880
8900
}
8881
8901
8882
8902
void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
0 commit comments