Skip to content

VectorCombine: refactor foldShuffleToIdentity (NFC) #92766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2024

Conversation

artagnon
Copy link
Contributor

Lift out the long lambdas into static functions, use C++ destructing syntax, and fix other minor things to improve the readability of the function.

Lift out the long lambdas into static functions, use C++ destructing
syntax, and fix other minor things to improve the readability of the
function.
@llvmbot
Copy link
Member

llvmbot commented May 20, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Ramkumar Ramachandra (artagnon)

Changes

Lift out the long lambdas into static functions, use C++ destructing syntax, and fix other minor things to improve the readability of the function.


Full diff: https://github.com/llvm/llvm-project/pull/92766.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+119-109)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 15deaf908422d..5d45c012b4b87 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1668,6 +1668,86 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
   return true;
 }
 
+using InstLane = std::pair<Value *, int>;
+
+static InstLane lookThroughShuffles(Value *V, int Lane) {
+  while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
+    unsigned NumElts =
+        cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
+    int M = SV->getMaskValue(Lane);
+    if (M < 0)
+      return {nullptr, PoisonMaskElem};
+    if (static_cast<unsigned>(M) < NumElts) {
+      V = SV->getOperand(0);
+      Lane = M;
+    } else {
+      V = SV->getOperand(1);
+      Lane = M - NumElts;
+    }
+  }
+  return InstLane{V, Lane};
+}
+
+static SmallVector<InstLane>
+generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
+  SmallVector<InstLane> NItem;
+  for (InstLane IL : Item) {
+    auto [V, Lane] = IL;
+    InstLane OpLane =
+        V ? lookThroughShuffles(cast<Instruction>(V)->getOperand(Op), Lane)
+          : InstLane{nullptr, PoisonMaskElem};
+    NItem.emplace_back(OpLane);
+  }
+  return NItem;
+}
+
+static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
+                                  const SmallPtrSet<Value *, 4> &IdentityLeafs,
+                                  const SmallPtrSet<Value *, 4> &SplatLeafs,
+                                  IRBuilder<> &Builder) {
+  auto [FrontV, FrontLane] = Item.front();
+
+  if (IdentityLeafs.contains(FrontV) &&
+      all_of(drop_begin(enumerate(Item)), [Item](const auto &E) {
+        Value *FrontV = Item.front().first;
+        auto [V, Lane] = E.value();
+        return !V || (V == FrontV && Lane == (int)E.index());
+      })) {
+    return FrontV;
+  }
+  if (SplatLeafs.contains(FrontV)) {
+    if (auto *ILI = dyn_cast<Instruction>(FrontV))
+      Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef());
+    else if (auto *Arg = dyn_cast<Argument>(FrontV))
+      Builder.SetInsertPointPastAllocas(Arg->getParent());
+    SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
+    return Builder.CreateShuffleVector(FrontV, Mask);
+  }
+
+  auto *I = cast<Instruction>(FrontV);
+  auto *II = dyn_cast<IntrinsicInst>(I);
+  unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
+  SmallVector<Value *> Ops(NumOps);
+  for (unsigned Idx = 0; Idx < NumOps; Idx++) {
+    if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx)) {
+      Ops[Idx] = II->getOperand(Idx);
+      continue;
+    }
+    Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx),
+                                   Ty, IdentityLeafs, SplatLeafs, Builder);
+  }
+  Builder.SetInsertPoint(I);
+  Type *DstTy =
+      FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
+  if (auto *BI = dyn_cast<BinaryOperator>(I))
+    return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(), Ops[0],
+                               Ops[1]);
+  if (II)
+    return Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
+  assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
+  return Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
+}
+
 // Starting from a shuffle, look up through operands tracking the shuffled index
 // of each lane. If we can simplify away the shuffles to identities then
 // do so.
@@ -1677,42 +1757,9 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
       !isa<Instruction>(I.getOperand(1)))
     return false;
 
