Skip to content

[ctx_prof] Add support for ICP #105469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions llvm/include/llvm/Analysis/CtxProfAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class PGOContextualProfile {
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex++;
}

using ConstVisitor = function_ref<void(const PGOCtxProfContext &)>;
using Visitor = function_ref<void(PGOCtxProfContext &)>;

void update(Visitor, const Function *F = nullptr);
void visit(ConstVisitor, const Function *F = nullptr) const;

const CtxProfFlatProfile flatten() const;

bool invalidate(Module &, const PreservedAnalyses &PA,
Expand Down Expand Up @@ -105,13 +111,18 @@ class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {

class CtxProfAnalysisPrinterPass
: public PassInfoMixin<CtxProfAnalysisPrinterPass> {
raw_ostream &OS;

public:
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {}
enum class PrintMode { Everything, JSON };
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS,
PrintMode Mode = PrintMode::Everything)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the default be the most verbose?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdym?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh! I think I understand what you mean. Yes, I think the default should be the most verbose, it's a testing facility.

: OS(OS), Mode(Mode) {}

PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
static bool isRequired() { return true; }

private:
raw_ostream &OS;
const PrintMode Mode;
};

/// Assign a GUID to functions as metadata. GUID calculation takes linkage into
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/IntrinsicInst.h
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,7 @@ class InstrProfCntrInstBase : public InstrProfInstBase {
ConstantInt *getNumCounters() const;
// The index of the counter that this instruction acts on.
ConstantInt *getIndex() const;
void setIndex(uint32_t Idx);
};

/// This represents the llvm.instrprof.cover intrinsic.
Expand Down Expand Up @@ -1585,6 +1586,7 @@ class InstrProfCallsite : public InstrProfCntrInstBase {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
Value *getCallee() const;
void setCallee(Value *Callee);
};

/// This represents the llvm.instrprof.timestamp intrinsic.
Expand Down
22 changes: 22 additions & 0 deletions llvm/include/llvm/ProfileData/PGOCtxProfReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,25 @@ class PGOCtxProfContext final {

GlobalValue::GUID guid() const { return GUID; }
const SmallVectorImpl<uint64_t> &counters() const { return Counters; }
SmallVectorImpl<uint64_t> &counters() { return Counters; }

uint64_t getEntrycount() const {
assert(!Counters.empty() &&
"Functions are expected to have at their entry BB instrumented, so "
"there should always be at least 1 counter.");
return Counters[0];
}

const CallsiteMapTy &callsites() const { return Callsites; }
CallsiteMapTy &callsites() { return Callsites; }

void ingestContext(uint32_t CSId, PGOCtxProfContext &&Other) {
auto [Iter, _] = callsites().try_emplace(CSId, CallTargetMapTy());
Iter->second.emplace(Other.guid(), std::move(Other));
}

void resizeCounters(uint32_t Size) { Counters.resize(Size); }

bool hasCallsite(uint32_t I) const {
return Callsites.find(I) != Callsites.end();
}
Expand All @@ -68,6 +84,12 @@ class PGOCtxProfContext final {
assert(hasCallsite(I) && "Callsite not found");
return Callsites.find(I)->second;
}

CallTargetMapTy &callsite(uint32_t I) {
assert(hasCallsite(I) && "Callsite not found");
return Callsites.find(I)->second;
}

void getContainedGuids(DenseSet<GlobalValue::GUID> &Guids) const;
};

Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
#define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H

#include "llvm/Analysis/CtxProfAnalysis.h"
namespace llvm {
template <typename T> class ArrayRef;
class Constant;
Expand Down Expand Up @@ -56,6 +57,9 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
MDNode *BranchWeights = nullptr);

CallBase *promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
PGOContextualProfile &CtxProf);

/// This is similar to `promoteCallWithIfThenElse` except that the condition to
/// promote a virtual call is that \p VPtr is the same as any of \p
/// AddressPoints.
Expand Down
79 changes: 50 additions & 29 deletions llvm/lib/Analysis/CtxProfAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,22 @@ PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
return PreservedAnalyses::all();
}

