Skip to content

Commit e3fa7ee

Browse files
authored
VectorCombine: refactor foldShuffleToIdentity (NFC) (#92766)
Lift out the long lambdas into static functions, use C++ destructing syntax, and fix other minor things to improve the readability of the function.
1 parent bfb5fe2 commit e3fa7ee

File tree

1 file changed

+121
-110
lines changed

1 file changed

+121
-110
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 121 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,86 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
16681668
return true;
16691669
}
16701670

1671+
using InstLane = std::pair<Value *, int>;
1672+
1673+
static InstLane lookThroughShuffles(Value *V, int Lane) {
1674+
while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
1675+
unsigned NumElts =
1676+
cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
1677+
int M = SV->getMaskValue(Lane);
1678+
if (M < 0)
1679+
return {nullptr, PoisonMaskElem};
1680+
if (static_cast<unsigned>(M) < NumElts) {
1681+
V = SV->getOperand(0);
1682+
Lane = M;
1683+
} else {
1684+
V = SV->getOperand(1);
1685+
Lane = M - NumElts;
1686+
}
1687+
}
1688+
return InstLane{V, Lane};
1689+
}
1690+
1691+
static SmallVector<InstLane>
1692+
generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
1693+
SmallVector<InstLane> NItem;
1694+
for (InstLane IL : Item) {
1695+
auto [V, Lane] = IL;
1696+
InstLane OpLane =
1697+
V ? lookThroughShuffles(cast<Instruction>(V)->getOperand(Op), Lane)
1698+
: InstLane{nullptr, PoisonMaskElem};
1699+
NItem.emplace_back(OpLane);
1700+
}
1701+
return NItem;
1702+
}
1703+
1704+
static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
1705+
const SmallPtrSet<Value *, 4> &IdentityLeafs,
1706+
const SmallPtrSet<Value *, 4> &SplatLeafs,
1707+
IRBuilder<> &Builder) {
1708+
auto [FrontV, FrontLane] = Item.front();
1709+
1710+
if (IdentityLeafs.contains(FrontV) &&
1711+
all_of(drop_begin(enumerate(Item)), [Item](const auto &E) {
1712+
Value *FrontV = Item.front().first;
1713+
auto [V, Lane] = E.value();
1714+
return !V || (V == FrontV && Lane == (int)E.index());
1715+
})) {
1716+
return FrontV;
1717+
}
1718+
if (SplatLeafs.contains(FrontV)) {
1719+
if (auto *ILI = dyn_cast<Instruction>(FrontV))
1720+
Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef());
1721+
else if (auto *Arg = dyn_cast<Argument>(FrontV))
1722+
Builder.SetInsertPointPastAllocas(Arg->getParent());
1723+
SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
1724+
return Builder.CreateShuffleVector(FrontV, Mask);
1725+
}
1726+
1727+
auto *I = cast<Instruction>(FrontV);
1728+
auto *II = dyn_cast<IntrinsicInst>(I);
1729+
unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
1730+
SmallVector<Value *> Ops(NumOps);
1731+
for (unsigned Idx = 0; Idx < NumOps; Idx++) {
1732+
if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx)) {
1733+
Ops[Idx] = II->getOperand(Idx);
1734+
continue;
1735+
}
1736+
Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx),
1737+
Ty, IdentityLeafs, SplatLeafs, Builder);
1738+
}
1739+
Builder.SetInsertPoint(I);
1740+
Type *DstTy =
1741+
FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
1742+
if (auto *BI = dyn_cast<BinaryOperator>(I))
1743+
return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(), Ops[0],
1744+
Ops[1]);
1745+
if (II)
1746+
return Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
1747+
assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
1748+
return Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
1749+
}
1750+
16711751
// Starting from a shuffle, look up through operands tracking the shuffled index
16721752
// of each lane. If we can simplify away the shuffles to identities then
16731753
// do so.
@@ -1677,117 +1757,90 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
16771757
!isa<Instruction>(I.getOperand(1)))
16781758
return false;
16791759

