Skip to content

Commit 044c0ce

Browse files
committed
[ctx_prof] Flattened profile lowering pass
1 parent 9fef09f commit 044c0ce

File tree

7 files changed

+442
-3
lines changed

7 files changed

+442
-3
lines changed

llvm/include/llvm/ProfileData/ProfileCommon.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,13 @@ class ProfileSummaryBuilder {
7979
class InstrProfSummaryBuilder final : public ProfileSummaryBuilder {
8080
uint64_t MaxInternalBlockCount = 0;
8181

82-
inline void addEntryCount(uint64_t Count);
83-
inline void addInternalCount(uint64_t Count);
84-
8582
public:
8683
InstrProfSummaryBuilder(std::vector<uint32_t> Cutoffs)
8784
: ProfileSummaryBuilder(std::move(Cutoffs)) {}
8885

86+
void addEntryCount(uint64_t Count);
87+
void addInternalCount(uint64_t Count);
88+
8989
void addRecord(const InstrProfRecord &);
9090
std::unique_ptr<ProfileSummary> getSummary();
9191
};
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===-- PGOCtxProfFlattening.h - Contextual Instr. Flattening ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the PGOCtxProfFlattening class.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
#ifndef LLVM_TRANSFORMS_INSTRUMENTATION_PGOCTXPROFFLATTENING_H
13+
#define LLVM_TRANSFORMS_INSTRUMENTATION_PGOCTXPROFFLATTENING_H
14+
15+
#include "llvm/IR/PassManager.h"
16+
namespace llvm {
17+
18+
class PGOCtxProfFlattening : public PassInfoMixin<PGOCtxProfFlattening> {
19+
public:
20+
explicit PGOCtxProfFlattening() = default;
21+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
22+
};
23+
} // namespace llvm
24+
#endif

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@
197197
#include "llvm/Transforms/Instrumentation/MemProfiler.h"
198198
#include "llvm/Transforms/Instrumentation/MemorySanitizer.h"
199199
#include "llvm/Transforms/Instrumentation/NumericalStabilitySanitizer.h"
200+
#include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h"
200201
#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
201202
#include "llvm/Transforms/Instrumentation/PGOForceFunctionAttrs.h"
202203
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"

llvm/lib/Passes/PassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ MODULE_PASS("coro-early", CoroEarlyPass())
5858
MODULE_PASS("cross-dso-cfi", CrossDSOCFIPass())
5959
MODULE_PASS("ctx-instr-gen",
6060
PGOInstrumentationGen(PGOInstrumentationType::CTXPROF))
61+
MODULE_PASS("ctx-prof-flatten", PGOCtxProfFlattening())
6162
MODULE_PASS("deadargelim", DeadArgumentEliminationPass())
6263
MODULE_PASS("debugify", NewPMDebugifyPass())
6364
MODULE_PASS("dfsan", DataFlowSanitizerPass())

llvm/lib/Transforms/Instrumentation/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_llvm_component_library(LLVMInstrumentation
1515
InstrProfiling.cpp
1616
KCFI.cpp
1717
LowerAllowCheckPass.cpp
18+
PGOCtxProfFlattening.cpp
1819
PGOCtxProfLowering.cpp
1920
PGOForceFunctionAttrs.cpp
2021
PGOInstrumentation.cpp
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
//===- PGOCtxProfFlattening.cpp - Contextual Instr. Flattening ------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Flattens the contextual profile and lowers it to MD_prof.
10+
// This should happen after all IPO (which is assumed to have maintained the
11+
// contextual profile) happened. Flattening consists of summing the values at
12+
// the same index of the counters belonging to all the contexts of a function.
13+
// The lowering consists of materializing the counter values to function
14+
// entrypoint counts and branch probabilities.
15+
//
16+
// This pass also removes contextual instrumentation, which has been kept around
17+
// to facilitate its functionality.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
21+
#include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h"
22+
#include "llvm/ADT/STLExtras.h"
23+
#include "llvm/Analysis/CtxProfAnalysis.h"
24+
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
25+
#include "llvm/Analysis/ProfileSummaryInfo.h"
26+
#include "llvm/CodeGen/MachineBasicBlock.h"
27+
#include "llvm/IR/Analysis.h"
28+
#include "llvm/IR/CFG.h"
29+
#include "llvm/IR/Dominators.h"
30+
#include "llvm/IR/IntrinsicInst.h"
31+
#include "llvm/IR/Module.h"
32+
#include "llvm/IR/PassManager.h"
33+
#include "llvm/IR/ProfileSummary.h"
34+
#include "llvm/ProfileData/ProfileCommon.h"
35+
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
36+
#include "llvm/Transforms/Scalar/DCE.h"
37+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
38+
39+
using namespace llvm;
40+
41+
namespace {
42+
43+
class ProfileAnnotator final {
44+
class BBInfo;
45+
struct EdgeInfo {
46+
BBInfo *const Src;
47+
BBInfo *const Dest;
48+
std::optional<uint64_t> Count;
49+
50+
explicit EdgeInfo(BBInfo &Src, BBInfo &Dest) : Src(&Src), Dest(&Dest) {}
51+
};
52+
53+
class BBInfo {
54+
std::optional<uint64_t> Count;
55+
SmallVector<EdgeInfo *> OutEdges;
56+
SmallVector<EdgeInfo *> InEdges;
57+
size_t UnknownCountOutEdges = 0;
58+
size_t UnknownCountInEdges = 0;
59+
60+
uint64_t getEdgeSum(const SmallVector<EdgeInfo *> &Edges,
61+
bool AssumeAllKnown) const {
62+
uint64_t Sum = 0;
63+
for (const auto *E : Edges)
64+
if (E)
65+
Sum += AssumeAllKnown ? *E->Count : E->Count.value_or(0U);
66+
return Sum;
67+
}
68+
69+
void takeCountFrom(const SmallVector<EdgeInfo *> &Edges) {
70+
assert(!Count.has_value());
71+
Count = getEdgeSum(Edges, true);
72+
}
73+
74+
void setSingleUnknownEdgeCount(SmallVector<EdgeInfo *> &Edges) {
75+
uint64_t KnownSum = getEdgeSum(Edges, false);
76+
uint64_t EdgeVal = *Count > KnownSum ? *Count - KnownSum : 0U;
77+
EdgeInfo *E = nullptr;
78+
for (auto *I : Edges)
79+
if (I && !I->Count.has_value()) {
80+
E = I;
81+
#ifdef NDEBUG
82+
break;
83+
#else
84+
assert((!E || E == I) &&
85+
"Expected exactly one edge to have an unknown count, "
86+
"found a second one");
87+
continue;
88+
#endif
89+
}
90+
assert(E && "Expected exactly one edge to have an unknown count");
91+
assert(!E->Count.has_value());
92+
E->Count = EdgeVal;
93+
assert(E->Src->UnknownCountOutEdges > 0);
94+
assert(E->Dest->UnknownCountInEdges > 0);
95+
--E->Src->UnknownCountOutEdges;
96+
--E->Dest->UnknownCountInEdges;
97+
}
98+
99+
public:
100+
BBInfo(size_t NumInEdges, size_t NumOutEdges, std::optional<uint64_t> Count)
101+
: Count(Count) {
102+
InEdges.reserve(NumInEdges);
103+
OutEdges.resize(NumOutEdges);
104+
}
105+
106+
bool tryTakeCountFromKnownOutEdges(const BasicBlock &BB) {
107+
if (!succ_empty(&BB) && !UnknownCountOutEdges) {
108+
takeCountFrom(OutEdges);
109+
return true;
110+
}
111+
return false;
112+
}
113+
114+
bool tryTakeCountFromKnownInEdges(const BasicBlock &BB) {
115+
if (!BB.isEntryBlock() && !UnknownCountInEdges) {
116+
takeCountFrom(InEdges);
117+
return true;
118+
}
119+
return false;
120+
}
121+
122+
void addInEdge(EdgeInfo *Info) {
123+
InEdges.push_back(Info);
124+
++UnknownCountInEdges;
125+
}
126+
127+
void addOutEdge(size_t Index, EdgeInfo *Info) {
128+
OutEdges[Index] = Info;
129+
++UnknownCountOutEdges;
130+
}
131+
132+
bool hasCount() const { return Count.has_value(); }
133+
134+
bool trySetSingleUnknownInEdgeCount() {
135+
if (UnknownCountInEdges == 1) {
136+
setSingleUnknownEdgeCount(InEdges);
137+
return true;
138+
}
139+
return false;
140+
}
141+
142+
bool trySetSingleUnknownOutEdgeCount() {
143+
if (UnknownCountOutEdges == 1) {
144+
setSingleUnknownEdgeCount(OutEdges);
145+
return true;
146+
}
147+
return false;
148+
}
149+
size_t getNumOutEdges() const { return OutEdges.size(); }
150+
151+
uint64_t getEdgeCount(size_t Index) const {
152+
if (auto *E = OutEdges[Index])
153+
return *E->Count;
154+
return 0U;
155+
}
156+
};
157+
158+
Function &F;
159+
const SmallVectorImpl<uint64_t> &Counters;
160+
// To be accessed through getBBInfo() after construction.
161+
std::map<const BasicBlock *, BBInfo> BBInfos;
162+
std::vector<EdgeInfo> EdgeInfos;
163+
InstrProfSummaryBuilder &PB;
164+
165+
// This is an adaptation of PGOUseFunc::populateCounters.
166+
// FIXME(mtrofin): look into factoring the code to share one implementation.
167+
void propagateCounterValues(const SmallVectorImpl<uint64_t> &Counters) {
168+
bool KeepGoing = true;
169+
while (KeepGoing) {
170+
KeepGoing = false;
171+
for (const auto &BB : reverse(F)) {
172+
auto &Info = getBBInfo(BB);
173+
if (!Info.hasCount())
174+
KeepGoing |= Info.tryTakeCountFromKnownOutEdges(BB) ||
175+
Info.tryTakeCountFromKnownInEdges(BB);
176+
if (Info.hasCount()) {
177+
KeepGoing |= Info.trySetSingleUnknownOutEdgeCount();
178+
KeepGoing |= Info.trySetSingleUnknownInEdgeCount();
179+
}
180+
}
181+
}
182+
}
183+
// The only criteria for exclusion is faux suspend -> exit edges in presplit
184+
// coroutines. The API serves for readability, currently.
185+
bool shouldExcludeEdge(const BasicBlock &Src, const BasicBlock &Dest) const {
186+
return llvm::isPresplitCoroSuspendExitEdge(Src, Dest);
187+
}
188+
189+
BBInfo &getBBInfo(const BasicBlock &BB) { return BBInfos.find(&BB)->second; }
190+
191+
public:
192+
ProfileAnnotator(Function &F, const SmallVectorImpl<uint64_t> &Counters,
193+
InstrProfSummaryBuilder &PB)
194+
: F(F), Counters(Counters), PB(PB) {
195+
assert(!F.isDeclaration());
196+
assert(!Counters.empty());
197+
size_t NrEdges = 0;
198+
for (const auto &BB : F) {
199+
std::optional<uint64_t> Count;
200+
if (auto *Ins = CtxProfAnalysis::getBBInstrumentation(
201+
const_cast<BasicBlock &>(BB)))
202+
Count = Counters[Ins->getIndex()->getZExtValue()];
203+
auto [It, Ins] =
204+
BBInfos.insert({&BB, {pred_size(&BB), succ_size(&BB), Count}});
205+
(void)Ins;
206+
assert(Ins);
207+
NrEdges += llvm::count_if(successors(&BB), [&](const auto *Succ) {
208+
return !shouldExcludeEdge(BB, *Succ);
209+
});
210+
}
211+
// Pre-allocate the vector, we want references to its contents to be stable.
212+
EdgeInfos.reserve(NrEdges);
213+
for (const auto &BB : F) {
214+
auto &Info = getBBInfo(BB);
215+
for (auto I = 0U; I < BB.getTerminator()->getNumSuccessors(); ++I) {
216+
const auto *Succ = BB.getTerminator()->getSuccessor(I);
217+
if (!shouldExcludeEdge(BB, *Succ)) {
218+
assert(EdgeInfos.size() < NrEdges);
219+
auto &EI = EdgeInfos.emplace_back(getBBInfo(BB), getBBInfo(*Succ));
220+
Info.addOutEdge(I, &EI);
221+
getBBInfo(*Succ).addInEdge(&EI);
222+
}
223+
}
224+
}
225+
}
226+
227+
/// Assign branch weights and function entry count. Also update the PSI
228+
/// builder.
229+
void assignProfileData() {
230+
assert(!Counters.empty());
231+
propagateCounterValues(Counters);
232+
F.setEntryCount(Counters[0]);
233+
PB.addEntryCount(Counters[0]);
234+
235+
for (auto &BB : F) {
236+
if (succ_size(&BB) < 2)
237+
continue;
238+
auto *Term = BB.getTerminator();
239+
SmallVector<uint64_t, 2> EdgeCounts(Term->getNumSuccessors(), 0);
240+
uint64_t MaxCount = 0;
241+
const auto &BBInfo = getBBInfo(BB);
242+
for (unsigned SuccIdx = 0, Size = BBInfo.getNumOutEdges(); SuccIdx < Size;
243+
++SuccIdx) {
244+
uint64_t EdgeCount = BBInfo.getEdgeCount(SuccIdx);
245+
if (EdgeCount > MaxCount)
246+
MaxCount = EdgeCount;
247+
EdgeCounts[SuccIdx] = EdgeCount;
248+
PB.addInternalCount(EdgeCount);
249+
}
250+
251+
if (MaxCount == 0)
252+
F.getContext().emitError(
253+
"[ctx-prof] Encountered a BB with more than one successor, where "
254+
"all outgoing edges have a 0 count. This occurs in non-exiting "
255+
"functions (message pumps, usually) which are not supported in the "
256+
"contextual profiling case");
257+
setProfMetadata(F.getParent(), Term, EdgeCounts, MaxCount);
258+
}
259+
}
260+
};
261+
262+
bool areAllBBsReachable(const Function &F, FunctionAnalysisManager &FAM) {
263+
auto &DT = FAM.getResult<DominatorTreeAnalysis>(const_cast<Function &>(F));
264+
for (const auto &BB : F)
265+
if (!DT.isReachableFromEntry(&BB))
266+
return false;
267+
return true;
268+
}
269+
270+
void clearColdFunctionProfile(Function &F) {
271+
for (auto &BB : F)
272+
BB.getTerminator()->setMetadata(LLVMContext::MD_prof, nullptr);
273+
F.setEntryCount(0U);
274+
}
275+
276+
void removeInstrumentation(Function &F) {
277+
for (auto &BB : F)
278+
for (auto &I : llvm::make_early_inc_range(BB))
279+
if (isa<InstrProfCntrInstBase>(I))
280+
I.eraseFromParent();
281+
}
282+
283+
} // namespace
284+
285+
PreservedAnalyses PGOCtxProfFlattening::run(Module &M,
286+
ModuleAnalysisManager &MAM) {
287+
auto &CtxProf = MAM.getResult<CtxProfAnalysis>(M);
288+
if (!CtxProf)
289+
return PreservedAnalyses::all();
290+
291+
const auto FlattenedProfile = CtxProf.flatten();
292+
293+
InstrProfSummaryBuilder PB(ProfileSummaryBuilder::DefaultCutoffs);
294+
for (auto &F : M) {
295+
if (F.isDeclaration())
296+
continue;
297+
298+
if (!areAllBBsReachable(F,
299+
MAM.getResult<FunctionAnalysisManagerModuleProxy>(M)
300+
.getManager())) {
301+
M.getContext().emitError(
302+
"[ctx-prof] Function has unreacheable basic blocks: " + F.getName());
303+
continue;
304+
}
305+
306+
const auto &FlatProfile =
307+
FlattenedProfile.lookup(AssignGUIDPass::getGUID(F));
308+
// If this function didn't appear in the contextual profile, it's cold.
309+
if (FlatProfile.empty())
310+
clearColdFunctionProfile(F);
311+
else {
312+
ProfileAnnotator S(F, FlatProfile, PB);
313+
S.assignProfileData();
314+
}
315+
removeInstrumentation(F);
316+
}
317+
318+
auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M);
319+
320+
M.setProfileSummary(PB.getSummary()->getMD(M.getContext()),
321+
ProfileSummary::Kind::PSK_Instr);
322+
PSI.refresh();
323+
return PreservedAnalyses::none();
324+
}

0 commit comments

Comments
 (0)