Skip to content

Commit c6e027a

Browse files
committed
[ctx_prof] Add support for ICP
1 parent f1780fa commit c6e027a

File tree

8 files changed

+364
-32
lines changed

8 files changed

+364
-32
lines changed

llvm/include/llvm/Analysis/CtxProfAnalysis.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ class PGOContextualProfile {
7373
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex++;
7474
}
7575

76+
using ConstVisitor = function_ref<void(const PGOCtxProfContext &)>;
77+
using Visitor = function_ref<void(PGOCtxProfContext &)>;
78+
79+
void update(Visitor, const Function *F = nullptr);
80+
void visit(ConstVisitor, const Function *F = nullptr) const;
81+
7682
const CtxProfFlatProfile flatten() const;
7783

7884
bool invalidate(Module &, const PreservedAnalyses &PA,
@@ -105,13 +111,18 @@ class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
105111

106112
class CtxProfAnalysisPrinterPass
107113
: public PassInfoMixin<CtxProfAnalysisPrinterPass> {
108-
raw_ostream &OS;
109-
110114
public:
111-
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {}
115+
enum class PrintMode { Everything, JSON };
116+
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS,
117+
PrintMode Mode = PrintMode::Everything)
118+
: OS(OS), Mode(Mode) {}
112119

113120
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
114121
static bool isRequired() { return true; }
122+
123+
private:
124+
raw_ostream &OS;
125+
const PrintMode Mode;
115126
};
116127

117128
/// Assign a GUID to functions as metadata. GUID calculation takes linkage into

llvm/include/llvm/IR/IntrinsicInst.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,6 +1535,7 @@ class InstrProfCntrInstBase : public InstrProfInstBase {
15351535
ConstantInt *getNumCounters() const;
15361536
// The index of the counter that this instruction acts on.
15371537
ConstantInt *getIndex() const;
1538+
void setIndex(uint32_t Idx);
15381539
};
15391540

15401541
/// This represents the llvm.instrprof.cover intrinsic.
@@ -1585,6 +1586,7 @@ class InstrProfCallsite : public InstrProfCntrInstBase {
15851586
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
15861587
}
15871588
Value *getCallee() const;
1589+
void setCallee(Value *);
15881590
};
15891591

15901592
/// This represents the llvm.instrprof.timestamp intrinsic.

llvm/include/llvm/ProfileData/PGOCtxProfReader.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,23 @@ class PGOCtxProfContext final {
5757

5858
GlobalValue::GUID guid() const { return GUID; }
5959
const SmallVectorImpl<uint64_t> &counters() const { return Counters; }
60+
SmallVectorImpl<uint64_t> &counters() { return Counters; }
61+
62+
uint64_t getEntrycount() const { return Counters[0]; }
63+
6064
const CallsiteMapTy &callsites() const { return Callsites; }
6165
CallsiteMapTy &callsites() { return Callsites; }
6266

67+
void ingestContext(uint32_t CSId, PGOCtxProfContext &&Other) {
68+
auto [Iter, _] = callsites().try_emplace(CSId, CallTargetMapTy());
69+
Iter->second.emplace(Other.guid(), std::move(Other));
70+
}
71+
72+
void growCounters(uint32_t Size) {
73+
if (Size >= Counters.size())
74+
Counters.resize(Size);
75+
}
76+
6377
bool hasCallsite(uint32_t I) const {
6478
return Callsites.find(I) != Callsites.end();
6579
}
@@ -68,6 +82,12 @@ class PGOCtxProfContext final {
6882
assert(hasCallsite(I) && "Callsite not found");
6983
return Callsites.find(I)->second;
7084
}
85+
86+
CallTargetMapTy &callsite(uint32_t I) {
87+
assert(hasCallsite(I) && "Callsite not found");
88+
return Callsites.find(I)->second;
89+
}
90+
7191
void getContainedGuids(DenseSet<GlobalValue::GUID> &Guids) const;
7292
};
7393

llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
1515
#define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
1616

17+
#include "llvm/Analysis/CtxProfAnalysis.h"
1718
namespace llvm {
1819
template <typename T> class ArrayRef;
1920
class Constant;
@@ -56,6 +57,9 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
5657
CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
5758
MDNode *BranchWeights = nullptr);
5859

60+
CallBase *promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
61+
PGOContextualProfile &CtxProf);
62+
5963
/// This is similar to `promoteCallWithIfThenElse` except that the condition to
6064
/// promote a virtual call is that \p VPtr is the same as any of \p
6165
/// AddressPoints.

llvm/lib/Analysis/CtxProfAnalysis.cpp

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,22 @@ PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
173173
return PreservedAnalyses::all();
174174
}
175175

176-
OS << "Function Info:\n";
177-
for (const auto &[Guid, FuncInfo] : C.FuncInfo)
178-
OS << Guid << " : " << FuncInfo.Name
179-
<< ". MaxCounterID: " << FuncInfo.NextCounterIndex
180-
<< ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
176+
if (Mode == PrintMode::Everything) {
177+
OS << "Function Info:\n";
178+
for (const auto &[Guid, FuncInfo] : C.FuncInfo)
179+
OS << Guid << " : " << FuncInfo.Name
180+
<< ". MaxCounterID: " << FuncInfo.NextCounterIndex
181+
<< ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
182+
}
181183

182184
const auto JSONed = ::llvm::json::toJSON(C.profiles());
183185

184-
OS << "\nCurrent Profile:\n";
186+
if (Mode == PrintMode::Everything)
187+
OS << "\nCurrent Profile:\n";
185188
OS << formatv("{0:2}", JSONed);
189+
if (Mode == PrintMode::JSON)
190+
return PreservedAnalyses::all();
191+
186192
OS << "\n";
187193
OS << "\nFlat Profile:\n";
188194
auto Flat = C.flatten();
@@ -209,34 +215,49 @@ InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) {
209215
return nullptr;
210216
}
211217