1680-
using InstLane = std::pair<Value *, int>;
1681-
1682-
auto LookThroughShuffles = [](Value *V, int Lane) -> InstLane {
1683-
while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
1684-
unsigned NumElts =
1685-
cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
1686-
int M = SV->getMaskValue(Lane);
1687-
if (M < 0)
1688-
return {nullptr, PoisonMaskElem};
1689-
else if (M < (int)NumElts) {
1690-
V = SV->getOperand(0);
1691-
Lane = M;
1692-
} else {
1693-
V = SV->getOperand(1);
1694-
Lane = M - NumElts;
1695-
}
1696-
}
1697-
return InstLane{V, Lane};
1698-
};
1699-
1700-
auto GenerateInstLaneVectorFromOperand =
1701-
[&LookThroughShuffles](ArrayRef<InstLane> Item, int Op) {
1702-
SmallVector<InstLane> NItem;
1703-
for (InstLane V : Item) {
1704-
NItem.emplace_back(
1705-
!V.first
1706-
? InstLane{nullptr, PoisonMaskElem}
1707-
: LookThroughShuffles(
1708-
cast<Instruction>(V.first)->getOperand(Op), V.second));
1709-
}
1710-
return NItem;
1711-
};
1712-
17131760
SmallVector<InstLane> Start(Ty->getNumElements());
17141761
for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
1715-
Start[M] = LookThroughShuffles(&I, M);
1762+
Start[M] = lookThroughShuffles(&I, M);
17161763

17171764
SmallVector<SmallVector<InstLane>> Worklist;
17181765
Worklist.push_back(Start);
17191766
SmallPtrSet<Value *, 4> IdentityLeafs, SplatLeafs;
17201767
unsigned NumVisited = 0;
17211768

