Skip to content

Commit 3209766

Browse files
authored
[ctx_prof] Add Inlining support (llvm#106154)
Add an overload of `InlineFunction` that updates the contextual profile. If there is no contextual profile, this overload is equivalent to the non-contextual profile variant. Post-inlining, the update mainly consists of: - making the PGO instrumentation of the callee "the caller's": the owner function (the "name" parameter of the instrumentation instructions) becomes the caller, and new index values are allocated for each of the callee's indices (this happens for both increment and callsite instrumentation instructions) - in the contextual profile: - each context corresponding to the caller has its counters updated to incorporate the counters inherited from the callee at the inlined callsite. Counter values are copied as-is because no scaling is required since the profile is contextual. - the contexts of the callee (at the inlined callsite) are moved to the caller. - the callee context at the inlined callsite is deleted.
1 parent b24a304 commit 3209766

File tree

12 files changed

+431
-18
lines changed

12 files changed

+431
-18
lines changed

llvm/include/llvm/Analysis/CtxProfAnalysis.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "llvm/IR/IntrinsicInst.h"
1616
#include "llvm/IR/PassManager.h"
1717
#include "llvm/ProfileData/PGOCtxProfReader.h"
18+
#include <optional>
1819

1920
namespace llvm {
2021

@@ -63,6 +64,16 @@ class PGOContextualProfile {
6364
return getDefinedFunctionGUID(F) != 0;
6465
}
6566

67+
uint32_t getNumCounters(const Function &F) const {
68+
assert(isFunctionKnown(F));
69+
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCounterIndex;
70+
}
71+
72+
uint32_t getNumCallsites(const Function &F) const {
73+
assert(isFunctionKnown(F));
74+
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex;
75+
}
76+
6677
uint32_t allocateNextCounterIndex(const Function &F) {
6778
assert(isFunctionKnown(F));
6879
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCounterIndex++;
@@ -91,11 +102,11 @@ class PGOContextualProfile {
91102
};
92103

93104
class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
94-
StringRef Profile;
105+
const std::optional<StringRef> Profile;
95106

96107
public:
97108
static AnalysisKey Key;
98-
explicit CtxProfAnalysis(StringRef Profile = "");
109+
explicit CtxProfAnalysis(std::optional<StringRef> Profile = std::nullopt);
99110

100111
using Result = PGOContextualProfile;
101112

@@ -113,9 +124,7 @@ class CtxProfAnalysisPrinterPass
113124
: public PassInfoMixin<CtxProfAnalysisPrinterPass> {
114125
public:
115126
enum class PrintMode { Everything, JSON };
116-
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS,
117-
PrintMode Mode = PrintMode::Everything)
118-
: OS(OS), Mode(Mode) {}
127+
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS);
119128

120129
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
121130
static bool isRequired() { return true; }

llvm/include/llvm/IR/IntrinsicInst.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,6 +1516,8 @@ class InstrProfInstBase : public IntrinsicInst {
15161516
return const_cast<Value *>(getArgOperand(0))->stripPointerCasts();
15171517
}
15181518

1519+
void setNameValue(Value *V) { setArgOperand(0, V); }
1520+
15191521
// The hash of the CFG for the instrumented function.
15201522
ConstantInt *getHash() const {
15211523
return cast<ConstantInt>(const_cast<Value *>(getArgOperand(1)));

llvm/include/llvm/ProfileData/PGOCtxProfReader.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ class PGOCtxProfContext final {
7474
Iter->second.emplace(Other.guid(), std::move(Other));
7575
}
7676

77+
void ingestAllContexts(uint32_t CSId, CallTargetMapTy &&Other) {
78+
auto [_, Inserted] = callsites().try_emplace(CSId, std::move(Other));
79+
(void)Inserted;
80+
assert(Inserted &&
81+
"CSId was expected to be newly created as result of e.g. inlining");
82+
}
83+
7784
void resizeCounters(uint32_t Size) { Counters.resize(Size); }
7885

7986
bool hasCallsite(uint32_t I) const {

llvm/include/llvm/Transforms/Utils/Cloning.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/ADT/SmallVector.h"
2121
#include "llvm/ADT/Twine.h"
2222
#include "llvm/Analysis/AssumptionCache.h"
23+
#include "llvm/Analysis/CtxProfAnalysis.h"
2324
#include "llvm/Analysis/InlineCost.h"
2425
#include "llvm/IR/BasicBlock.h"
2526
#include "llvm/IR/ValueHandle.h"
@@ -270,6 +271,17 @@ InlineResult InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
270271
bool InsertLifetime = true,
271272
Function *ForwardVarArgsTo = nullptr);
272273

274+
/// Same as above, but it will update the contextual profile. If the contextual
275+
/// profile is invalid (i.e. not loaded because it is not present), it defaults
276+
/// to the behavior of the non-contextual profile updating variant above. This
277+
/// makes it easy to drop-in replace uses of the non-contextual overload.
278+
InlineResult InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
279+
CtxProfAnalysis::Result &CtxProf,
280+
bool MergeAttributes = false,
281+
AAResults *CalleeAAR = nullptr,
282+
bool InsertLifetime = true,
283+
Function *ForwardVarArgsTo = nullptr);
284+
273285
/// Clones a loop \p OrigLoop. Returns the loop and the blocks in \p
274286
/// Blocks.
275287
///

llvm/lib/Analysis/CtxProfAnalysis.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ cl::opt<std::string>
2929
UseCtxProfile("use-ctx-profile", cl::init(""), cl::Hidden,
3030
cl::desc("Use the specified contextual profile file"));
3131

32+
static cl::opt<CtxProfAnalysisPrinterPass::PrintMode> PrintLevel(
33+
"ctx-profile-printer-level",
34+
cl::init(CtxProfAnalysisPrinterPass::PrintMode::JSON), cl::Hidden,
35+
cl::values(clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::Everything,
36+
"everything", "print everything - most verbose"),
37+
clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::JSON, "json",
38+
"just the json representation of the profile")),
39+
cl::desc("Verbosity level of the contextual profile printer pass."));
40+
3241
namespace llvm {
3342
namespace json {
3443
Value toJSON(const PGOCtxProfContext &P) {
@@ -96,12 +105,20 @@ GlobalValue::GUID AssignGUIDPass::getGUID(const Function &F) {
96105
}
97106
AnalysisKey CtxProfAnalysis::Key;
98107

99-
CtxProfAnalysis::CtxProfAnalysis(StringRef Profile)
100-
: Profile(Profile.empty() ? UseCtxProfile : Profile) {}
108+
CtxProfAnalysis::CtxProfAnalysis(std::optional<StringRef> Profile)
109+
: Profile([&]() -> std::optional<StringRef> {
110+
if (Profile)
111+
return *Profile;
112+
if (UseCtxProfile.getNumOccurrences())
113+
return UseCtxProfile;
114+
return std::nullopt;
115+
}()) {}
101116

102117
PGOContextualProfile CtxProfAnalysis::run(Module &M,
103118
ModuleAnalysisManager &MAM) {
104-
ErrorOr<std::unique_ptr<MemoryBuffer>> MB = MemoryBuffer::getFile(Profile);
119+
if (!Profile)
120+
return {};
121+
ErrorOr<std::unique_ptr<MemoryBuffer>> MB = MemoryBuffer::getFile(*Profile);
105122
if (auto EC = MB.getError()) {
106123
M.getContext().emitError("could not open contextual profile file: " +
107124
EC.message());
@@ -150,7 +167,6 @@ PGOContextualProfile CtxProfAnalysis::run(Module &M,
150167
// If we made it this far, the Result is valid - which we mark by setting
151168
// .Profiles.
152169
// Trim first the roots that aren't in this module.
153-
DenseSet<GlobalValue::GUID> ProfiledGUIDs;
154170
for (auto &[RootGuid, _] : llvm::make_early_inc_range(*MaybeCtx))
155171
if (!Result.FuncInfo.contains(RootGuid))
156172
MaybeCtx->erase(RootGuid);
@@ -165,11 +181,14 @@ PGOContextualProfile::getDefinedFunctionGUID(const Function &F) const {
165181
return 0;
166182
}
167183

184+
CtxProfAnalysisPrinterPass::CtxProfAnalysisPrinterPass(raw_ostream &OS)
185+
: OS(OS), Mode(PrintLevel) {}
186+
168187
PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
169188
ModuleAnalysisManager &MAM) {
170189
CtxProfAnalysis::Result &C = MAM.getResult<CtxProfAnalysis>(M);
171190
if (!C) {
172-
M.getContext().emitError("Invalid CtxProfAnalysis");
191+
OS << "No contextual profile was provided.\n";
173192
return PreservedAnalyses::all();
174193
}
175194

llvm/lib/Transforms/IPO/ModuleInliner.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/Analysis/AliasAnalysis.h"
2121
#include "llvm/Analysis/AssumptionCache.h"
2222
#include "llvm/Analysis/BlockFrequencyInfo.h"
23+
#include "llvm/Analysis/CtxProfAnalysis.h"
2324
#include "llvm/Analysis/InlineAdvisor.h"
2425
#include "llvm/Analysis/InlineCost.h"
2526
#include "llvm/Analysis/InlineOrder.h"
@@ -113,6 +114,8 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
113114
return PreservedAnalyses::all();
114115
}
115116

117+
auto &CtxProf = MAM.getResult<CtxProfAnalysis>(M);
118+
116119
bool Changed = false;
117120

118121
ProfileSummaryInfo *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(M);
@@ -213,7 +216,7 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
213216
&FAM.getResult<BlockFrequencyAnalysis>(Callee));
214217

215218
InlineResult IR =
216-
InlineFunction(*CB, IFI, /*MergeAttributes=*/true,
219+
InlineFunction(*CB, IFI, CtxProf, /*MergeAttributes=*/true,
217220
&FAM.getResult<AAManager>(*CB->getCaller()));
218221
if (!IR.isSuccess()) {
219222
Advice->recordUnsuccessfulInlining(IR);

0 commit comments

Comments
 (0)