20
20
#include "llvm/ADT/APInt.h"
21
21
#include "llvm/ADT/ArrayRef.h"
22
22
#include "llvm/ADT/DenseMap.h"
23
+ #include "llvm/ADT/IntervalMap.h"
23
24
#include "llvm/ADT/None.h"
24
25
#include "llvm/ADT/Optional.h"
25
26
#include "llvm/ADT/STLExtras.h"
@@ -490,6 +491,10 @@ namespace {
490
491
/// returns false.
491
492
bool findBetterNeighborChains(StoreSDNode *St);
492
493
494
+ // Helper for findBetterNeighborChains. Walk up store chain add additional
495
+ // chained stores that do not overlap and can be parallelized.
496
+ bool parallelizeChainedStores(StoreSDNode *St);
497
+
493
498
/// Holds a pointer to an LSBaseSDNode as well as information on where it
494
499
/// is located in a sequence of memory operations connected by a chain.
495
500
struct MemOpLink {
@@ -18905,6 +18910,11 @@ SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
18905
18910
return DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Aliases);
18906
18911
}
18907
18912
18913
+ // TODO: Replace with with std::monostate when we move to C++17.
18914
+ struct UnitT { } Unit;
18915
+ bool operator==(const UnitT &, const UnitT &) { return true; }
18916
+ bool operator!=(const UnitT &, const UnitT &) { return false; }
18917
+
18908
18918
// This function tries to collect a bunch of potentially interesting
18909
18919
// nodes to improve the chains of, all at once. This might seem
18910
18920
// redundant, as this function gets called when visiting every store
@@ -18917,13 +18927,22 @@ SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
18917
18927
// the nodes that will eventually be candidates, and then not be able
18918
18928
// to go from a partially-merged state to the desired final
18919
18929
// fully-merged state.
18920
- bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
18921
- if (OptLevel == CodeGenOpt::None)
18922
- return false;
18930
+
18931
+ bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
18932
+ SmallVector<StoreSDNode *, 8> ChainedStores;
18933
+ StoreSDNode *STChain = St;
18934
+ // Intervals records which offsets from BaseIndex have been covered. In
18935
+ // the common case, every store writes to the immediately previous address
18936
+ // space and thus merged with the previous interval at insertion time.
18937
+
18938
+ using IMap =
18939
+ llvm::IntervalMap<int64_t, UnitT, 8, IntervalMapHalfOpenInfo<int64_t>>;
18940
+ IMap::Allocator A;
18941
+ IMap Intervals(A);
18923
18942
18924
18943
// This holds the base pointer, index, and the offset in bytes from the base
18925
18944
// pointer.
18926
- BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
18945
+ const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
18927
18946
18928
18947
// We must have a base and an offset.
18929
18948
if (!BasePtr.getBase().getNode())
@@ -18933,76 +18952,114 @@ bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
18933
18952
if (BasePtr.getBase().isUndef())
18934
18953
return false;
18935
18954
18936
- SmallVector<StoreSDNode *, 8> ChainedStores;
18937
- ChainedStores.push_back(St );
18955
+ // Add ST's interval.
18956
+ Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8, Unit );
18938
18957
18939
- // Walk up the chain and look for nodes with offsets from the same
18940
- // base pointer. Stop when reaching an instruction with a different kind
18941
- // or instruction which has a different base pointer.
18942
- StoreSDNode *Index = St;
18943
- while (Index) {
18958
+ while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
18944
18959
// If the chain has more than one use, then we can't reorder the mem ops.
18945
- if (Index != St && ! SDValue(Index , 0)->hasOneUse())
18960
+ if (! SDValue(Chain , 0)->hasOneUse())
18946
18961
break;
18947
-
18948
- if (Index->isVolatile() || Index->isIndexed())
18962
+ if (Chain->isVolatile() || Chain->isIndexed())
18949
18963
break;
18950
18964
18951
18965
// Find the base pointer and offset for this memory node.
18952
- BaseIndexOffset Ptr = BaseIndexOffset::match(Index, DAG);
18953
-
18966
+ const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
18954
18967
// Check that the base pointer is the same as the original one.
18955
- if (!BasePtr.equalBaseIndex(Ptr, DAG))
18968
+ int64_t Offset;
18969
+ if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
18956
18970
break;
18971
+ int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
18972
+ // Make sure we don't overlap with other intervals by checking the ones to
18973
+ // the left or right before inserting.
18974
+ auto I = Intervals.find(Offset);
18975
+ // If there's a next interval, we should end before it.
18976
+ if (I != Intervals.end() && I.start() < (Offset + Length))
18977
+ break;
18978
+ // If there's a previous interval, we should start after it.
18979
+ if (I != Intervals.begin() && (--I).stop() <= Offset)
18980
+ break;
18981
+ Intervals.insert(Offset, Offset + Length, Unit);
18957
18982
18958
- // Walk up the chain to find the next store node, ignoring any
18959
- // intermediate loads. Any other kind of node will halt the loop.
18960
- SDNode *NextInChain = Index->getChain().getNode();
18961
- while (true) {
18962
- if (StoreSDNode *STn = dyn_cast<StoreSDNode>(NextInChain)) {
18963
- // We found a store node. Use it for the next iteration.
18964
- if (STn->isVolatile() || STn->isIndexed()) {
18965
- Index = nullptr;
18966
- break;
18967
- }
18968
- ChainedStores.push_back(STn);
18969
- Index = STn;
18970
- break;
18971
- } else if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(NextInChain)) {
18972
- NextInChain = Ldn->getChain().getNode();
18973
- continue;
18974
- } else {
18975
- Index = nullptr;
18976
- break;
18977
- }
18978
- }// end while
18983
+ ChainedStores.push_back(Chain);
18984
+ STChain = Chain;
18979
18985
}
18980
18986
18981
- // At this point, ChainedStores lists all of the Store nodes
18982
- // reachable by iterating up through chain nodes matching the above
18983
- // conditions. For each such store identified, try to find an
18984
- // earlier chain to attach the store to which won't violate the
18985
- // required ordering.
18986
- bool MadeChangeToSt = false;
18987
- SmallVector<std::pair<StoreSDNode *, SDValue>, 8> BetterChains;
18987
+ // If we didn't find a chained store, exit.
18988
+ if (ChainedStores.size() == 0)
18989
+ return false;
18990
+
18991
+ // Improve all chained stores (St and ChainedStores members) starting from
18992
+ // where the store chain ended and return single TokenFactor.
18993
+ SDValue NewChain = STChain->getChain();
18994
+ SmallVector<SDValue, 8> TFOps;
18995
+ for (unsigned I = ChainedStores.size(); I;) {
18996
+ StoreSDNode *S = ChainedStores[--I];
18997
+ SDValue BetterChain = FindBetterChain(S, NewChain);
18998
+ S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
18999
+ S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
19000
+ TFOps.push_back(SDValue(S, 0));
19001
+ ChainedStores[I] = S;
19002
+ }
19003
+
19004
+ // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
19005
+ SDValue BetterChain = FindBetterChain(St, NewChain);
19006
+ SDValue NewST;
19007
+ if (St->isTruncatingStore())
19008
+ NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
19009
+ St->getBasePtr(), St->getMemoryVT(),
19010
+ St->getMemOperand());
19011
+ else
19012
+ NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
19013
+ St->getBasePtr(), St->getMemOperand());
18988
19014
18989
- for (StoreSDNode *ChainedStore : ChainedStores) {
18990
- SDValue Chain = ChainedStore->getChain();
18991
- SDValue BetterChain = FindBetterChain(ChainedStore, Chain);
19015
+ TFOps.push_back(NewST);
18992
19016
18993
- if (Chain != BetterChain) {
18994
- if (ChainedStore == St)
18995
- MadeChangeToSt = true;
18996
- BetterChains.push_back(std::make_pair(ChainedStore, BetterChain));
18997
- }
18998
- }
19017
+ // If we improved every element of TFOps, then we've lost the dependence on
19018
+ // NewChain to successors of St and we need to add it back to TFOps. Do so at
19019
+ // the beginning to keep relative order consistent with FindBetterChains.
19020
+ auto hasImprovedChain = [&](SDValue ST) -> bool {
19021
+ return ST->getOperand(0) != NewChain;
19022
+ };
19023
+ bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
19024
+ if (AddNewChain)
19025
+ TFOps.insert(TFOps.begin(), NewChain);
19026
+
19027
+ SDValue TF = DAG.getNode(ISD::TokenFactor, SDLoc(STChain), MVT::Other, TFOps);
19028
+ CombineTo(St, TF);
19029
+
19030
+ AddToWorklist(STChain);
19031
+ // Add TF operands worklist in reverse order.
19032
+ for (auto I = TF->getNumOperands(); I;)
19033
+ AddToWorklist(TF->getOperand(--I).getNode());
19034
+ AddToWorklist(TF.getNode());
19035
+ return true;
19036
+ }
19037
+
19038
+ bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
19039
+ if (OptLevel == CodeGenOpt::None)
19040
+ return false;
19041
+
19042
+ const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
18999
19043
19000
- // Do all replacements after finding the replacements to make to avoid making
19001
- // the chains more complicated by introducing new TokenFactors.
19002
- for (auto Replacement : BetterChains)
19003
- replaceStoreChain(Replacement.first, Replacement.second);
19044
+ // We must have a base and an offset.
19045
+ if (!BasePtr.getBase().getNode())
19046
+ return false;
19047
+
19048
+ // Do not handle stores to undef base pointers.
19049
+ if (BasePtr.getBase().isUndef())
19050
+ return false;
19051
+
19052
+ // Directly improve a chain of disjoint stores starting at St.
19053
+ if (parallelizeChainedStores(St))
19054
+ return true;
19004
19055
19005
- return MadeChangeToSt;
19056
+ // Improve St's Chain..
19057
+ SDValue BetterChain = FindBetterChain(St, St->getChain());
19058
+ if (St->getChain() != BetterChain) {
19059
+ replaceStoreChain(St, BetterChain);
19060
+ return true;
19061
+ }
19062
+ return false;
19006
19063
}
19007
19064
19008
19065
/// This is the entry point for the file.
0 commit comments