-
Notifications
You must be signed in to change notification settings - Fork 13.5k
VectorCombine: refactor foldShuffleToIdentity (NFC) #92766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Lift out the long lambdas into static functions, use C++ destructing syntax, and fix other minor things to improve the readability of the function.
@llvm/pr-subscribers-llvm-transforms Author: Ramkumar Ramachandra (artagnon) ChangesLift out the long lambdas into static functions, use C++ destructing syntax, and fix other minor things to improve the readability of the function. Full diff: https://github.com/llvm/llvm-project/pull/92766.diff 1 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 15deaf908422d..5d45c012b4b87 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1668,6 +1668,86 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
return true;
}
+using InstLane = std::pair<Value *, int>;
+
+static InstLane lookThroughShuffles(Value *V, int Lane) {
+ while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
+ unsigned NumElts =
+ cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
+ int M = SV->getMaskValue(Lane);
+ if (M < 0)
+ return {nullptr, PoisonMaskElem};
+ if (static_cast<unsigned>(M) < NumElts) {
+ V = SV->getOperand(0);
+ Lane = M;
+ } else {
+ V = SV->getOperand(1);
+ Lane = M - NumElts;
+ }
+ }
+ return InstLane{V, Lane};
+}
+
+static SmallVector<InstLane>
+generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
+ SmallVector<InstLane> NItem;
+ for (InstLane IL : Item) {
+ auto [V, Lane] = IL;
+ InstLane OpLane =
+ V ? lookThroughShuffles(cast<Instruction>(V)->getOperand(Op), Lane)
+ : InstLane{nullptr, PoisonMaskElem};
+ NItem.emplace_back(OpLane);
+ }
+ return NItem;
+}
+
+static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
+ const SmallPtrSet<Value *, 4> &IdentityLeafs,
+ const SmallPtrSet<Value *, 4> &SplatLeafs,
+ IRBuilder<> &Builder) {
+ auto [FrontV, FrontLane] = Item.front();
+
+ if (IdentityLeafs.contains(FrontV) &&
+ all_of(drop_begin(enumerate(Item)), [Item](const auto &E) {
+ Value *FrontV = Item.front().first;
+ auto [V, Lane] = E.value();
+ return !V || (V == FrontV && Lane == (int)E.index());
+ })) {
+ return FrontV;
+ }
+ if (SplatLeafs.contains(FrontV)) {
+ if (auto *ILI = dyn_cast<Instruction>(FrontV))
+ Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef());
+ else if (auto *Arg = dyn_cast<Argument>(FrontV))
+ Builder.SetInsertPointPastAllocas(Arg->getParent());
+ SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
+ return Builder.CreateShuffleVector(FrontV, Mask);
+ }
+
+ auto *I = cast<Instruction>(FrontV);
+ auto *II = dyn_cast<IntrinsicInst>(I);
+ unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
+ SmallVector<Value *> Ops(NumOps);
+ for (unsigned Idx = 0; Idx < NumOps; Idx++) {
+ if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx)) {
+ Ops[Idx] = II->getOperand(Idx);
+ continue;
+ }
+ Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx),
+ Ty, IdentityLeafs, SplatLeafs, Builder);
+ }
+ Builder.SetInsertPoint(I);
+ Type *DstTy =
+ FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
+ if (auto *BI = dyn_cast<BinaryOperator>(I))
+ return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(), Ops[0],
+ Ops[1]);
+ if (II)
+ return Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
+ assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
+ return Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
+}
+
// Starting from a shuffle, look up through operands tracking the shuffled index
// of each lane. If we can simplify away the shuffles to identities then
// do so.
@@ -1677,42 +1757,9 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
!isa<Instruction>(I.getOperand(1)))
return false;
- using InstLane = std::pair<Value *, int>;
-
- auto LookThroughShuffles = [](Value *V, int Lane) -> InstLane {
- while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
- unsigned NumElts =
- cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
- int M = SV->getMaskValue(Lane);
- if (M < 0)
- return {nullptr, PoisonMaskElem};
- else if (M < (int)NumElts) {
- V = SV->getOperand(0);
- Lane = M;
- } else {
- V = SV->getOperand(1);
- Lane = M - NumElts;
- }
- }
- return InstLane{V, Lane};
- };
-
- auto GenerateInstLaneVectorFromOperand =
- [&LookThroughShuffles](ArrayRef<InstLane> Item, int Op) {
- SmallVector<InstLane> NItem;
- for (InstLane V : Item) {
- NItem.emplace_back(
- !V.first
- ? InstLane{nullptr, PoisonMaskElem}
- : LookThroughShuffles(
- cast<Instruction>(V.first)->getOperand(Op), V.second));
- }
- return NItem;
- };
-
SmallVector<InstLane> Start(Ty->getNumElements());
for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
- Start[M] = LookThroughShuffles(&I, M);
+ Start[M] = lookThroughShuffles(&I, M);
SmallVector<SmallVector<InstLane>> Worklist;
Worklist.push_back(Start);
@@ -1721,73 +1768,78 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
while (!Worklist.empty()) {
SmallVector<InstLane> Item = Worklist.pop_back_val();
+ auto [FrontV, FrontLane] = Item.front();
if (++NumVisited > MaxInstrsToScan)
return false;
// If we found an undef first lane then bail out to keep things simple.
- if (!Item[0].first)
+ if (!FrontV)
return false;
// Look for an identity value.
- if (Item[0].second == 0 &&
- cast<FixedVectorType>(Item[0].first->getType())->getNumElements() ==
+ if (!FrontLane &&
+ cast<FixedVectorType>(FrontV->getType())->getNumElements() ==
Ty->getNumElements() &&
- all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
- return !E.value().first || (E.value().first == Item[0].first &&
+ all_of(drop_begin(enumerate(Item)), [Item](const auto &E) {
+ Value *FrontV = Item.front().first;
+ return !E.value().first || (E.value().first == FrontV &&
E.value().second == (int)E.index());
})) {
- IdentityLeafs.insert(Item[0].first);
+ IdentityLeafs.insert(FrontV);
continue;
}
// Look for a splat value.
- if (all_of(drop_begin(Item), [&](InstLane &IL) {
- return !IL.first ||
- (IL.first == Item[0].first && IL.second == Item[0].second);
+ if (all_of(drop_begin(Item), [Item](InstLane &IL) {
+ auto [FrontV, FrontLane] = Item.front();
+ auto [V, Lane] = IL;
+ return !V || (V == FrontV && Lane == FrontLane);
})) {
- SplatLeafs.insert(Item[0].first);
+ SplatLeafs.insert(FrontV);
continue;
}
// We need each element to be the same type of value, and check that each
// element has a single use.
- if (!all_of(drop_begin(Item), [&](InstLane IL) {
- if (!IL.first)
+ if (!all_of(drop_begin(Item), [Item](InstLane IL) {
+ Value *FrontV = Item.front().first;
+ Value *V = IL.first;
+ if (!V)
return true;
- if (auto *I = dyn_cast<Instruction>(IL.first); I && !I->hasOneUse())
+ if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse())
return false;
- if (IL.first->getValueID() != Item[0].first->getValueID())
+ if (V->getValueID() != FrontV->getValueID())
return false;
- if (isa<CallInst>(IL.first) && !isa<IntrinsicInst>(IL.first))
+ if (isa<CallInst>(V) && !isa<IntrinsicInst>(V))
return false;
- auto *II = dyn_cast<IntrinsicInst>(IL.first);
- return !II ||
- (isa<IntrinsicInst>(Item[0].first) &&
- II->getIntrinsicID() ==
- cast<IntrinsicInst>(Item[0].first)->getIntrinsicID());
+ auto *II = dyn_cast<IntrinsicInst>(V);
+ return !II || (isa<IntrinsicInst>(FrontV) &&
+ II->getIntrinsicID() ==
+ cast<IntrinsicInst>(FrontV)->getIntrinsicID());
}))
return false;
// Check the operator is one that we support. We exclude div/rem in case
// they hit UB from poison lanes.
- if (isa<BinaryOperator>(Item[0].first) &&
- !cast<BinaryOperator>(Item[0].first)->isIntDivRem()) {
- Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
- Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 1));
- } else if (isa<UnaryOperator>(Item[0].first)) {
- Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
- } else if (auto *II = dyn_cast<IntrinsicInst>(Item[0].first);
+ if (isa<BinaryOperator>(FrontV) &&
+ !cast<BinaryOperator>(FrontV)->isIntDivRem()) {
+ Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
+ Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
+ } else if (isa<UnaryOperator>(FrontV)) {
+ Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
+ } else if (auto *II = dyn_cast<IntrinsicInst>(FrontV);
II && isTriviallyVectorizable(II->getIntrinsicID())) {
for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) {
- if (!all_of(drop_begin(Item), [&](InstLane &IL) {
- return !IL.first ||
- (cast<Instruction>(IL.first)->getOperand(Op) ==
- cast<Instruction>(Item[0].first)->getOperand(Op));
+ if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
+ Value *FrontV = Item.front().first;
+ Value *V = IL.first;
+ return !V || (cast<Instruction>(V)->getOperand(Op) ==
+ cast<Instruction>(FrontV)->getOperand(Op));
}))
return false;
continue;
}
- Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, Op));
+ Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
}
} else {
return false;
@@ -1799,49 +1851,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
// If we got this far, we know the shuffles are superfluous and can be
// removed. Scan through again and generate the new tree of instructions.
- std::function<Value *(ArrayRef<InstLane>)> Generate =
- [&](ArrayRef<InstLane> Item) -> Value * {
- if (IdentityLeafs.contains(Item[0].first) &&
- all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
- return !E.value().first || (E.value().first == Item[0].first &&
- E.value().second == (int)E.index());
- })) {
- return Item[0].first;
- }
- if (SplatLeafs.contains(Item[0].first)) {
- if (auto ILI = dyn_cast<Instruction>(Item[0].first))
- Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef());
- else if (isa<Argument>(Item[0].first))
- Builder.SetInsertPointPastAllocas(I.getParent()->getParent());
- SmallVector<int, 16> Mask(Ty->getNumElements(), Item[0].second);
- return Builder.CreateShuffleVector(Item[0].first, Mask);
- }
-
- auto *I = cast<Instruction>(Item[0].first);
- auto *II = dyn_cast<IntrinsicInst>(I);
- unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
- SmallVector<Value *> Ops(NumOps);
- for (unsigned Idx = 0; Idx < NumOps; Idx++) {
- if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx)) {
- Ops[Idx] = II->getOperand(Idx);
- continue;
- }
- Ops[Idx] = Generate(GenerateInstLaneVectorFromOperand(Item, Idx));
- }
- Builder.SetInsertPoint(I);
- Type *DstTy = FixedVectorType::get(I->getType()->getScalarType(),
- Ty->getNumElements());
- if (auto BI = dyn_cast<BinaryOperator>(I))
- return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
- Ops[0], Ops[1]);
- if (II)
- return Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
- assert(isa<UnaryInstruction>(I) &&
- "Unexpected instruction type in Generate");
- return Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
- };
-
- Value *V = Generate(Start);
+ Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, Builder);
replaceValue(I, *V);
return true;
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. The variable names part of this seems like an improvement. LGTM with a couple of minor suggestions.
Other than some additional checks needed for compare predicates and selects with scalar condition operands, these are relatively simple additions to what already exists. I will rebase over llvm#92766, but already had the patch for this version.
This just adds splat constants, which can be treated like any other splat which hopefully makes them very simple. I will rebase over llvm#92766, but already had the patch for this version.
Lift out the long lambdas into static functions, use C++ destructing syntax, and fix other minor things to improve the readability of the function.