-  using InstLane = std::pair<Value *, int>;
-
-  auto LookThroughShuffles = [](Value *V, int Lane) -> InstLane {
-    while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
-      unsigned NumElts =
-          cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
-      int M = SV->getMaskValue(Lane);
-      if (M < 0)
-        return {nullptr, PoisonMaskElem};
-      else if (M < (int)NumElts) {
-        V = SV->getOperand(0);
-        Lane = M;
-      } else {
-        V = SV->getOperand(1);
-        Lane = M - NumElts;
-      }
-    }
-    return InstLane{V, Lane};
-  };
-
-  auto GenerateInstLaneVectorFromOperand =
-      [&LookThroughShuffles](ArrayRef<InstLane> Item, int Op) {
-        SmallVector<InstLane> NItem;
-        for (InstLane V : Item) {
-          NItem.emplace_back(
-              !V.first
-                  ? InstLane{nullptr, PoisonMaskElem}
-                  : LookThroughShuffles(
-                        cast<Instruction>(V.first)->getOperand(Op), V.second));
-        }
-        return NItem;
-      };
-
   SmallVector<InstLane> Start(Ty->getNumElements());
   for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
-    Start[M] = LookThroughShuffles(&I, M);
+    Start[M] = lookThroughShuffles(&I, M);
 
   SmallVector<SmallVector<InstLane>> Worklist;
   Worklist.push_back(Start);
