Skip to content

[VectorCombine] foldShuffleOfShuffles - fold "shuffle (shuffle x, undef), (shuffle y, undef)" -> "shuffle x, y" #88743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ struct undef_match {
/// neither undef nor poison, the aggregate is not matched.
inline auto m_Undef() { return undef_match(); }

/// Match an arbitrary UndefValue constant.
inline class_match<UndefValue> m_UndefValue() {
return class_match<UndefValue>();
}

/// Match an arbitrary poison constant.
inline class_match<PoisonValue> m_Poison() {
return class_match<PoisonValue>();
Expand Down Expand Up @@ -777,6 +782,9 @@ m_WithOverflowInst(const WithOverflowInst *&I) {
return I;
}

/// Match an UndefValue, capturing the value if we match.
inline bind_ty<UndefValue> m_UndefValue(UndefValue *&U) { return U; }

/// Match a Constant, capturing the value if we match.
inline bind_ty<Constant> m_Constant(Constant *&C) { return C; }

Expand Down
82 changes: 82 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class VectorCombine {
bool scalarizeLoadExtract(Instruction &I);
bool foldShuffleOfBinops(Instruction &I);
bool foldShuffleOfCastops(Instruction &I);
bool foldShuffleOfShuffles(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
bool foldTruncFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
Expand Down Expand Up @@ -1552,6 +1553,86 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
return true;
}

/// Try to convert "shuffle (shuffle x, undef), (shuffle y, undef)"
/// into "shuffle x, y".
bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
Value *V0, *V1;
UndefValue *U0, *U1;
ArrayRef<int> OuterMask, InnerMask0, InnerMask1;
if (!match(&I, m_Shuffle(m_OneUse(m_Shuffle(m_Value(V0), m_UndefValue(U0),
m_Mask(InnerMask0))),
m_OneUse(m_Shuffle(m_Value(V1), m_UndefValue(U1),
m_Mask(InnerMask1))),
m_Mask(OuterMask))))
return false;

auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(V0->getType());
auto *ShuffleImmTy = dyn_cast<FixedVectorType>(I.getOperand(0)->getType());
if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
V0->getType() != V1->getType())
return false;

unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
unsigned NumImmElts = ShuffleImmTy->getNumElements();

// Bail if either inner masks reference a RHS undef arg.
if ((!isa<PoisonValue>(U0) &&
any_of(InnerMask0, [&](int M) { return M >= (int)NumSrcElts; })) ||
(!isa<PoisonValue>(U1) &&
any_of(InnerMask1, [&](int M) { return M >= (int)NumSrcElts; })))
return false;

// Merge shuffles - replace index to the RHS poison arg with PoisonMaskElem,
SmallVector<int, 16> NewMask(OuterMask.begin(), OuterMask.end());
for (int &M : NewMask) {
if (0 <= M && M < (int)NumImmElts) {
M = (InnerMask0[M] >= (int)NumSrcElts) ? PoisonMaskElem : InnerMask0[M];
} else if (M >= (int)NumImmElts) {
if (InnerMask1[M - NumImmElts] >= (int)NumSrcElts)
M = PoisonMaskElem;
else
M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts);
}
}

// Have we folded to an Identity shuffle?
if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) {
replaceValue(I, *V0);
return true;
}

// Try to merge the shuffles if the new shuffle is not costly.
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;

InstructionCost OldCost =
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
InnerMask0, CostKind) +
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
InnerMask1, CostKind) +
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy,
OuterMask, CostKind, 0, nullptr, std::nullopt, &I);

InstructionCost NewCost = TTI.getShuffleCost(
TargetTransformInfo::SK_PermuteTwoSrc, ShuffleSrcTy, NewMask, CostKind);

LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
<< "\n");
if (NewCost > OldCost)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this should be >=, unless there is a strong reason to do this more aggressively?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It comes down to the reduction in instruction count - if we fold 3 shuffle instructions into 1 (note we only fold if the inner shuffles having oneuse) for the same cost - isn't that better for further folds?

return false;

// Clear unused sources to poison.
if (none_of(NewMask, [&](int M) { return 0 <= M && M < (int)NumSrcElts; }))
V0 = PoisonValue::get(ShuffleSrcTy);
if (none_of(NewMask, [&](int M) { return (int)NumSrcElts <= M; }))
V1 = PoisonValue::get(ShuffleSrcTy);

Value *Shuf = Builder.CreateShuffleVector(V0, V1, NewMask);
replaceValue(I, *Shuf);
return true;
}

/// Given a commutative reduction, the order of the input lanes does not alter
/// the results. We can use this to remove certain shuffles feeding the
/// reduction, removing the need to shuffle at all.
Expand Down Expand Up @@ -2107,6 +2188,7 @@ bool VectorCombine::run() {
case Instruction::ShuffleVector:
MadeChange |= foldShuffleOfBinops(I);
MadeChange |= foldShuffleOfCastops(I);
MadeChange |= foldShuffleOfShuffles(I);
MadeChange |= foldSelectShuffle(I);
break;
case Instruction::BitCast:
Expand Down
Loading
Loading