Skip to content

Commit fc27e19

Browse files
mtrofinronlieb
authored andcommitted
[ctx_prof] Add support for ICP (llvm#105469)
An overload of `llvm::promoteCallWithIfThenElse` that updates the contextual profile. High-level, this is very simple: after creating the `if... then (direct call) else (indirect call)` structure, we instrument the new callsites and BBs (the instrumentation will help with tracking for other IPO transformations, and, ultimately, to match counter values before flattening to `MD_prof`). In more detail: - move the callsite instrumentation of the indirect call to the `else` BB, before the indirect call - create a new callsite instrumentation for the direct call - create instrumentation for both the `then` and `else` BBs - we could instrument just one (MST-style) but we're not running the binary with this instrumentation, and at most this would save some space (less counters tracked). For simplicity instrumenting both at this point - update each context belonging to the caller by updating the counters, and moving the indirect callee to the new, direct callsite ID Issue llvm#89287 Change-Id: I6cea45269b753c9d4b660f3e8b16f176a7281e13
1 parent 26febbc commit fc27e19

File tree

8 files changed

+344
-33
lines changed

8 files changed

+344
-33
lines changed

llvm/include/llvm/Analysis/CtxProfAnalysis.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ class PGOContextualProfile {
7373
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex++;
7474
}
7575

76+
using ConstVisitor = function_ref<void(const PGOCtxProfContext &)>;
77+
using Visitor = function_ref<void(PGOCtxProfContext &)>;
78+
79+
void update(Visitor, const Function *F = nullptr);
80+
void visit(ConstVisitor, const Function *F = nullptr) const;
81+
7682
const CtxProfFlatProfile flatten() const;
7783

7884
bool invalidate(Module &, const PreservedAnalyses &PA,
@@ -105,13 +111,18 @@ class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
105111

106112
class CtxProfAnalysisPrinterPass
107113
: public PassInfoMixin<CtxProfAnalysisPrinterPass> {
108-
raw_ostream &OS;
109-
110114
public:
111-
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {}
115+
enum class PrintMode { Everything, JSON };
116+
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS,
117+
PrintMode Mode = PrintMode::Everything)
118+
: OS(OS), Mode(Mode) {}
112119

113120
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
114121
static bool isRequired() { return true; }
122+
123+
private:
124+
raw_ostream &OS;
125+
const PrintMode Mode;
115126
};
116127

117128
/// Assign a GUID to functions as metadata. GUID calculation takes linkage into

llvm/include/llvm/IR/IntrinsicInst.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,7 @@ class InstrProfCntrInstBase : public InstrProfInstBase {
15971597
ConstantInt *getNumCounters() const;
15981598
// The index of the counter that this instruction acts on.
15991599
ConstantInt *getIndex() const;
1600+
void setIndex(uint32_t Idx);
16001601
};
16011602

16021603
/// This represents the llvm.instrprof.cover intrinsic.
@@ -1647,6 +1648,7 @@ class InstrProfCallsite : public InstrProfCntrInstBase {
16471648
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
16481649
}
16491650
Value *getCallee() const;
1651+
void setCallee(Value *Callee);
16501652
};
16511653

16521654
/// This represents the llvm.instrprof.timestamp intrinsic.

llvm/include/llvm/ProfileData/PGOCtxProfReader.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,25 @@ class PGOCtxProfContext final {
5757

5858
GlobalValue::GUID guid() const { return GUID; }
5959
const SmallVectorImpl<uint64_t> &counters() const { return Counters; }
60+
SmallVectorImpl<uint64_t> &counters() { return Counters; }
61+
62+
uint64_t getEntrycount() const {
63+
assert(!Counters.empty() &&
64+
"Functions are expected to have at their entry BB instrumented, so "
65+
"there should always be at least 1 counter.");
66+
return Counters[0];
67+
}
68+
6069
const CallsiteMapTy &callsites() const { return Callsites; }
6170
CallsiteMapTy &callsites() { return Callsites; }
6271

72+
void ingestContext(uint32_t CSId, PGOCtxProfContext &&Other) {
73+
auto [Iter, _] = callsites().try_emplace(CSId, CallTargetMapTy());
74+
Iter->second.emplace(Other.guid(), std::move(Other));
75+
}
76+
77+
void resizeCounters(uint32_t Size) { Counters.resize(Size); }
78+
6379
bool hasCallsite(uint32_t I) const {
6480
return Callsites.find(I) != Callsites.end();
6581
}
@@ -68,6 +84,12 @@ class PGOCtxProfContext final {
6884
assert(hasCallsite(I) && "Callsite not found");
6985
return Callsites.find(I)->second;
7086
}
87+
88+
CallTargetMapTy &callsite(uint32_t I) {
89+
assert(hasCallsite(I) && "Callsite not found");
90+
return Callsites.find(I)->second;
91+
}
92+
7193
void getContainedGuids(DenseSet<GlobalValue::GUID> &Guids) const;
7294
};
7395

llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
1515
#define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
1616

17+
#include "llvm/Analysis/CtxProfAnalysis.h"
1718
namespace llvm {
1819
template <typename T> class ArrayRef;
1920
class Constant;
@@ -56,6 +57,9 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
5657
CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
5758
MDNode *BranchWeights = nullptr);
5859

60+
CallBase *promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
61+
PGOContextualProfile &CtxProf);
62+
5963
/// This is similar to `promoteCallWithIfThenElse` except that the condition to
6064
/// promote a virtual call is that \p VPtr is the same as any of \p
6165
/// AddressPoints.

llvm/lib/Analysis/CtxProfAnalysis.cpp

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,22 @@ PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
176176
return PreservedAnalyses::all();
177177
}
178178

179-
OS << "Function Info:\n";
180-
for (const auto &[Guid, FuncInfo] : C.FuncInfo)
181-
OS << Guid << " : " << FuncInfo.Name
182-
<< ". MaxCounterID: " << FuncInfo.NextCounterIndex
183-
<< ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
179+
if (Mode == PrintMode::Everything) {
180+
OS << "Function Info:\n";
181+
for (const auto &[Guid, FuncInfo] : C.FuncInfo)
182+
OS << Guid << " : " << FuncInfo.Name
183+
<< ". MaxCounterID: " << FuncInfo.NextCounterIndex
184+
<< ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
185+
}
184186

185187
const auto JSONed = ::llvm::json::toJSON(C.profiles());
186188

187-
OS << "\nCurrent Profile:\n";
189+
if (Mode == PrintMode::Everything)
190+
OS << "\nCurrent Profile:\n";
188191
OS << formatv("{0:2}", JSONed);
192+
if (Mode == PrintMode::JSON)
193+
return PreservedAnalyses::all();
194+
189195
OS << "\n";
190196
OS << "\nFlat Profile:\n";
191197
auto Flat = C.flatten();
@@ -212,34 +218,49 @@ InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) {
212218
return nullptr;
213219
}
214220

215-
static void
216-
preorderVisit(const PGOCtxProfContext::CallTargetMapTy &Profiles,
217-
function_ref<void(const PGOCtxProfContext &)> Visitor) {
218-
std::function<void(const PGOCtxProfContext &)> Traverser =
219-
[&](const auto &Ctx) {
220-
Visitor(Ctx);
221-
for (const auto &[_, SubCtxSet] : Ctx.callsites())
222-
for (const auto &[__, Subctx] : SubCtxSet)
223-
Traverser(Subctx);
224-
};
225-
for (const auto &[_, P] : Profiles)
221+
template <class ProfilesTy, class ProfTy>
222+
static void preorderVisit(ProfilesTy &Profiles,
223+
function_ref<void(ProfTy &)> Visitor,
224+
GlobalValue::GUID Match = 0) {
225+
std::function<void(ProfTy &)> Traverser = [&](auto &Ctx) {
226+
if (!Match || Ctx.guid() == Match)
227+
Visitor(Ctx);
228+
for (auto &[_, SubCtxSet] : Ctx.callsites())
229+
for (auto &[__, Subctx] : SubCtxSet)
230+
Traverser(Subctx);
231+
};
232+
for (auto &[_, P] : Profiles)
226233
Traverser(P);
227234
}
228235

236+
void PGOContextualProfile::update(Visitor V, const Function *F) {
237+
GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
238+
preorderVisit<PGOCtxProfContext::CallTargetMapTy, PGOCtxProfContext>(
239+
*Profiles, V, G);
240+
}
241+
242+
void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const {
243+
GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
244+
preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
245+
const PGOCtxProfContext>(*Profiles, V, G);
246+
}
247+
229248
const CtxProfFlatProfile PGOContextualProfile::flatten() const {
230249
assert(Profiles.has_value());
231250
CtxProfFlatProfile Flat;
232-
preorderVisit(*Profiles, [&](const PGOCtxProfContext &Ctx) {
233-
auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
234-
if (Ins) {
235-
llvm::append_range(It->second, Ctx.counters());
236-
return;
237-
}
238-
assert(It->second.size() == Ctx.counters().size() &&
239-
"All contexts corresponding to a function should have the exact "
240-
"same number of counters.");
241-
for (size_t I = 0, E = It->second.size(); I < E; ++I)
242-
It->second[I] += Ctx.counters()[I];
243-
});
251+
preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
252+
const PGOCtxProfContext>(
253+
*Profiles, [&](const PGOCtxProfContext &Ctx) {
254+
auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
255+
if (Ins) {
256+
llvm::append_range(It->second, Ctx.counters());
257+
return;
258+
}
259+
assert(It->second.size() == Ctx.counters().size() &&
260+
"All contexts corresponding to a function should have the exact "
261+
"same number of counters.");
262+
for (size_t I = 0, E = It->second.size(); I < E; ++I)
263+
It->second[I] += Ctx.counters()[I];
264+
});
244265
return Flat;
245266
}

llvm/lib/IR/IntrinsicInst.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,11 @@ ConstantInt *InstrProfCntrInstBase::getIndex() const {
286286
return cast<ConstantInt>(const_cast<Value *>(getArgOperand(3)));
287287
}
288288

289+
void InstrProfCntrInstBase::setIndex(uint32_t Idx) {
290+
assert(isa<InstrProfCntrInstBase>(this));
291+
setArgOperand(3, ConstantInt::get(Type::getInt32Ty(getContext()), Idx));
292+
}
293+
289294
Value *InstrProfIncrementInst::getStep() const {
290295
if (InstrProfIncrementInstStep::classof(this)) {
291296
return const_cast<Value *>(getArgOperand(4));
@@ -301,6 +306,11 @@ Value *InstrProfCallsite::getCallee() const {
301306
return nullptr;
302307
}
303308

309+
void InstrProfCallsite::setCallee(Value *Callee) {
310+
assert(isa<InstrProfCallsite>(this));
311+
setArgOperand(4, Callee);
312+
}
313+
304314
std::optional<RoundingMode> ConstrainedFPIntrinsic::getRoundingMode() const {
305315
unsigned NumOperands = arg_size();
306316
Metadata *MD = nullptr;

llvm/lib/Transforms/Utils/CallPromotionUtils.cpp

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
15-
#include "llvm/ADT/STLExtras.h"
15+
#include "llvm/Analysis/CtxProfAnalysis.h"
1616
#include "llvm/Analysis/Loads.h"
1717
#include "llvm/Analysis/TypeMetadataUtils.h"
1818
#include "llvm/IR/AttributeMask.h"
1919
#include "llvm/IR/Constant.h"
2020
#include "llvm/IR/IRBuilder.h"
2121
#include "llvm/IR/Instructions.h"
22+
#include "llvm/IR/IntrinsicInst.h"
2223
#include "llvm/IR/Module.h"
24+
#include "llvm/ProfileData/PGOCtxProfReader.h"
2325
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
2426

2527
using namespace llvm;
@@ -572,6 +574,88 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
572574
return promoteCall(NewInst, Callee);
573575
}
574576

577+
CallBase *llvm::promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
578+
PGOContextualProfile &CtxProf) {
579+
assert(CB.isIndirectCall());
580+
if (!CtxProf.isFunctionKnown(Callee))
581+
return nullptr;
582+
auto &Caller = *CB.getFunction();
583+
auto *CSInstr = CtxProfAnalysis::getCallsiteInstrumentation(CB);
584+
if (!CSInstr)
585+
return nullptr;
586+
const uint64_t CSIndex = CSInstr->getIndex()->getZExtValue();
587+
588+
CallBase &DirectCall = promoteCall(
589+
versionCallSite(CB, &Callee, /*BranchWeights=*/nullptr), &Callee);
590+
CSInstr->moveBefore(&CB);
591+
const auto NewCSID = CtxProf.allocateNextCallsiteIndex(Caller);
592+
auto *NewCSInstr = cast<InstrProfCallsite>(CSInstr->clone());
593+
NewCSInstr->setIndex(NewCSID);
594+
NewCSInstr->setCallee(&Callee);
595+
NewCSInstr->insertBefore(&DirectCall);
596+
auto &DirectBB = *DirectCall.getParent();
597+
auto &IndirectBB = *CB.getParent();
598+
599+
assert((CtxProfAnalysis::getBBInstrumentation(IndirectBB) == nullptr) &&
600+
"The ICP direct BB is new, it shouldn't have instrumentation");
601+
assert((CtxProfAnalysis::getBBInstrumentation(DirectBB) == nullptr) &&
602+
"The ICP indirect BB is new, it shouldn't have instrumentation");
603+
604+
// Allocate counters for the new basic blocks.
605+
const uint32_t DirectID = CtxProf.allocateNextCounterIndex(Caller);
606+
const uint32_t IndirectID = CtxProf.allocateNextCounterIndex(Caller);
607+
auto *EntryBBIns =
608+
CtxProfAnalysis::getBBInstrumentation(Caller.getEntryBlock());
609+
auto *DirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
610+
DirectBBIns->setIndex(DirectID);
611+
DirectBBIns->insertInto(&DirectBB, DirectBB.getFirstInsertionPt());
612+
613+
auto *IndirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
614+
IndirectBBIns->setIndex(IndirectID);
615+
IndirectBBIns->insertInto(&IndirectBB, IndirectBB.getFirstInsertionPt());
616+
617+
const GlobalValue::GUID CalleeGUID = AssignGUIDPass::getGUID(Callee);
618+
const uint32_t NewCountersSize = IndirectID + 1;
619+
620+
auto ProfileUpdater = [&](PGOCtxProfContext &Ctx) {
621+
assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller));
622+
assert(NewCountersSize - 2 == Ctx.counters().size());
623+
// All the ctx-es belonging to a function must have the same size counters.
624+
Ctx.resizeCounters(NewCountersSize);
625+
626+
// Maybe in this context, the indirect callsite wasn't observed at all
627+
if (!Ctx.hasCallsite(CSIndex))
628+
return;
629+
auto &CSData = Ctx.callsite(CSIndex);
630+
auto It = CSData.find(CalleeGUID);
631+
632+
// Maybe we did notice the indirect callsite, but to other targets.
633+
if (It == CSData.end())
634+
return;
635+
636+
assert(CalleeGUID == It->second.guid());
637+
638+
uint32_t DirectCount = It->second.getEntrycount();
639+
uint32_t TotalCount = 0;
640+
for (const auto &[_, V] : CSData)
641+
TotalCount += V.getEntrycount();
642+
assert(TotalCount >= DirectCount);
643+
uint32_t IndirectCount = TotalCount - DirectCount;
644+
// The ICP's effect is as-if the direct BB would have been taken DirectCount
645+
// times, and the indirect BB, IndirectCount times
646+
Ctx.counters()[DirectID] = DirectCount;
647+
Ctx.counters()[IndirectID] = IndirectCount;
648+
649+
// This particular indirect target needs to be moved to this caller under
650+
// the newly-allocated callsite index.
651+
assert(Ctx.callsites().count(NewCSID) == 0);
652+
Ctx.ingestContext(NewCSID, std::move(It->second));
653+
CSData.erase(CalleeGUID);
654+
};
655+
CtxProf.update(ProfileUpdater, &Caller);
656+
return &DirectCall;
657+
}
658+
575659
CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
576660
Function *Callee,
577661
ArrayRef<Constant *> AddressPoints,

0 commit comments

Comments
 (0)