@@ -1721,73 +1768,78 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
 
   while (!Worklist.empty()) {
     SmallVector<InstLane> Item = Worklist.pop_back_val();
+    auto [FrontV, FrontLane] = Item.front();
     if (++NumVisited > MaxInstrsToScan)
       return false;
 
     // If we found an undef first lane then bail out to keep things simple.
-    if (!Item[0].first)
+    if (!FrontV)
       return false;
 
     // Look for an identity value.
-    if (Item[0].second == 0 &&
-        cast<FixedVectorType>(Item[0].first->getType())->getNumElements() ==
+    if (!FrontLane &&
+        cast<FixedVectorType>(FrontV->getType())->getNumElements() ==
             Ty->getNumElements() &&
-        all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
-          return !E.value().first || (E.value().first == Item[0].first &&
+        all_of(drop_begin(enumerate(Item)), [Item](const auto &E) {
+          Value *FrontV = Item.front().first;
+          return !E.value().first || (E.value().first == FrontV &&
                                       E.value().second == (int)E.index());
         })) {
-      IdentityLeafs.insert(Item[0].first);
+      IdentityLeafs.insert(FrontV);
       continue;
     }
     // Look for a splat value.
-    if (all_of(drop_begin(Item), [&](InstLane &IL) {
-          return !IL.first ||
-                 (IL.first == Item[0].first && IL.second == Item[0].second);
+    if (all_of(drop_begin(Item), [Item](InstLane &IL) {
+          auto [FrontV, FrontLane] = Item.front();
+          auto [V, Lane] = IL;
+          return !V || (V == FrontV && Lane == FrontLane);
         })) {
-      SplatLeafs.insert(Item[0].first);
+      SplatLeafs.insert(FrontV);
       continue;
     }
 
     // We need each element to be the same type of value, and check that each
     // element has a single use.
-    if (!all_of(drop_begin(Item), [&](InstLane IL) {
-          if (!IL.first)
+    if (!all_of(drop_begin(Item), [Item](InstLane IL) {
+          Value *FrontV = Item.front().first;
+          Value *V = IL.first;
+          if (!V)
             return true;
-          if (auto *I = dyn_cast<Instruction>(IL.first); I && !I->hasOneUse())
+          if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse())
             return false;
-          if (IL.first->getValueID() != Item[0].first->getValueID())
+          if (V->getValueID() != FrontV->getValueID())
             return false;
-          if (isa<CallInst>(IL.first) && !isa<IntrinsicInst>(IL.first))
+          if (isa<CallInst>(V) && !isa<IntrinsicInst>(V))
             return false;
-          auto *II = dyn_cast<IntrinsicInst>(IL.first);
-          return !II ||
-                 (isa<IntrinsicInst>(Item[0].first) &&
-                  II->getIntrinsicID() ==
-                      cast<IntrinsicInst>(Item[0].first)->getIntrinsicID());
+          auto *II = dyn_cast<IntrinsicInst>(V);
+          return !II || (isa<IntrinsicInst>(FrontV) &&
+                         II->getIntrinsicID() ==
+                             cast<IntrinsicInst>(FrontV)->getIntrinsicID());
         }))
       return false;
 
     // Check the operator is one that we support. We exclude div/rem in case
     // they hit UB from poison lanes.
-    if (isa<BinaryOperator>(Item[0].first) &&
-        !cast<BinaryOperator>(Item[0].first)->isIntDivRem()) {
-      Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
-      Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 1));
-    } else if (isa<UnaryOperator>(Item[0].first)) {
-      Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
-    } else if (auto *II = dyn_cast<IntrinsicInst>(Item[0].first);
+    if (isa<BinaryOperator>(FrontV) &&
+        !cast<BinaryOperator>(FrontV)->isIntDivRem()) {
+      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
+      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
+    } else if (isa<UnaryOperator>(FrontV)) {
+      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
+    } else if (auto *II = dyn_cast<IntrinsicInst>(FrontV);
                II && isTriviallyVectorizable(II->getIntrinsicID())) {
       for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
         if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) {
-          if (!all_of(drop_begin(Item), [&](InstLane &IL) {
-                return !IL.first ||
-                       (cast<Instruction>(IL.first)->getOperand(Op) ==
-                        cast<Instruction>(Item[0].first)->getOperand(Op));
+          if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
+                Value *FrontV = Item.front().first;
+                Value *V = IL.first;
+                return !V || (cast<Instruction>(V)->getOperand(Op) ==
+                              cast<Instruction>(FrontV)->getOperand(Op));
               }))
             return false;
           continue;
         }
-        Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, Op));
+        Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
       }
     } else {
       return false;
@@ -1799,49 +1851,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
 
   // If we got this far, we know the shuffles are superfluous and can be
   // removed. Scan through again and generate the new tree of instructions.
-  std::function<Value *(ArrayRef<InstLane>)> Generate =
-      [&](ArrayRef<InstLane> Item) -> Value * {
-    if (IdentityLeafs.contains(Item[0].first) &&
-        all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
-          return !E.value().first || (E.value().first == Item[0].first &&
-                                      E.value().second == (int)E.index());
-        })) {
-      return Item[0].first;
-    }
-    if (SplatLeafs.contains(Item[0].first)) {
-      if (auto ILI = dyn_cast<Instruction>(Item[0].first))
-        Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef());
-      else if (isa<Argument>(Item[0].first))
-        Builder.SetInsertPointPastAllocas(I.getParent()->getParent());
-      SmallVector<int, 16> Mask(Ty->getNumElements(), Item[0].second);
-      return Builder.CreateShuffleVector(Item[0].first, Mask);
-    }
-
-    auto *I = cast<Instruction>(Item[0].first);
-    auto *II = dyn_cast<IntrinsicInst>(I);
-    unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
-    SmallVector<Value *> Ops(NumOps);
-    for (unsigned Idx = 0; Idx < NumOps; Idx++) {
-      if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx)) {
-        Ops[Idx] = II->getOperand(Idx);
-        continue;
-      }
-      Ops[Idx] = Generate(GenerateInstLaneVectorFromOperand(Item, Idx));
-    }
-    Builder.SetInsertPoint(I);
-    Type *DstTy = FixedVectorType::get(I->getType()->getScalarType(),
-                                       Ty->getNumElements());
-    if (auto BI = dyn_cast<BinaryOperator>(I))
-      return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
-                                 Ops[0], Ops[1]);
-    if (II)
-      return Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
-    assert(isa<UnaryInstruction>(I) &&
-           "Unexpected instruction type in Generate");
-    return Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
-  };
-
-  Value *V = Generate(Start);
+  Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, Builder);
   replaceValue(I, *V);
   return true;
 }

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. The variable names part of this seems like an improvement. LGTM with a couple of minor suggestions.

davemgreen added a commit to davemgreen/llvm-project that referenced this pull request May 20, 2024
Other than some additional checks needed for compare predicates and selects
with scalar condition operands, these are relatively simple additions to what
already exists.

I will rebase over llvm#92766, but already had the patch for this version.
davemgreen added a commit to davemgreen/llvm-project that referenced this pull request May 20, 2024
This just adds splat constants, which can be treated like any other splat which
hopefully makes them very simple.

I will rebase over llvm#92766, but already had the patch for this version.
@artagnon artagnon merged commit e3fa7ee into llvm:main May 21, 2024
3 of 4 checks passed
@artagnon artagnon deleted the vc-idshuffle-nfc branch May 21, 2024 07:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants