Skip to content

Commit bddfbe7

Browse files
authored
[VectorCombine] foldShuffleOfShuffles - fold "shuffle (shuffle x, undef), (shuffle y, undef)" -> "shuffle x, y" (#88743)
Another step towards cleaning up shuffles that have been split, often across bitcasts between SSE intrinsic. Strip shuffles entirely if we fold to an identity shuffle.
1 parent 5138ccd commit bddfbe7

File tree

7 files changed

+167
-121
lines changed

7 files changed

+167
-121
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ struct undef_match {
151151
/// neither undef nor poison, the aggregate is not matched.
152152
inline auto m_Undef() { return undef_match(); }
153153

154+
/// Match an arbitrary UndefValue constant.
155+
inline class_match<UndefValue> m_UndefValue() {
156+
return class_match<UndefValue>();
157+
}
158+
154159
/// Match an arbitrary poison constant.
155160
inline class_match<PoisonValue> m_Poison() {
156161
return class_match<PoisonValue>();
@@ -777,6 +782,9 @@ m_WithOverflowInst(const WithOverflowInst *&I) {
777782
return I;
778783
}
779784

785+
/// Match an UndefValue, capturing the value if we match.
786+
inline bind_ty<UndefValue> m_UndefValue(UndefValue *&U) { return U; }
787+
780788
/// Match a Constant, capturing the value if we match.
781789
inline bind_ty<Constant> m_Constant(Constant *&C) { return C; }
782790

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class VectorCombine {
113113
bool scalarizeLoadExtract(Instruction &I);
114114
bool foldShuffleOfBinops(Instruction &I);
115115
bool foldShuffleOfCastops(Instruction &I);
116+
bool foldShuffleOfShuffles(Instruction &I);
116117
bool foldShuffleFromReductions(Instruction &I);
117118
bool foldTruncFromReductions(Instruction &I);
118119
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
@@ -1552,6 +1553,86 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
15521553
return true;
15531554
}
15541555

1556+
/// Try to convert "shuffle (shuffle x, undef), (shuffle y, undef)"
1557+
/// into "shuffle x, y".
1558+
bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
1559+
Value *V0, *V1;
1560+
UndefValue *U0, *U1;
1561+
ArrayRef<int> OuterMask, InnerMask0, InnerMask1;
1562+
if (!match(&I, m_Shuffle(m_OneUse(m_Shuffle(m_Value(V0), m_UndefValue(U0),
1563+
m_Mask(InnerMask0))),
1564+
m_OneUse(m_Shuffle(m_Value(V1), m_UndefValue(U1),
1565+
m_Mask(InnerMask1))),
1566+
m_Mask(OuterMask))))
1567+
return false;
1568+
1569+
auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1570+
auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(V0->getType());
1571+
auto *ShuffleImmTy = dyn_cast<FixedVectorType>(I.getOperand(0)->getType());
1572+
if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
1573+
V0->getType() != V1->getType())
1574+
return false;
1575+
1576+
unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
1577+
unsigned NumImmElts = ShuffleImmTy->getNumElements();
1578+
1579+
// Bail if either inner masks reference a RHS undef arg.
1580+
if ((!isa<PoisonValue>(U0) &&
1581+
any_of(InnerMask0, [&](int M) { return M >= (int)NumSrcElts; })) ||
1582+
(!isa<PoisonValue>(U1) &&
1583+
any_of(InnerMask1, [&](int M) { return M >= (int)NumSrcElts; })))
1584+
return false;
1585+
1586+
// Merge shuffles - replace index to the RHS poison arg with PoisonMaskElem,
1587+
SmallVector<int, 16> NewMask(OuterMask.begin(), OuterMask.end());
1588+
for (int &M : NewMask) {
1589+
if (0 <= M && M < (int)NumImmElts) {
1590+
M = (InnerMask0[M] >= (int)NumSrcElts) ? PoisonMaskElem : InnerMask0[M];
1591+
} else if (M >= (int)NumImmElts) {
1592+
if (InnerMask1[M - NumImmElts] >= (int)NumSrcElts)
1593+
M = PoisonMaskElem;
1594+
else
1595+
M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts);
1596+
}
1597+
}
1598+
1599+
// Have we folded to an Identity shuffle?
1600+
if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) {
1601+
replaceValue(I, *V0);
1602+
return true;
1603+
}
1604+
1605+
// Try to merge the shuffles if the new shuffle is not costly.
1606+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1607+
1608+
InstructionCost OldCost =
1609+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
1610+
InnerMask0, CostKind) +
1611+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
1612+
InnerMask1, CostKind) +
1613+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy,
1614+
OuterMask, CostKind, 0, nullptr, std::nullopt, &I);
1615+
1616+
InstructionCost NewCost = TTI.getShuffleCost(
1617+
TargetTransformInfo::SK_PermuteTwoSrc, ShuffleSrcTy, NewMask, CostKind);
1618+
1619+
LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
1620+
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1621+
<< "\n");
1622+
if (NewCost > OldCost)
1623+
return false;
1624+
1625+
// Clear unused sources to poison.
1626+
if (none_of(NewMask, [&](int M) { return 0 <= M && M < (int)NumSrcElts; }))
1627+
V0 = PoisonValue::get(ShuffleSrcTy);
1628+
if (none_of(NewMask, [&](int M) { return (int)NumSrcElts <= M; }))
1629+
V1 = PoisonValue::get(ShuffleSrcTy);
1630+
1631+
Value *Shuf = Builder.CreateShuffleVector(V0, V1, NewMask);
1632+
replaceValue(I, *Shuf);
1633+
return true;
1634+
}
1635+
15551636
/// Given a commutative reduction, the order of the input lanes does not alter
15561637
/// the results. We can use this to remove certain shuffles feeding the
15571638
/// reduction, removing the need to shuffle at all.
@@ -2107,6 +2188,7 @@ bool VectorCombine::run() {
21072188
case Instruction::ShuffleVector:
21082189
MadeChange |= foldShuffleOfBinops(I);
21092190
MadeChange |= foldShuffleOfCastops(I);
2191+
MadeChange |= foldShuffleOfShuffles(I);
21102192
MadeChange |= foldSelectShuffle(I);
21112193
break;
21122194
case Instruction::BitCast:

0 commit comments

Comments
 (0)