@@ -125,6 +125,9 @@ class PGOHash {
125
125
BinaryOperatorNE,
126
126
// The preceding values are available since PGO_HASH_V2.
127
127
128
+ // Cilk statements. These values are also available with PGO_HASH_V1.
129
+ CilkForStmt,
130
+
128
131
// Keep this last. It's for the static assert that follows.
129
132
LastHashType
130
133
};
@@ -266,6 +269,7 @@ struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
266
269
DEFINE_NESTABLE_TRAVERSAL (ObjCForCollectionStmt)
267
270
DEFINE_NESTABLE_TRAVERSAL (CXXTryStmt)
268
271
DEFINE_NESTABLE_TRAVERSAL (CXXCatchStmt)
272
+ DEFINE_NESTABLE_TRAVERSAL (CilkForStmt)
269
273
270
274
// / Get version \p HashVersion of the PGO hash for \p S.
271
275
PGOHash::HashType getHashType (PGOHashVersion HashVersion, const Stmt *S) {
@@ -326,6 +330,8 @@ struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
326
330
}
327
331
break ;
328
332
}
333
+ case Stmt::CilkForStmtClass:
334
+ return PGOHash::CilkForStmt;
329
335
}
330
336
331
337
if (HashVersion >= PGO_HASH_V2) {
@@ -743,6 +749,53 @@ struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
743
749
setCount (ParentCount + RHSCount - CurrentCount);
744
750
RecordNextStmtCount = true ;
745
751
}
752
+
753
+ void VisitCilkForStmt (const CilkForStmt *S) {
754
+ RecordStmtCount (S);
755
+ if (S->getInit ())
756
+ Visit (S->getInit ());
757
+ if (S->getLimitStmt ())
758
+ Visit (S->getLimitStmt ());
759
+ if (S->getBeginStmt ())
760
+ Visit (S->getBeginStmt ());
761
+ if (S->getEndStmt ())
762
+ Visit (S->getEndStmt ());
763
+ if (S->getLoopVarDecl ())
764
+ Visit (S->getLoopVarDecl ());
765
+
766
+ uint64_t ParentCount = CurrentCount;
767
+
768
+ BreakContinueStack.push_back (BreakContinue ());
769
+ // Visit the body region first. (This is basically the same as a while
770
+ // loop; see further comments in VisitWhileStmt.)
771
+ uint64_t BodyCount = setCount (PGO.getRegionCount (S));
772
+ CountMap[S->getBody ()] = BodyCount;
773
+ Visit (S->getBody ());
774
+ uint64_t BackedgeCount = CurrentCount;
775
+ BreakContinue BC = BreakContinueStack.pop_back_val ();
776
+
777
+ // The increment is essentially part of the body but it needs to include
778
+ // the count for all the continue statements.
779
+ if (S->getInc ()) {
780
+ uint64_t IncCount = setCount (BackedgeCount + BC.ContinueCount );
781
+ CountMap[S->getInc ()] = IncCount;
782
+ Visit (S->getInc ());
783
+ }
784
+
785
+ // ...then go back and propagate counts through the condition.
786
+ uint64_t CondCount =
787
+ setCount (ParentCount + BackedgeCount + BC.ContinueCount );
788
+ if (S->getInitCond ()) {
789
+ CountMap[S->getInitCond ()] = ParentCount;
790
+ Visit (S->getInitCond ());
791
+ }
792
+ if (S->getCond ()) {
793
+ CountMap[S->getCond ()] = CondCount;
794
+ Visit (S->getCond ());
795
+ }
796
+ setCount (BC.BreakCount + CondCount - BodyCount);
797
+ RecordNextStmtCount = true ;
798
+ }
746
799
};
747
800
} // end anonymous namespace
748
801
0 commit comments