Skip to content

Commit b90fcaf

Browse files
authored
[CodeLayout][NFC] Using MergedVector to avoid extra vector allocations (#68724)
Using a wrapper (MergedVector) around vectors to avoid extra vector allocations. Plus a few edits in the comments.
1 parent 8da1e3d commit b90fcaf

File tree

1 file changed

+77
-62
lines changed

1 file changed

+77
-62
lines changed

llvm/lib/Transforms/Utils/CodeLayout.cpp

Lines changed: 77 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ static cl::opt<unsigned> BackwardDistance(
9999
cl::desc("The maximum distance (in bytes) of a backward jump for ExtTSP"));
100100

101101
// The maximum size of a chain created by the algorithm. The size is bounded
102-
// so that the algorithm can efficiently process extremely large instance.
102+
// so that the algorithm can efficiently process extremely large instances.
103103
static cl::opt<unsigned>
104104
MaxChainSize("ext-tsp-max-chain-size", cl::ReallyHidden, cl::init(4096),
105105
cl::desc("The maximum size of a chain to create."));
@@ -217,8 +217,8 @@ struct NodeT {
217217
NodeT &operator=(const NodeT &) = delete;
218218
NodeT &operator=(NodeT &&) = default;
219219

220-
explicit NodeT(size_t Index, uint64_t Size, uint64_t EC)
221-
: Index(Index), Size(Size), ExecutionCount(EC) {}
220+
explicit NodeT(size_t Index, uint64_t Size, uint64_t Count)
221+
: Index(Index), Size(Size), ExecutionCount(Count) {}
222222

223223
bool isEntry() const { return Index == 0; }
224224

@@ -477,12 +477,12 @@ void ChainT::mergeEdges(ChainT *Other) {
477477

478478
using NodeIter = std::vector<NodeT *>::const_iterator;
479479

480-
/// A wrapper around three chains of nodes; it is used to avoid extra
481-
/// instantiation of the vectors.
482-
struct MergedChain {
483-
MergedChain(NodeIter Begin1, NodeIter End1, NodeIter Begin2 = NodeIter(),
484-
NodeIter End2 = NodeIter(), NodeIter Begin3 = NodeIter(),
485-
NodeIter End3 = NodeIter())
480+
/// A wrapper around three concatenated vectors (chains) of nodes; it is used
481+
/// to avoid extra instantiation of the vectors.
482+
struct MergedNodesT {
483+
MergedNodesT(NodeIter Begin1, NodeIter End1, NodeIter Begin2 = NodeIter(),
484+
NodeIter End2 = NodeIter(), NodeIter Begin3 = NodeIter(),
485+
NodeIter End3 = NodeIter())
486486
: Begin1(Begin1), End1(End1), Begin2(Begin2), End2(End2), Begin3(Begin3),
487487
End3(End3) {}
488488

@@ -507,6 +507,8 @@ struct MergedChain {
507507

508508
const NodeT *getFirstNode() const { return *Begin1; }
509509

510+
bool empty() const { return Begin1 == End1; }
511+
510512
private:
511513
NodeIter Begin1;
512514
NodeIter End1;
@@ -516,14 +518,34 @@ struct MergedChain {
516518
NodeIter End3;
517519
};
518520

521+
/// A wrapper around two concatenated vectors (chains) of jumps.
522+
struct MergedJumpsT {
523+
MergedJumpsT(const std::vector<JumpT *> *Jumps1,
524+
const std::vector<JumpT *> *Jumps2 = nullptr) {
525+
assert(!Jumps1->empty() && "cannot merge empty jump list");
526+
JumpArray[0] = Jumps1;
527+
JumpArray[1] = Jumps2;
528+
}
529+
530+
template <typename F> void forEach(const F &Func) const {
531+
for (auto Jumps : JumpArray)
532+
if (Jumps != nullptr)
533+
for (JumpT *Jump : *Jumps)
534+
Func(Jump);
535+
}
536+
537+
private:
538+
std::array<const std::vector<JumpT *> *, 2> JumpArray{nullptr, nullptr};
539+
};
540+
519541
/// Merge two chains of nodes respecting a given 'type' and 'offset'.
520542
///
521543
/// If MergeType == 0, then the result is a concatenation of two chains.
522544
/// Otherwise, the first chain is cut into two sub-chains at the offset,
523545
/// and merged using all possible ways of concatenating three chains.
524-
MergedChain mergeNodes(const std::vector<NodeT *> &X,
525-
const std::vector<NodeT *> &Y, size_t MergeOffset,
526-
MergeTypeT MergeType) {
546+
MergedNodesT mergeNodes(const std::vector<NodeT *> &X,
547+
const std::vector<NodeT *> &Y, size_t MergeOffset,
548+
MergeTypeT MergeType) {
527549
// Split the first chain, X, into X1 and X2.
528550
NodeIter BeginX1 = X.begin();
529551
NodeIter EndX1 = X.begin() + MergeOffset;
@@ -535,15 +557,15 @@ MergedChain mergeNodes(const std::vector<NodeT *> &X,
535557
// Construct a new chain from the three existing ones.
536558
switch (MergeType) {
537559
case MergeTypeT::X_Y:
538-
return MergedChain(BeginX1, EndX2, BeginY, EndY);
560+
return MergedNodesT(BeginX1, EndX2, BeginY, EndY);
539561
case MergeTypeT::Y_X:
540-
return MergedChain(BeginY, EndY, BeginX1, EndX2);
562+
return MergedNodesT(BeginY, EndY, BeginX1, EndX2);
541563
case MergeTypeT::X1_Y_X2:
542-
return MergedChain(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2);
564+
return MergedNodesT(BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2);
543565
case MergeTypeT::Y_X2_X1:
544-
return MergedChain(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1);
566+
return MergedNodesT(BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1);
545567
case MergeTypeT::X2_X1_Y:
546-
return MergedChain(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY);
568+
return MergedNodesT(BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY);
547569
}
548570
llvm_unreachable("unexpected chain merge type");
549571
}
@@ -618,6 +640,7 @@ class ExtTSPImpl {
618640
AllChains.reserve(NumNodes);
619641
HotChains.reserve(NumNodes);
620642
for (NodeT &Node : AllNodes) {
643+
// Create a chain.
621644
AllChains.emplace_back(Node.Index, &Node);
622645
Node.CurChain = &AllChains.back();
623646
if (Node.ExecutionCount > 0)
@@ -630,13 +653,13 @@ class ExtTSPImpl {
630653
for (JumpT *Jump : PredNode.OutJumps) {
631654
NodeT *SuccNode = Jump->Target;
632655
ChainEdge *CurEdge = PredNode.CurChain->getEdge(SuccNode->CurChain);
633-
// this edge is already present in the graph.
656+
// This edge is already present in the graph.
634657
if (CurEdge != nullptr) {
635658
assert(SuccNode->CurChain->getEdge(PredNode.CurChain) != nullptr);
636659
CurEdge->appendJump(Jump);
637660
continue;
638661
}
639-
// this is a new edge.
662+
// This is a new edge.
640663
AllEdges.emplace_back(Jump);
641664
PredNode.CurChain->addEdge(SuccNode->CurChain, &AllEdges.back());
642665
SuccNode->CurChain->addEdge(PredNode.CurChain, &AllEdges.back());
@@ -649,7 +672,7 @@ class ExtTSPImpl {
649672
/// to B are from A. Such nodes should be adjacent in the optimal ordering;
650673
/// the method finds and merges such pairs of nodes.
651674
void mergeForcedPairs() {
652-
// Find fallthroughs based on edge weights.
675+
// Find forced pairs of blocks.
653676
for (NodeT &Node : AllNodes) {
654677
if (SuccNodes[Node.Index].size() == 1 &&
655678
PredNodes[SuccNodes[Node.Index][0]].size() == 1 &&
@@ -699,9 +722,7 @@ class ExtTSPImpl {
699722
/// Deterministically compare pairs of chains.
700723
auto compareChainPairs = [](const ChainT *A1, const ChainT *B1,
701724
const ChainT *A2, const ChainT *B2) {
702-
if (A1 != A2)
703-
return A1->Id < A2->Id;
704-
return B1->Id < B2->Id;
725+
return std::make_tuple(A1->Id, B1->Id) < std::make_tuple(A2->Id, B2->Id);
705726
};
706727

707728
while (HotChains.size() > 1) {
@@ -769,24 +790,22 @@ class ExtTSPImpl {
769790
}
770791

771792
/// Compute the Ext-TSP score for a given node order and a list of jumps.
772-
double extTSPScore(const MergedChain &MergedBlocks,
773-
const std::vector<JumpT *> &Jumps) const {
774-
if (Jumps.empty())
775-
return 0.0;
793+
double extTSPScore(const MergedNodesT &Nodes,
794+
const MergedJumpsT &Jumps) const {
776795
uint64_t CurAddr = 0;
777-
MergedBlocks.forEach([&](const NodeT *Node) {
796+
Nodes.forEach([&](const NodeT *Node) {
778797
Node->EstimatedAddr = CurAddr;
779798
CurAddr += Node->Size;
780799
});
781800

782801
double Score = 0;
783-
for (JumpT *Jump : Jumps) {
802+
Jumps.forEach([&](const JumpT *Jump) {
784803
const NodeT *SrcBlock = Jump->Source;
785804
const NodeT *DstBlock = Jump->Target;
786805
Score += ::extTSPScore(SrcBlock->EstimatedAddr, SrcBlock->Size,
787806
DstBlock->EstimatedAddr, Jump->ExecutionCount,
788807
Jump->IsConditional);
789-
}
808+
});
790809
return Score;
791810
}
792811

@@ -798,17 +817,13 @@ class ExtTSPImpl {
798817
/// element being the corresponding merging type.
799818
MergeGainT getBestMergeGain(ChainT *ChainPred, ChainT *ChainSucc,
800819
ChainEdge *Edge) const {
801-
if (Edge->hasCachedMergeGain(ChainPred, ChainSucc)) {
820+
if (Edge->hasCachedMergeGain(ChainPred, ChainSucc))
802821
return Edge->getCachedMergeGain(ChainPred, ChainSucc);
803-
}
804822

823+
assert(!Edge->jumps().empty() && "trying to merge chains w/o jumps");
805824
// Precompute jumps between ChainPred and ChainSucc.
806-
auto Jumps = Edge->jumps();
807825
ChainEdge *EdgePP = ChainPred->getEdge(ChainPred);
808-
if (EdgePP != nullptr) {
809-
Jumps.insert(Jumps.end(), EdgePP->jumps().begin(), EdgePP->jumps().end());
810-
}
811-
assert(!Jumps.empty() && "trying to merge chains w/o jumps");
826+
MergedJumpsT Jumps(&Edge->jumps(), EdgePP ? &EdgePP->jumps() : nullptr);
812827

813828
// This object holds the best chosen gain of merging two chains.
814829
MergeGainT Gain = MergeGainT();
@@ -875,19 +890,20 @@ class ExtTSPImpl {
875890
///
876891
/// The two chains are not modified in the method.
877892
MergeGainT computeMergeGain(const ChainT *ChainPred, const ChainT *ChainSucc,
878-
const std::vector<JumpT *> &Jumps,
879-
size_t MergeOffset, MergeTypeT MergeType) const {
880-
auto MergedBlocks =
893+
const MergedJumpsT &Jumps, size_t MergeOffset,
894+
MergeTypeT MergeType) const {
895+
MergedNodesT MergedNodes =
881896
mergeNodes(ChainPred->Nodes, ChainSucc->Nodes, MergeOffset, MergeType);
882897

883898
// Do not allow a merge that does not preserve the original entry point.
884899
if ((ChainPred->isEntry() || ChainSucc->isEntry()) &&
885-
!MergedBlocks.getFirstNode()->isEntry())
900+
!MergedNodes.getFirstNode()->isEntry())
886901
return MergeGainT();
887902

888903
// The gain for the new chain.
889-
auto NewGainScore = extTSPScore(MergedBlocks, Jumps) - ChainPred->Score;
890-
return MergeGainT(NewGainScore, MergeOffset, MergeType);
904+
double NewScore = extTSPScore(MergedNodes, Jumps);
905+
double CurScore = ChainPred->Score;
906+
return MergeGainT(NewScore - CurScore, MergeOffset, MergeType);
891907
}
892908

893909
/// Merge chain From into chain Into, update the list of active chains,
@@ -897,7 +913,7 @@ class ExtTSPImpl {
897913
assert(Into != From && "a chain cannot be merged with itself");
898914

899915
// Merge the nodes.
900-
MergedChain MergedNodes =
916+
MergedNodesT MergedNodes =
901917
mergeNodes(Into->Nodes, From->Nodes, MergeOffset, MergeType);
902918
Into->merge(From, MergedNodes.getNodes());
903919

@@ -908,8 +924,9 @@ class ExtTSPImpl {
908924
// Update cached ext-tsp score for the new chain.
909925
ChainEdge *SelfEdge = Into->getEdge(Into);
910926
if (SelfEdge != nullptr) {
911-
MergedNodes = MergedChain(Into->Nodes.begin(), Into->Nodes.end());
912-
Into->Score = extTSPScore(MergedNodes, SelfEdge->jumps());
927+
MergedNodes = MergedNodesT(Into->Nodes.begin(), Into->Nodes.end());
928+
MergedJumpsT MergedJumps(&SelfEdge->jumps());
929+
Into->Score = extTSPScore(MergedNodes, MergedJumps);
913930
}
914931

915932
// Remove the chain from the list of active chains.
@@ -943,7 +960,7 @@ class ExtTSPImpl {
943960
// Sorting chains by density in the decreasing order.
944961
std::sort(SortedChains.begin(), SortedChains.end(),
945962
[&](const ChainT *L, const ChainT *R) {
946-
// Place the entry point is at the beginning of the order.
963+
// Place the entry point at the beginning of the order.
947964
if (L->isEntry() != R->isEntry())
948965
return L->isEntry();
949966

@@ -1163,9 +1180,9 @@ class CDSortImpl {
11631180
/// result is a pair with the first element being the gain and the second
11641181
/// element being the corresponding merging type.
11651182
MergeGainT getBestMergeGain(ChainEdge *Edge) const {
1183+
assert(!Edge->jumps().empty() && "trying to merge chains w/o jumps");
11661184
// Precompute jumps between ChainPred and ChainSucc.
1167-
auto Jumps = Edge->jumps();
1168-
assert(!Jumps.empty() && "trying to merge chains w/o jumps");
1185+
MergedJumpsT Jumps(&Edge->jumps());
11691186
ChainT *SrcChain = Edge->srcChain();
11701187
ChainT *DstChain = Edge->dstChain();
11711188

@@ -1204,7 +1221,7 @@ class CDSortImpl {
12041221
///
12051222
/// The two chains are not modified in the method.
12061223
MergeGainT computeMergeGain(ChainT *ChainPred, ChainT *ChainSucc,
1207-
const std::vector<JumpT *> &Jumps,
1224+
const MergedJumpsT &Jumps,
12081225
MergeTypeT MergeType) const {
12091226
// This doesn't depend on the ordering of the nodes
12101227
double FreqGain = freqBasedLocalityGain(ChainPred, ChainSucc);
@@ -1255,24 +1272,22 @@ class CDSortImpl {
12551272
}
12561273

12571274
/// Compute the change of the distance locality after merging the chains.
1258-
double distBasedLocalityGain(const MergedChain &MergedBlocks,
1259-
const std::vector<JumpT *> &Jumps) const {
1260-
if (Jumps.empty())
1261-
return 0.0;
1275+
double distBasedLocalityGain(const MergedNodesT &Nodes,
1276+
const MergedJumpsT &Jumps) const {
12621277
uint64_t CurAddr = 0;
1263-
MergedBlocks.forEach([&](const NodeT *Node) {
1278+
Nodes.forEach([&](const NodeT *Node) {
12641279
Node->EstimatedAddr = CurAddr;
12651280
CurAddr += Node->Size;
12661281
});
12671282

12681283
double CurScore = 0;
12691284
double NewScore = 0;
1270-
for (const JumpT *Arc : Jumps) {
1271-
uint64_t SrcAddr = Arc->Source->EstimatedAddr + Arc->Offset;
1272-
uint64_t DstAddr = Arc->Target->EstimatedAddr;
1273-
NewScore += distScore(SrcAddr, DstAddr, Arc->ExecutionCount);
1274-
CurScore += distScore(0, TotalSize, Arc->ExecutionCount);
1275-
}
1285+
Jumps.forEach([&](const JumpT *Jump) {
1286+
uint64_t SrcAddr = Jump->Source->EstimatedAddr + Jump->Offset;
1287+
uint64_t DstAddr = Jump->Target->EstimatedAddr;
1288+
NewScore += distScore(SrcAddr, DstAddr, Jump->ExecutionCount);
1289+
CurScore += distScore(0, TotalSize, Jump->ExecutionCount);
1290+
});
12761291
return NewScore - CurScore;
12771292
}
12781293

@@ -1283,7 +1298,7 @@ class CDSortImpl {
12831298
assert(Into != From && "a chain cannot be merged with itself");
12841299

12851300
// Merge the nodes.
1286-
MergedChain MergedNodes =
1301+
MergedNodesT MergedNodes =
12871302
mergeNodes(Into->Nodes, From->Nodes, MergeOffset, MergeType);
12881303
Into->merge(From, MergedNodes.getNodes());
12891304

0 commit comments

Comments
 (0)