OS << "Function Info:\n";
for (const auto &[Guid, FuncInfo] : C.FuncInfo)
OS << Guid << " : " << FuncInfo.Name
<< ". MaxCounterID: " << FuncInfo.NextCounterIndex
<< ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
if (Mode == PrintMode::Everything) {
OS << "Function Info:\n";
for (const auto &[Guid, FuncInfo] : C.FuncInfo)
OS << Guid << " : " << FuncInfo.Name
<< ". MaxCounterID: " << FuncInfo.NextCounterIndex
<< ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
}

const auto JSONed = ::llvm::json::toJSON(C.profiles());

OS << "\nCurrent Profile:\n";
if (Mode == PrintMode::Everything)
OS << "\nCurrent Profile:\n";
OS << formatv("{0:2}", JSONed);
if (Mode == PrintMode::JSON)
return PreservedAnalyses::all();

OS << "\n";
OS << "\nFlat Profile:\n";
auto Flat = C.flatten();
Expand All @@ -209,34 +215,49 @@ InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) {
return nullptr;
}

static void
preorderVisit(const PGOCtxProfContext::CallTargetMapTy &Profiles,
function_ref<void(const PGOCtxProfContext &)> Visitor) {
std::function<void(const PGOCtxProfContext &)> Traverser =
[&](const auto &Ctx) {
Visitor(Ctx);
for (const auto &[_, SubCtxSet] : Ctx.callsites())
for (const auto &[__, Subctx] : SubCtxSet)
Traverser(Subctx);
};
for (const auto &[_, P] : Profiles)
template <class ProfilesTy, class ProfTy>
static void preorderVisit(ProfilesTy &Profiles,
function_ref<void(ProfTy &)> Visitor,
GlobalValue::GUID Match = 0) {
std::function<void(ProfTy &)> Traverser = [&](auto &Ctx) {
if (!Match || Ctx.guid() == Match)
Visitor(Ctx);
for (auto &[_, SubCtxSet] : Ctx.callsites())
for (auto &[__, Subctx] : SubCtxSet)
Traverser(Subctx);
};
for (auto &[_, P] : Profiles)
Traverser(P);
}

void PGOContextualProfile::update(Visitor V, const Function *F) {
GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
preorderVisit<PGOCtxProfContext::CallTargetMapTy, PGOCtxProfContext>(
*Profiles, V, G);
}

void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const {
GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
const PGOCtxProfContext>(*Profiles, V, G);
}

const CtxProfFlatProfile PGOContextualProfile::flatten() const {
assert(Profiles.has_value());
CtxProfFlatProfile Flat;
preorderVisit(*Profiles, [&](const PGOCtxProfContext &Ctx) {
auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
if (Ins) {
llvm::append_range(It->second, Ctx.counters());
return;
}
assert(It->second.size() == Ctx.counters().size() &&
"All contexts corresponding to a function should have the exact "
"same number of counters.");
for (size_t I = 0, E = It->second.size(); I < E; ++I)
It->second[I] += Ctx.counters()[I];
});
preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
const PGOCtxProfContext>(
*Profiles, [&](const PGOCtxProfContext &Ctx) {
auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
if (Ins) {
llvm::append_range(It->second, Ctx.counters());
return;
}
assert(It->second.size() == Ctx.counters().size() &&
"All contexts corresponding to a function should have the exact "
"same number of counters.");
for (size_t I = 0, E = It->second.size(); I < E; ++I)
It->second[I] += Ctx.counters()[I];
});
return Flat;
}
10 changes: 10 additions & 0 deletions llvm/lib/IR/IntrinsicInst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ ConstantInt *InstrProfCntrInstBase::getIndex() const {
return cast<ConstantInt>(const_cast<Value *>(getArgOperand(3)));
}

void InstrProfCntrInstBase::setIndex(uint32_t Idx) {
assert(isa<InstrProfCntrInstBase>(this));
setArgOperand(3, ConstantInt::get(Type::getInt32Ty(getContext()), Idx));
}

Value *InstrProfIncrementInst::getStep() const {
if (InstrProfIncrementInstStep::classof(this)) {
return const_cast<Value *>(getArgOperand(4));
Expand All @@ -300,6 +305,11 @@ Value *InstrProfCallsite::getCallee() const {
return nullptr;
}

void InstrProfCallsite::setCallee(Value *Callee) {
assert(isa<InstrProfCallsite>(this));
setArgOperand(4, Callee);
}

std::optional<RoundingMode> ConstrainedFPIntrinsic::getRoundingMode() const {
unsigned NumOperands = arg_size();
Metadata *MD = nullptr;
Expand Down
86 changes: 85 additions & 1 deletion llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/ProfileData/PGOCtxProfReader.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"

using namespace llvm;
Expand Down Expand Up @@ -572,6 +574,88 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
return promoteCall(NewInst, Callee);
}

CallBase *llvm::promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
PGOContextualProfile &CtxProf) {
assert(CB.isIndirectCall());
if (!CtxProf.isFunctionKnown(Callee))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be useful to have some statistics on how many promoted / dropped?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems to belong in the pass that will exercise this, rather.

return nullptr;
auto &Caller = *CB.getFunction();
auto *CSInstr = CtxProfAnalysis::getCallsiteInstrumentation(CB);
if (!CSInstr)
return nullptr;
const uint64_t CSIndex = CSInstr->getIndex()->getZExtValue();

CallBase &DirectCall = promoteCall(
versionCallSite(CB, &Callee, /*BranchWeights=*/nullptr), &Callee);
CSInstr->moveBefore(&CB);
const auto NewCSID = CtxProf.allocateNextCallsiteIndex(Caller);
auto *NewCSInstr = cast<InstrProfCallsite>(CSInstr->clone());
NewCSInstr->setIndex(NewCSID);
NewCSInstr->setCallee(&Callee);
NewCSInstr->insertBefore(&DirectCall);
auto &DirectBB = *DirectCall.getParent();
auto &IndirectBB = *CB.getParent();

assert((CtxProfAnalysis::getBBInstrumentation(IndirectBB) == nullptr) &&
"The ICP direct BB is new, it shouldn't have instrumentation");
assert((CtxProfAnalysis::getBBInstrumentation(DirectBB) == nullptr) &&
"The ICP indirect BB is new, it shouldn't have instrumentation");

// Allocate counters for the new basic blocks.
const uint32_t DirectID = CtxProf.allocateNextCounterIndex(Caller);
const uint32_t IndirectID = CtxProf.allocateNextCounterIndex(Caller);
auto *EntryBBIns =
CtxProfAnalysis::getBBInstrumentation(Caller.getEntryBlock());
auto *DirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
DirectBBIns->setIndex(DirectID);
DirectBBIns->insertInto(&DirectBB, DirectBB.getFirstInsertionPt());

auto *IndirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
IndirectBBIns->setIndex(IndirectID);
IndirectBBIns->insertInto(&IndirectBB, IndirectBB.getFirstInsertionPt());

const GlobalValue::GUID CalleeGUID = AssignGUIDPass::getGUID(Callee);
const uint32_t NewCountersSize = IndirectID + 1;

auto ProfileUpdater = [&](PGOCtxProfContext &Ctx) {
assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller));
assert(NewCountersSize - 2 == Ctx.counters().size());
// All the ctx-es belonging to a function must have the same size counters.
Ctx.resizeCounters(NewCountersSize);

// Maybe in this context, the indirect callsite wasn't observed at all
if (!Ctx.hasCallsite(CSIndex))
return;
auto &CSData = Ctx.callsite(CSIndex);
auto It = CSData.find(CalleeGUID);

// Maybe we did notice the indirect callsite, but to other targets.
if (It == CSData.end())
return;

assert(CalleeGUID == It->second.guid());

uint32_t DirectCount = It->second.getEntrycount();
uint32_t TotalCount = 0;
for (const auto &[_, V] : CSData)
TotalCount += V.getEntrycount();
assert(TotalCount >= DirectCount);
uint32_t IndirectCount = TotalCount - DirectCount;
// The ICP's effect is as-if the direct BB would have been taken DirectCount
// times, and the indirect BB, IndirectCount times
Ctx.counters()[DirectID] = DirectCount;
Ctx.counters()[IndirectID] = IndirectCount;

// This particular indirect target needs to be moved to this caller under
// the newly-allocated callsite index.
assert(Ctx.callsites().count(NewCSID) == 0);
Ctx.ingestContext(NewCSID, std::move(It->second));
CSData.erase(CalleeGUID);
};
CtxProf.update(ProfileUpdater, &Caller);
return &DirectCall;
}

CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
Function *Callee,
ArrayRef<Constant *> AddressPoints,
Expand Down
Loading
Loading