212-
static void
213-
preorderVisit(const PGOCtxProfContext::CallTargetMapTy &Profiles,
214-
function_ref<void(const PGOCtxProfContext &)> Visitor) {
215-
std::function<void(const PGOCtxProfContext &)> Traverser =
216-
[&](const auto &Ctx) {
217-
Visitor(Ctx);
218-
for (const auto &[_, SubCtxSet] : Ctx.callsites())
219-
for (const auto &[__, Subctx] : SubCtxSet)
220-
Traverser(Subctx);
221-
};
222-
for (const auto &[_, P] : Profiles)
218+
template <class ProfilesTy, class ProfTy>
219+
static void preorderVisit(ProfilesTy &Profiles,
220+
function_ref<void(ProfTy &)> Visitor,
221+
GlobalValue::GUID Match = 0) {
222+
std::function<void(ProfTy &)> Traverser = [&](auto &Ctx) {
223+
if (!Match || Ctx.guid() == Match)
224+
Visitor(Ctx);
225+
for (auto &[_, SubCtxSet] : Ctx.callsites())
226+
for (auto &[__, Subctx] : SubCtxSet)
227+
Traverser(Subctx);
228+
};
229+
for (auto &[_, P] : Profiles)
223230
Traverser(P);
224231
}
225232

233+
void PGOContextualProfile::update(Visitor V, const Function *F) {
234+
GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
235+
preorderVisit<PGOCtxProfContext::CallTargetMapTy, PGOCtxProfContext>(
236+
*Profiles, V, G);
237+
}
238+
239+
void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const {
240+
GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
241+
preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
242+
const PGOCtxProfContext>(*Profiles, V, G);
243+
}
244+
226245
const CtxProfFlatProfile PGOContextualProfile::flatten() const {
227246
assert(Profiles.has_value());
228247
CtxProfFlatProfile Flat;
229-
preorderVisit(*Profiles, [&](const PGOCtxProfContext &Ctx) {
230-
auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
231-
if (Ins) {
232-
llvm::append_range(It->second, Ctx.counters());
233-
return;
234-
}
235-
assert(It->second.size() == Ctx.counters().size() &&
236-
"All contexts corresponding to a function should have the exact "
237-
"same number of counters.");
238-
for (size_t I = 0, E = It->second.size(); I < E; ++I)
239-
It->second[I] += Ctx.counters()[I];
240-
});
248+
preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
249+
const PGOCtxProfContext>(
250+
*Profiles, [&](const PGOCtxProfContext &Ctx) {
251+
auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
252+
if (Ins) {
253+
llvm::append_range(It->second, Ctx.counters());
254+
return;
255+
}
256+
assert(It->second.size() == Ctx.counters().size() &&
257+
"All contexts corresponding to a function should have the exact "
258+
"same number of counters.");
259+
for (size_t I = 0, E = It->second.size(); I < E; ++I)
260+
It->second[I] += Ctx.counters()[I];
261+
});
241262
return Flat;
242263
}

llvm/lib/IR/IntrinsicInst.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,11 @@ ConstantInt *InstrProfCntrInstBase::getIndex() const {
285285
return cast<ConstantInt>(const_cast<Value *>(getArgOperand(3)));
286286
}
287287

288+
void InstrProfCntrInstBase::setIndex(uint32_t Idx) {
289+
assert(isa<InstrProfCntrInstBase>(this));
290+
setArgOperand(3, ConstantInt::get(Type::getInt32Ty(getContext()), Idx));
291+
}
292+
288293
Value *InstrProfIncrementInst::getStep() const {
289294
if (InstrProfIncrementInstStep::classof(this)) {
290295
return const_cast<Value *>(getArgOperand(4));
@@ -300,6 +305,11 @@ Value *InstrProfCallsite::getCallee() const {
300305
return nullptr;
301306
}
302307

308+
void InstrProfCallsite::setCallee(Value *V) {
309+
assert(isa<InstrProfCallsite>(this));
310+
setArgOperand(4, V);
311+
}
312+
303313
std::optional<RoundingMode> ConstrainedFPIntrinsic::getRoundingMode() const {
304314
unsigned NumOperands = arg_size();
305315
Metadata *MD = nullptr;

llvm/lib/Transforms/Utils/CallPromotionUtils.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313

1414
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
1515
#include "llvm/ADT/STLExtras.h"
16+
#include "llvm/Analysis/CtxProfAnalysis.h"
1617
#include "llvm/Analysis/Loads.h"
1718
#include "llvm/Analysis/TypeMetadataUtils.h"
1819
#include "llvm/IR/AttributeMask.h"
1920
#include "llvm/IR/Constant.h"
2021
#include "llvm/IR/IRBuilder.h"
2122
#include "llvm/IR/Instructions.h"
23+
#include "llvm/IR/IntrinsicInst.h"
2224
#include "llvm/IR/Module.h"
25+
#include "llvm/ProfileData/PGOCtxProfReader.h"
2326
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
2427

2528
using namespace llvm;
@@ -572,6 +575,89 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
572575
return promoteCall(NewInst, Callee);
573576
}
574577

578+
CallBase *llvm::promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
579+
PGOContextualProfile &CtxProf) {
580+
assert(CB.isIndirectCall());
581+
if (!CtxProf.isFunctionKnown(Callee))
582+
return nullptr;
583+
auto &Caller = *CB.getParent()->getParent();
584+
auto *CSInstr = CtxProfAnalysis::getCallsiteInstrumentation(CB);
585+
if (!CSInstr)
586+
return nullptr;
587+
const auto CSIndex = CSInstr->getIndex()->getZExtValue();
588+
589+
CallBase &DirectCall = promoteCall(
590+
versionCallSite(CB, &Callee, /*BranchWeights=*/nullptr), &Callee);
591+
CSInstr->moveBefore(&CB);
592+
const auto NewCSID = CtxProf.allocateNextCallsiteIndex(Caller);
593+
auto *NewCSInstr = cast<InstrProfCallsite>(CSInstr->clone());
594+
NewCSInstr->setIndex(NewCSID);
595+
NewCSInstr->setCallee(&Callee);
596+
NewCSInstr->insertBefore(&DirectCall);
597+
auto &DirectBB = *DirectCall.getParent();
598+
auto &IndirectBB = *CB.getParent();
599+
600+
assert((CtxProfAnalysis::getBBInstrumentation(IndirectBB) == nullptr) &&
601+
"The ICP direct BB is new, it shouldn't have instrumentation");
602+
assert((CtxProfAnalysis::getBBInstrumentation(DirectBB) == nullptr) &&
603+
"The ICP indirect BB is new, it shouldn't have instrumentation");
604+
605+
// Make the 2 new BBs have counters.
606+
const uint32_t DirectID = CtxProf.allocateNextCounterIndex(Caller);
607+
const uint32_t IndirectID = CtxProf.allocateNextCounterIndex(Caller);
608+
const uint32_t NewCountersSize = IndirectID + 1;
609+
auto *EntryBBIns =
610+
CtxProfAnalysis::getBBInstrumentation(Caller.getEntryBlock());
611+
auto *DirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
612+
DirectBBIns->setIndex(DirectID);
613+
DirectBBIns->insertInto(&DirectBB, DirectBB.getFirstInsertionPt());
614+
615+
auto *IndirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
616+
IndirectBBIns->setIndex(IndirectID);
617+
IndirectBBIns->insertInto(&IndirectBB, IndirectBB.getFirstInsertionPt());
618+
619+
const GlobalValue::GUID CalleeGUID = AssignGUIDPass::getGUID(Callee);
620+
621+
auto ProfileUpdater = [&](PGOCtxProfContext &Ctx) {
622+
assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller));
623+
assert(NewCountersSize - 2 == Ctx.counters().size());
624+
// Regardless what next, all the ctx-es belonging to a function must have
625+
// the same size counters.
626+
Ctx.growCounters(NewCountersSize);
627+
628+
// Maybe in this context, the indirect callsite wasn't observed at all
629+
if (!Ctx.hasCallsite(CSIndex))
630+
return;
631+
auto &CSData = Ctx.callsite(CSIndex);
632+
auto It = CSData.find(CalleeGUID);
633+
634+
// Maybe we did notice the indirect callsite, but to other targets.
635+
if (It == CSData.end())
636+
return;
637+
638+
assert(CalleeGUID == It->second.guid());
639+
640+
uint32_t DirectCount = It->second.getEntrycount();
641+
uint32_t TotalCount = 0;
642+
for (const auto &[_, V] : CSData)
643+
TotalCount += V.getEntrycount();
644+
assert(TotalCount >= DirectCount);
645+
uint32_t IndirectCount = TotalCount - DirectCount;
646+
// The ICP's effect is as-if the direct BB would have been taken DirectCount
647+
// times, and the indirect BB, IndirectCount times
648+
Ctx.counters()[DirectID] = DirectCount;
649+
Ctx.counters()[IndirectID] = IndirectCount;
650+
651+
// This particular indirect target needs to be moved to this caller under
652+
// the newly-allocated callsite index.
653+
assert(Ctx.callsites().count(NewCSID) == 0);
654+
Ctx.ingestContext(NewCSID, std::move(It->second));
655+
CSData.erase(CalleeGUID);
656+
};
657+
CtxProf.update(ProfileUpdater, &Caller);
658+
return &DirectCall;
659+
}
660+
575661
CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
576662
Function *Callee,
577663
ArrayRef<Constant *> AddressPoints,

0 commit comments

Comments
 (0)