@@ -113,6 +113,7 @@ class VectorCombine {
113
113
bool scalarizeLoadExtract (Instruction &I);
114
114
bool foldShuffleOfBinops (Instruction &I);
115
115
bool foldShuffleOfCastops (Instruction &I);
116
+ bool foldShuffleOfShuffles (Instruction &I);
116
117
bool foldShuffleFromReductions (Instruction &I);
117
118
bool foldTruncFromReductions (Instruction &I);
118
119
bool foldSelectShuffle (Instruction &I, bool FromReduction = false );
@@ -1552,6 +1553,86 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
1552
1553
return true ;
1553
1554
}
1554
1555
1556
+ // / Try to convert "shuffle (shuffle x, undef), (shuffle y, undef)"
1557
+ // / into "shuffle x, y".
1558
+ bool VectorCombine::foldShuffleOfShuffles (Instruction &I) {
1559
+ Value *V0, *V1;
1560
+ UndefValue *U0, *U1;
1561
+ ArrayRef<int > OuterMask, InnerMask0, InnerMask1;
1562
+ if (!match (&I, m_Shuffle (m_OneUse (m_Shuffle (m_Value (V0), m_UndefValue (U0),
1563
+ m_Mask (InnerMask0))),
1564
+ m_OneUse (m_Shuffle (m_Value (V1), m_UndefValue (U1),
1565
+ m_Mask (InnerMask1))),
1566
+ m_Mask (OuterMask))))
1567
+ return false ;
1568
+
1569
+ auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType ());
1570
+ auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(V0->getType ());
1571
+ auto *ShuffleImmTy = dyn_cast<FixedVectorType>(I.getOperand (0 )->getType ());
1572
+ if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
1573
+ V0->getType () != V1->getType ())
1574
+ return false ;
1575
+
1576
+ unsigned NumSrcElts = ShuffleSrcTy->getNumElements ();
1577
+ unsigned NumImmElts = ShuffleImmTy->getNumElements ();
1578
+
1579
+ // Bail if either inner masks reference a RHS undef arg.
1580
+ if ((!isa<PoisonValue>(U0) &&
1581
+ any_of (InnerMask0, [&](int M) { return M >= (int )NumSrcElts; })) ||
1582
+ (!isa<PoisonValue>(U1) &&
1583
+ any_of (InnerMask1, [&](int M) { return M >= (int )NumSrcElts; })))
1584
+ return false ;
1585
+
1586
+ // Merge shuffles - replace index to the RHS poison arg with PoisonMaskElem,
1587
+ SmallVector<int , 16 > NewMask (OuterMask.begin (), OuterMask.end ());
1588
+ for (int &M : NewMask) {
1589
+ if (0 <= M && M < (int )NumImmElts) {
1590
+ M = (InnerMask0[M] >= (int )NumSrcElts) ? PoisonMaskElem : InnerMask0[M];
1591
+ } else if (M >= (int )NumImmElts) {
1592
+ if (InnerMask1[M - NumImmElts] >= (int )NumSrcElts)
1593
+ M = PoisonMaskElem;
1594
+ else
1595
+ M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts);
1596
+ }
1597
+ }
1598
+
1599
+ // Have we folded to an Identity shuffle?
1600
+ if (ShuffleVectorInst::isIdentityMask (NewMask, NumSrcElts)) {
1601
+ replaceValue (I, *V0);
1602
+ return true ;
1603
+ }
1604
+
1605
+ // Try to merge the shuffles if the new shuffle is not costly.
1606
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1607
+
1608
+ InstructionCost OldCost =
1609
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
1610
+ InnerMask0, CostKind) +
1611
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
1612
+ InnerMask1, CostKind) +
1613
+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy,
1614
+ OuterMask, CostKind, 0 , nullptr , std::nullopt, &I);
1615
+
1616
+ InstructionCost NewCost = TTI.getShuffleCost (
1617
+ TargetTransformInfo::SK_PermuteTwoSrc, ShuffleSrcTy, NewMask, CostKind);
1618
+
1619
+ LLVM_DEBUG (dbgs () << " Found a shuffle feeding two shuffles: " << I
1620
+ << " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
1621
+ << " \n " );
1622
+ if (NewCost > OldCost)
1623
+ return false ;
1624
+
1625
+ // Clear unused sources to poison.
1626
+ if (none_of (NewMask, [&](int M) { return 0 <= M && M < (int )NumSrcElts; }))
1627
+ V0 = PoisonValue::get (ShuffleSrcTy);
1628
+ if (none_of (NewMask, [&](int M) { return (int )NumSrcElts <= M; }))
1629
+ V1 = PoisonValue::get (ShuffleSrcTy);
1630
+
1631
+ Value *Shuf = Builder.CreateShuffleVector (V0, V1, NewMask);
1632
+ replaceValue (I, *Shuf);
1633
+ return true ;
1634
+ }
1635
+
1555
1636
// / Given a commutative reduction, the order of the input lanes does not alter
1556
1637
// / the results. We can use this to remove certain shuffles feeding the
1557
1638
// / reduction, removing the need to shuffle at all.
@@ -2107,6 +2188,7 @@ bool VectorCombine::run() {
2107
2188
case Instruction::ShuffleVector:
2108
2189
MadeChange |= foldShuffleOfBinops (I);
2109
2190
MadeChange |= foldShuffleOfCastops (I);
2191
+ MadeChange |= foldShuffleOfShuffles (I);
2110
2192
MadeChange |= foldSelectShuffle (I);
2111
2193
break ;
2112
2194
case Instruction::BitCast:
0 commit comments