17221769
while (!Worklist.empty()) {
1723-
SmallVector<InstLane> Item = Worklist.pop_back_val();
17241770
if (++NumVisited > MaxInstrsToScan)
17251771
return false;
17261772

1773+
SmallVector<InstLane> Item = Worklist.pop_back_val();
1774+
auto [FrontV, FrontLane] = Item.front();
1775+
17271776
// If we found an undef first lane then bail out to keep things simple.
1728-
if (!Item[0].first)
1777+
if (!FrontV)
17291778
return false;
17301779

17311780
// Look for an identity value.
1732-
if (Item[0].second == 0 &&
1733-
cast<FixedVectorType>(Item[0].first->getType())->getNumElements() ==
1781+
if (!FrontLane &&
1782+
cast<FixedVectorType>(FrontV->getType())->getNumElements() ==
17341783
Ty->getNumElements() &&
1735-
all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
1736-
return !E.value().first || (E.value().first == Item[0].first &&
1784+
all_of(drop_begin(enumerate(Item)), [Item](const auto &E) {
1785+
Value *FrontV = Item.front().first;
1786+
return !E.value().first || (E.value().first == FrontV &&
17371787
E.value().second == (int)E.index());
17381788
})) {
1739-
IdentityLeafs.insert(Item[0].first);
1789+
IdentityLeafs.insert(FrontV);
17401790
continue;
17411791
}
17421792
// Look for a splat value.
1743-
if (all_of(drop_begin(Item), [&](InstLane &IL) {
1744-
return !IL.first ||
1745-
(IL.first == Item[0].first && IL.second == Item[0].second);
1793+
if (all_of(drop_begin(Item), [Item](InstLane &IL) {
1794+
auto [FrontV, FrontLane] = Item.front();
1795+
auto [V, Lane] = IL;
1796+
return !V || (V == FrontV && Lane == FrontLane);
17461797
})) {
1747-
SplatLeafs.insert(Item[0].first);
1798+
SplatLeafs.insert(FrontV);
17481799
continue;
17491800
}
17501801

17511802
// We need each element to be the same type of value, and check that each
17521803
// element has a single use.
1753-
if (!all_of(drop_begin(Item), [&](InstLane IL) {
1754-
if (!IL.first)
1804+
if (!all_of(drop_begin(Item), [Item](InstLane IL) {
1805+
Value *FrontV = Item.front().first;
1806+
Value *V = IL.first;
1807+
if (!V)
17551808
return true;
1756-
if (auto *I = dyn_cast<Instruction>(IL.first); I && !I->hasOneUse())
1809+
if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse())
17571810
return false;
1758-
if (IL.first->getValueID() != Item[0].first->getValueID())
1811+
if (V->getValueID() != FrontV->getValueID())
17591812
return false;
1760-
if (isa<CallInst>(IL.first) && !isa<IntrinsicInst>(IL.first))
1813+
if (isa<CallInst>(V) && !isa<IntrinsicInst>(V))
17611814
return false;
1762-
auto *II = dyn_cast<IntrinsicInst>(IL.first);
1763-
return !II ||
1764-
(isa<IntrinsicInst>(Item[0].first) &&
1765-
II->getIntrinsicID() ==
1766-
cast<IntrinsicInst>(Item[0].first)->getIntrinsicID());
1815+
auto *II = dyn_cast<IntrinsicInst>(V);
1816+
return !II || (isa<IntrinsicInst>(FrontV) &&
1817+
II->getIntrinsicID() ==
1818+
cast<IntrinsicInst>(FrontV)->getIntrinsicID());
17671819
}))
17681820
return false;
17691821

17701822
// Check the operator is one that we support. We exclude div/rem in case
17711823
// they hit UB from poison lanes.
1772-
if (isa<BinaryOperator>(Item[0].first) &&
1773-
!cast<BinaryOperator>(Item[0].first)->isIntDivRem()) {
1774-
Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
1775-
Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 1));
1776-
} else if (isa<UnaryOperator>(Item[0].first)) {
1777-
Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
1778-
} else if (auto *II = dyn_cast<IntrinsicInst>(Item[0].first);
1824+
if (isa<BinaryOperator>(FrontV) &&
1825+
!cast<BinaryOperator>(FrontV)->isIntDivRem()) {
1826+
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
1827+
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
1828+
} else if (isa<UnaryOperator>(FrontV)) {
1829+
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
1830+
} else if (auto *II = dyn_cast<IntrinsicInst>(FrontV);
17791831
II && isTriviallyVectorizable(II->getIntrinsicID())) {
17801832
for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
17811833
if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) {
1782-
if (!all_of(drop_begin(Item), [&](InstLane &IL) {
1783-
return !IL.first ||
1784-
(cast<Instruction>(IL.first)->getOperand(Op) ==
1785-
cast<Instruction>(Item[0].first)->getOperand(Op));
1834+
if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
1835+
Value *FrontV = Item.front().first;
1836+
Value *V = IL.first;
1837+
return !V || (cast<Instruction>(V)->getOperand(Op) ==
1838+
cast<Instruction>(FrontV)->getOperand(Op));
17861839
}))
17871840
return false;
17881841
continue;
17891842
}
1790-
Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, Op));
1843+
Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
17911844
}
17921845
} else {
17931846
return false;
@@ -1799,49 +1852,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
17991852

18001853
// If we got this far, we know the shuffles are superfluous and can be
18011854
// removed. Scan through again and generate the new tree of instructions.
1802-
std::function<Value *(ArrayRef<InstLane>)> Generate =
1803-
[&](ArrayRef<InstLane> Item) -> Value * {
1804-
if (IdentityLeafs.contains(Item[0].first) &&
1805-
all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
1806-
return !E.value().first || (E.value().first == Item[0].first &&
1807-
E.value().second == (int)E.index());
1808-
})) {
1809-
return Item[0].first;
1810-
}
1811-
if (SplatLeafs.contains(Item[0].first)) {
1812-
if (auto ILI = dyn_cast<Instruction>(Item[0].first))
1813-
Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef());
1814-
else if (isa<Argument>(Item[0].first))
1815-
Builder.SetInsertPointPastAllocas(I.getParent()->getParent());
1816-
SmallVector<int, 16> Mask(Ty->getNumElements(), Item[0].second);
1817-
return Builder.CreateShuffleVector(Item[0].first, Mask);
1818-
}
1819-
1820-
auto *I = cast<Instruction>(Item[0].first);
1821-
auto *II = dyn_cast<IntrinsicInst>(I);
1822-
unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
1823-
SmallVector<Value *> Ops(NumOps);
1824-
for (unsigned Idx = 0; Idx < NumOps; Idx++) {
1825-
if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx)) {
1826-
Ops[Idx] = II->getOperand(Idx);
1827-
continue;
1828-
}
1829-
Ops[Idx] = Generate(GenerateInstLaneVectorFromOperand(Item, Idx));
1830-
}
1831-
Builder.SetInsertPoint(I);
1832-
Type *DstTy = FixedVectorType::get(I->getType()->getScalarType(),
1833-
Ty->getNumElements());
1834-
if (auto BI = dyn_cast<BinaryOperator>(I))
1835-
return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
1836-
Ops[0], Ops[1]);
1837-
if (II)
1838-
return Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
1839-
assert(isa<UnaryInstruction>(I) &&
1840-
"Unexpected instruction type in Generate");
1841-
return Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
1842-
};
1843-
1844-
Value *V = Generate(Start);
1855+
Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, Builder);
18451856
replaceValue(I, *V);
18461857
return true;
18471858
}

0 commit comments

Comments
 (0)