@@ -8749,15 +8749,15 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8749
8749
// something that isn't another partial reduction. This is because the
8750
8750
// extends are intended to be lowered along with the reduction itself.
8751
8751
8752
- // Build up a set of partial reduction bin ops for efficient use checking.
8753
- SmallSet<User *, 4 > PartialReductionBinOps ;
8752
+ // Build up a set of partial reduction ops for efficient use checking.
8753
+ SmallSet<User *, 4 > PartialReductionOps ;
8754
8754
for (const auto &[PartialRdx, _] : PartialReductionChains)
8755
- PartialReductionBinOps .insert (PartialRdx.BinOp );
8755
+ PartialReductionOps .insert (PartialRdx.ExtendUser );
8756
8756
8757
8757
auto ExtendIsOnlyUsedByPartialReductions =
8758
- [&PartialReductionBinOps ](Instruction *Extend) {
8758
+ [&PartialReductionOps ](Instruction *Extend) {
8759
8759
return all_of (Extend->users (), [&](const User *U) {
8760
- return PartialReductionBinOps .contains (U);
8760
+ return PartialReductionOps .contains (U);
8761
8761
});
8762
8762
};
8763
8763
@@ -8766,15 +8766,14 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8766
8766
for (auto Pair : PartialReductionChains) {
8767
8767
PartialReductionChain Chain = Pair.first ;
8768
8768
if (ExtendIsOnlyUsedByPartialReductions (Chain.ExtendA ) &&
8769
- ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB ))
8769
+ (!Chain. ExtendB || ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB ) ))
8770
8770
ScaledReductionMap.insert (std::make_pair (Chain.Reduction , Pair.second ));
8771
8771
}
8772
8772
}
8773
8773
8774
8774
bool VPRecipeBuilder::getScaledReductions (
8775
8775
Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
8776
8776
SmallVectorImpl<std::pair<PartialReductionChain, unsigned >> &Chains) {
8777
-
8778
8777
if (!CM.TheLoop ->contains (RdxExitInstr))
8779
8778
return false ;
8780
8779
@@ -8803,40 +8802,69 @@ bool VPRecipeBuilder::getScaledReductions(
8803
8802
if (PhiOp != PHI)
8804
8803
return false ;
8805
8804
8806
- auto *BinOp = dyn_cast<BinaryOperator>(Op);
8807
- if (!BinOp || !BinOp->hasOneUse ())
8808
- return false ;
8809
-
8810
8805
using namespace llvm ::PatternMatch;
8811
- // Use the side-effect of match to replace BinOp only if the pattern is
8812
- // matched, we don't care at this point whether it actually matched.
8813
- match (BinOp, m_Neg (m_BinOp (BinOp)));
8814
8806
8815
- Value *A, *B;
8816
- if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
8817
- !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8818
- return false ;
8807
+ // If the update is a binary operator, check both of its operands to see if
8808
+ // they are extends. Otherwise, see if the update comes directly from an
8809
+ // extend.
8810
+ Instruction *Exts[2 ] = {nullptr };
8811
+ BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
8812
+ std::optional<unsigned > BinOpc;
8813
+ Type *ExtOpTypes[2 ] = {nullptr };
8814
+
8815
+ auto collectExtInfo = [&Exts, &ExtOpTypes](SmallVectorImpl<Value *> &Ops) -> bool {
8816
+ unsigned I = 0 ;
8817
+ for (Value *OpI : Ops) {
8818
+ Value *ExtOp;
8819
+ if (!match (OpI, m_ZExtOrSExt (m_Value (ExtOp))))
8820
+ return false ;
8821
+ Exts[I] = cast<Instruction>(OpI);
8822
+ ExtOpTypes[I] = ExtOp->getType ();
8823
+ I++;
8824
+ }
8825
+ return true ;
8826
+ };
8827
+
8828
+ if (ExtendUser) {
8829
+ if (!ExtendUser->hasOneUse ())
8830
+ return false ;
8831
+
8832
+ // Use the side-effect of match to replace BinOp only if the pattern is
8833
+ // matched, we don't care at this point whether it actually matched.
8834
+ match (ExtendUser, m_Neg (m_BinOp (ExtendUser)));
8819
8835
8820
- Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8821
- Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
8836
+ SmallVector<Value *> Ops (ExtendUser->operands ());
8837
+ if (!collectExtInfo (Ops))
8838
+ return false ;
8839
+
8840
+ BinOpc = std::make_optional (ExtendUser->getOpcode ());
8841
+ } else if (match (Update, m_Add (m_Value (), m_Value ()))) {
8842
+ // We already know the operands for Update are Op and PhiOp.
8843
+ SmallVector<Value *> Ops ({Op});
8844
+ if (!collectExtInfo (Ops))
8845
+ return false ;
8846
+
8847
+ ExtendUser = Update;
8848
+ BinOpc = std::nullopt;
8849
+ } else
8850
+ return false ;
8822
8851
8823
8852
TTI::PartialReductionExtendKind OpAExtend =
8824
- TargetTransformInfo::getPartialReductionExtendKind (ExtA );
8853
+ TargetTransformInfo::getPartialReductionExtendKind (Exts[ 0 ] );
8825
8854
TTI::PartialReductionExtendKind OpBExtend =
8826
- TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8827
-
8828
- PartialReductionChain Chain (RdxExitInstr, ExtA, ExtB, BinOp );
8855
+ Exts[ 1 ] ? TargetTransformInfo::getPartialReductionExtendKind (Exts[ 1 ])
8856
+ : TargetTransformInfo::PR_None;
8857
+ PartialReductionChain Chain (RdxExitInstr, Exts[ 0 ], Exts[ 1 ], ExtendUser );
8829
8858
8830
8859
unsigned TargetScaleFactor =
8831
8860
PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
8832
- A-> getType () ->getPrimitiveSizeInBits ());
8861
+ ExtOpTypes[ 0 ] ->getPrimitiveSizeInBits ());
8833
8862
8834
8863
if (LoopVectorizationPlanner::getDecisionAndClampRange (
8835
8864
[&](ElementCount VF) {
8836
8865
InstructionCost Cost = TTI->getPartialReductionCost (
8837
- Update->getOpcode (), A->getType (), B->getType (), PHI->getType (),
8838
- VF, OpAExtend, OpBExtend,
8839
- std::make_optional (BinOp->getOpcode ()));
8866
+ Update->getOpcode (), ExtOpTypes[0 ], ExtOpTypes[1 ],
8867
+ PHI->getType (), VF, OpAExtend, OpBExtend, BinOpc);
8840
8868
return Cost.isValid ();
8841
8869
},
8842
8870
Range)) {
0 commit comments