@@ -99,7 +99,7 @@ static cl::opt<unsigned> BackwardDistance(
99
99
cl::desc(" The maximum distance (in bytes) of a backward jump for ExtTSP" ));
100
100
101
101
// The maximum size of a chain created by the algorithm. The size is bounded
102
- // so that the algorithm can efficiently process extremely large instance .
102
+ // so that the algorithm can efficiently process extremely large instances .
103
103
static cl::opt<unsigned >
104
104
MaxChainSize (" ext-tsp-max-chain-size" , cl::ReallyHidden, cl::init(4096 ),
105
105
cl::desc(" The maximum size of a chain to create." ));
@@ -217,8 +217,8 @@ struct NodeT {
217
217
NodeT &operator =(const NodeT &) = delete ;
218
218
NodeT &operator =(NodeT &&) = default ;
219
219
220
- explicit NodeT (size_t Index, uint64_t Size , uint64_t EC )
221
- : Index(Index), Size(Size ), ExecutionCount(EC ) {}
220
+ explicit NodeT (size_t Index, uint64_t Size , uint64_t Count )
221
+ : Index(Index), Size(Size ), ExecutionCount(Count ) {}
222
222
223
223
bool isEntry () const { return Index == 0 ; }
224
224
@@ -477,12 +477,12 @@ void ChainT::mergeEdges(ChainT *Other) {
477
477
478
478
using NodeIter = std::vector<NodeT *>::const_iterator;
479
479
480
- // / A wrapper around three chains of nodes; it is used to avoid extra
481
- // / instantiation of the vectors.
482
- struct MergedChain {
483
- MergedChain (NodeIter Begin1, NodeIter End1, NodeIter Begin2 = NodeIter(),
484
- NodeIter End2 = NodeIter(), NodeIter Begin3 = NodeIter(),
485
- NodeIter End3 = NodeIter())
480
+ // / A wrapper around three concatenated vectors ( chains) of nodes; it is used
481
+ // / to avoid extra instantiation of the vectors.
482
+ struct MergedNodesT {
483
+ MergedNodesT (NodeIter Begin1, NodeIter End1, NodeIter Begin2 = NodeIter(),
484
+ NodeIter End2 = NodeIter(), NodeIter Begin3 = NodeIter(),
485
+ NodeIter End3 = NodeIter())
486
486
: Begin1(Begin1), End1(End1), Begin2(Begin2), End2(End2), Begin3(Begin3),
487
487
End3 (End3) {}
488
488
@@ -507,6 +507,8 @@ struct MergedChain {
507
507
508
508
const NodeT *getFirstNode () const { return *Begin1; }
509
509
510
+ bool empty () const { return Begin1 == End1; }
511
+
510
512
private:
511
513
NodeIter Begin1;
512
514
NodeIter End1;
@@ -516,14 +518,34 @@ struct MergedChain {
516
518
NodeIter End3;
517
519
};
518
520
521
+ // / A wrapper around two concatenated vectors (chains) of jumps.
522
+ struct MergedJumpsT {
523
+ MergedJumpsT (const std::vector<JumpT *> *Jumps1,
524
+ const std::vector<JumpT *> *Jumps2 = nullptr ) {
525
+ assert (!Jumps1->empty () && " cannot merge empty jump list" );
526
+ JumpArray[0 ] = Jumps1;
527
+ JumpArray[1 ] = Jumps2;
528
+ }
529
+
530
+ template <typename F> void forEach (const F &Func) const {
531
+ for (auto Jumps : JumpArray)
532
+ if (Jumps != nullptr )
533
+ for (JumpT *Jump : *Jumps)
534
+ Func (Jump);
535
+ }
536
+
537
+ private:
538
+ std::array<const std::vector<JumpT *> *, 2 > JumpArray{nullptr , nullptr };
539
+ };
540
+
519
541
// / Merge two chains of nodes respecting a given 'type' and 'offset'.
520
542
// /
521
543
// / If MergeType == 0, then the result is a concatenation of two chains.
522
544
// / Otherwise, the first chain is cut into two sub-chains at the offset,
523
545
// / and merged using all possible ways of concatenating three chains.
524
- MergedChain mergeNodes (const std::vector<NodeT *> &X,
525
- const std::vector<NodeT *> &Y, size_t MergeOffset,
526
- MergeTypeT MergeType) {
546
+ MergedNodesT mergeNodes (const std::vector<NodeT *> &X,
547
+ const std::vector<NodeT *> &Y, size_t MergeOffset,
548
+ MergeTypeT MergeType) {
527
549
// Split the first chain, X, into X1 and X2.
528
550
NodeIter BeginX1 = X.begin ();
529
551
NodeIter EndX1 = X.begin () + MergeOffset;
@@ -535,15 +557,15 @@ MergedChain mergeNodes(const std::vector<NodeT *> &X,
535
557
// Construct a new chain from the three existing ones.
536
558
switch (MergeType) {
537
559
case MergeTypeT::X_Y:
538
- return MergedChain (BeginX1, EndX2, BeginY, EndY);
560
+ return MergedNodesT (BeginX1, EndX2, BeginY, EndY);
539
561
case MergeTypeT::Y_X:
540
- return MergedChain (BeginY, EndY, BeginX1, EndX2);
562
+ return MergedNodesT (BeginY, EndY, BeginX1, EndX2);
541
563
case MergeTypeT::X1_Y_X2:
542
- return MergedChain (BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2);
564
+ return MergedNodesT (BeginX1, EndX1, BeginY, EndY, BeginX2, EndX2);
543
565
case MergeTypeT::Y_X2_X1:
544
- return MergedChain (BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1);
566
+ return MergedNodesT (BeginY, EndY, BeginX2, EndX2, BeginX1, EndX1);
545
567
case MergeTypeT::X2_X1_Y:
546
- return MergedChain (BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY);
568
+ return MergedNodesT (BeginX2, EndX2, BeginX1, EndX1, BeginY, EndY);
547
569
}
548
570
llvm_unreachable (" unexpected chain merge type" );
549
571
}
@@ -618,6 +640,7 @@ class ExtTSPImpl {
618
640
AllChains.reserve (NumNodes);
619
641
HotChains.reserve (NumNodes);
620
642
for (NodeT &Node : AllNodes) {
643
+ // Create a chain.
621
644
AllChains.emplace_back (Node.Index , &Node);
622
645
Node.CurChain = &AllChains.back ();
623
646
if (Node.ExecutionCount > 0 )
@@ -630,13 +653,13 @@ class ExtTSPImpl {
630
653
for (JumpT *Jump : PredNode.OutJumps ) {
631
654
NodeT *SuccNode = Jump->Target ;
632
655
ChainEdge *CurEdge = PredNode.CurChain ->getEdge (SuccNode->CurChain );
633
- // this edge is already present in the graph.
656
+ // This edge is already present in the graph.
634
657
if (CurEdge != nullptr ) {
635
658
assert (SuccNode->CurChain ->getEdge (PredNode.CurChain ) != nullptr );
636
659
CurEdge->appendJump (Jump);
637
660
continue ;
638
661
}
639
- // this is a new edge.
662
+ // This is a new edge.
640
663
AllEdges.emplace_back (Jump);
641
664
PredNode.CurChain ->addEdge (SuccNode->CurChain , &AllEdges.back ());
642
665
SuccNode->CurChain ->addEdge (PredNode.CurChain , &AllEdges.back ());
@@ -649,7 +672,7 @@ class ExtTSPImpl {
649
672
// / to B are from A. Such nodes should be adjacent in the optimal ordering;
650
673
// / the method finds and merges such pairs of nodes.
651
674
void mergeForcedPairs () {
652
- // Find fallthroughs based on edge weights .
675
+ // Find forced pairs of blocks .
653
676
for (NodeT &Node : AllNodes) {
654
677
if (SuccNodes[Node.Index ].size () == 1 &&
655
678
PredNodes[SuccNodes[Node.Index ][0 ]].size () == 1 &&
@@ -699,9 +722,7 @@ class ExtTSPImpl {
699
722
// / Deterministically compare pairs of chains.
700
723
auto compareChainPairs = [](const ChainT *A1, const ChainT *B1,
701
724
const ChainT *A2, const ChainT *B2) {
702
- if (A1 != A2)
703
- return A1->Id < A2->Id ;
704
- return B1->Id < B2->Id ;
725
+ return std::make_tuple (A1->Id , B1->Id ) < std::make_tuple (A2->Id , B2->Id );
705
726
};
706
727
707
728
while (HotChains.size () > 1 ) {
@@ -769,24 +790,22 @@ class ExtTSPImpl {
769
790
}
770
791
771
792
// / Compute the Ext-TSP score for a given node order and a list of jumps.
772
- double extTSPScore (const MergedChain &MergedBlocks,
773
- const std::vector<JumpT *> &Jumps) const {
774
- if (Jumps.empty ())
775
- return 0.0 ;
793
+ double extTSPScore (const MergedNodesT &Nodes,
794
+ const MergedJumpsT &Jumps) const {
776
795
uint64_t CurAddr = 0 ;
777
- MergedBlocks .forEach ([&](const NodeT *Node) {
796
+ Nodes .forEach ([&](const NodeT *Node) {
778
797
Node->EstimatedAddr = CurAddr;
779
798
CurAddr += Node->Size ;
780
799
});
781
800
782
801
double Score = 0 ;
783
- for ( JumpT *Jump : Jumps ) {
802
+ Jumps. forEach ([&]( const JumpT *Jump) {
784
803
const NodeT *SrcBlock = Jump->Source ;
785
804
const NodeT *DstBlock = Jump->Target ;
786
805
Score += ::extTSPScore (SrcBlock->EstimatedAddr , SrcBlock->Size ,
787
806
DstBlock->EstimatedAddr , Jump->ExecutionCount ,
788
807
Jump->IsConditional );
789
- }
808
+ });
790
809
return Score;
791
810
}
792
811
@@ -798,17 +817,13 @@ class ExtTSPImpl {
798
817
// / element being the corresponding merging type.
799
818
MergeGainT getBestMergeGain (ChainT *ChainPred, ChainT *ChainSucc,
800
819
ChainEdge *Edge) const {
801
- if (Edge->hasCachedMergeGain (ChainPred, ChainSucc)) {
820
+ if (Edge->hasCachedMergeGain (ChainPred, ChainSucc))
802
821
return Edge->getCachedMergeGain (ChainPred, ChainSucc);
803
- }
804
822
823
+ assert (!Edge->jumps ().empty () && " trying to merge chains w/o jumps" );
805
824
// Precompute jumps between ChainPred and ChainSucc.
806
- auto Jumps = Edge->jumps ();
807
825
ChainEdge *EdgePP = ChainPred->getEdge (ChainPred);
808
- if (EdgePP != nullptr ) {
809
- Jumps.insert (Jumps.end (), EdgePP->jumps ().begin (), EdgePP->jumps ().end ());
810
- }
811
- assert (!Jumps.empty () && " trying to merge chains w/o jumps" );
826
+ MergedJumpsT Jumps (&Edge->jumps (), EdgePP ? &EdgePP->jumps () : nullptr );
812
827
813
828
// This object holds the best chosen gain of merging two chains.
814
829
MergeGainT Gain = MergeGainT ();
@@ -875,19 +890,20 @@ class ExtTSPImpl {
875
890
// /
876
891
// / The two chains are not modified in the method.
877
892
MergeGainT computeMergeGain (const ChainT *ChainPred, const ChainT *ChainSucc,
878
- const std::vector<JumpT *> &Jumps,
879
- size_t MergeOffset, MergeTypeT MergeType) const {
880
- auto MergedBlocks =
893
+ const MergedJumpsT &Jumps, size_t MergeOffset ,
894
+ MergeTypeT MergeType) const {
895
+ MergedNodesT MergedNodes =
881
896
mergeNodes (ChainPred->Nodes , ChainSucc->Nodes , MergeOffset, MergeType);
882
897
883
898
// Do not allow a merge that does not preserve the original entry point.
884
899
if ((ChainPred->isEntry () || ChainSucc->isEntry ()) &&
885
- !MergedBlocks .getFirstNode ()->isEntry ())
900
+ !MergedNodes .getFirstNode ()->isEntry ())
886
901
return MergeGainT ();
887
902
888
903
// The gain for the new chain.
889
- auto NewGainScore = extTSPScore (MergedBlocks, Jumps) - ChainPred->Score ;
890
- return MergeGainT (NewGainScore, MergeOffset, MergeType);
904
+ double NewScore = extTSPScore (MergedNodes, Jumps);
905
+ double CurScore = ChainPred->Score ;
906
+ return MergeGainT (NewScore - CurScore, MergeOffset, MergeType);
891
907
}
892
908
893
909
// / Merge chain From into chain Into, update the list of active chains,
@@ -897,7 +913,7 @@ class ExtTSPImpl {
897
913
assert (Into != From && " a chain cannot be merged with itself" );
898
914
899
915
// Merge the nodes.
900
- MergedChain MergedNodes =
916
+ MergedNodesT MergedNodes =
901
917
mergeNodes (Into->Nodes , From->Nodes , MergeOffset, MergeType);
902
918
Into->merge (From, MergedNodes.getNodes ());
903
919
@@ -908,8 +924,9 @@ class ExtTSPImpl {
908
924
// Update cached ext-tsp score for the new chain.
909
925
ChainEdge *SelfEdge = Into->getEdge (Into);
910
926
if (SelfEdge != nullptr ) {
911
- MergedNodes = MergedChain (Into->Nodes .begin (), Into->Nodes .end ());
912
- Into->Score = extTSPScore (MergedNodes, SelfEdge->jumps ());
927
+ MergedNodes = MergedNodesT (Into->Nodes .begin (), Into->Nodes .end ());
928
+ MergedJumpsT MergedJumps (&SelfEdge->jumps ());
929
+ Into->Score = extTSPScore (MergedNodes, MergedJumps);
913
930
}
914
931
915
932
// Remove the chain from the list of active chains.
@@ -943,7 +960,7 @@ class ExtTSPImpl {
943
960
// Sorting chains by density in the decreasing order.
944
961
std::sort (SortedChains.begin (), SortedChains.end (),
945
962
[&](const ChainT *L, const ChainT *R) {
946
- // Place the entry point is at the beginning of the order.
963
+ // Place the entry point at the beginning of the order.
947
964
if (L->isEntry () != R->isEntry ())
948
965
return L->isEntry ();
949
966
@@ -1163,9 +1180,9 @@ class CDSortImpl {
1163
1180
// / result is a pair with the first element being the gain and the second
1164
1181
// / element being the corresponding merging type.
1165
1182
MergeGainT getBestMergeGain (ChainEdge *Edge) const {
1183
+ assert (!Edge->jumps ().empty () && " trying to merge chains w/o jumps" );
1166
1184
// Precompute jumps between ChainPred and ChainSucc.
1167
- auto Jumps = Edge->jumps ();
1168
- assert (!Jumps.empty () && " trying to merge chains w/o jumps" );
1185
+ MergedJumpsT Jumps (&Edge->jumps ());
1169
1186
ChainT *SrcChain = Edge->srcChain ();
1170
1187
ChainT *DstChain = Edge->dstChain ();
1171
1188
@@ -1204,7 +1221,7 @@ class CDSortImpl {
1204
1221
// /
1205
1222
// / The two chains are not modified in the method.
1206
1223
MergeGainT computeMergeGain (ChainT *ChainPred, ChainT *ChainSucc,
1207
- const std::vector<JumpT *> &Jumps,
1224
+ const MergedJumpsT &Jumps,
1208
1225
MergeTypeT MergeType) const {
1209
1226
// This doesn't depend on the ordering of the nodes
1210
1227
double FreqGain = freqBasedLocalityGain (ChainPred, ChainSucc);
@@ -1255,24 +1272,22 @@ class CDSortImpl {
1255
1272
}
1256
1273
1257
1274
// / Compute the change of the distance locality after merging the chains.
1258
- double distBasedLocalityGain (const MergedChain &MergedBlocks,
1259
- const std::vector<JumpT *> &Jumps) const {
1260
- if (Jumps.empty ())
1261
- return 0.0 ;
1275
+ double distBasedLocalityGain (const MergedNodesT &Nodes,
1276
+ const MergedJumpsT &Jumps) const {
1262
1277
uint64_t CurAddr = 0 ;
1263
- MergedBlocks .forEach ([&](const NodeT *Node) {
1278
+ Nodes .forEach ([&](const NodeT *Node) {
1264
1279
Node->EstimatedAddr = CurAddr;
1265
1280
CurAddr += Node->Size ;
1266
1281
});
1267
1282
1268
1283
double CurScore = 0 ;
1269
1284
double NewScore = 0 ;
1270
- for ( const JumpT *Arc : Jumps ) {
1271
- uint64_t SrcAddr = Arc ->Source ->EstimatedAddr + Arc ->Offset ;
1272
- uint64_t DstAddr = Arc ->Target ->EstimatedAddr ;
1273
- NewScore += distScore (SrcAddr, DstAddr, Arc ->ExecutionCount );
1274
- CurScore += distScore (0 , TotalSize, Arc ->ExecutionCount );
1275
- }
1285
+ Jumps. forEach ([&]( const JumpT *Jump ) {
1286
+ uint64_t SrcAddr = Jump ->Source ->EstimatedAddr + Jump ->Offset ;
1287
+ uint64_t DstAddr = Jump ->Target ->EstimatedAddr ;
1288
+ NewScore += distScore (SrcAddr, DstAddr, Jump ->ExecutionCount );
1289
+ CurScore += distScore (0 , TotalSize, Jump ->ExecutionCount );
1290
+ });
1276
1291
return NewScore - CurScore;
1277
1292
}
1278
1293
@@ -1283,7 +1298,7 @@ class CDSortImpl {
1283
1298
assert (Into != From && " a chain cannot be merged with itself" );
1284
1299
1285
1300
// Merge the nodes.
1286
- MergedChain MergedNodes =
1301
+ MergedNodesT MergedNodes =
1287
1302
mergeNodes (Into->Nodes , From->Nodes , MergeOffset, MergeType);
1288
1303
Into->merge (From, MergedNodes.getNodes ());
1289
1304
0 commit comments