Skip to content

[ctx_prof] Add Inlining support #106154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions llvm/include/llvm/Analysis/CtxProfAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PassManager.h"
#include "llvm/ProfileData/PGOCtxProfReader.h"
#include <optional>

namespace llvm {

Expand Down Expand Up @@ -63,6 +64,16 @@ class PGOContextualProfile {
return getDefinedFunctionGUID(F) != 0;
}

uint32_t getNumCounters(const Function &F) const {
assert(isFunctionKnown(F));
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCounterIndex;
}

uint32_t getNumCallsites(const Function &F) const {
assert(isFunctionKnown(F));
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex;
}

uint32_t allocateNextCounterIndex(const Function &F) {
assert(isFunctionKnown(F));
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCounterIndex++;
Expand Down Expand Up @@ -91,11 +102,11 @@ class PGOContextualProfile {
};

class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
StringRef Profile;
const std::optional<StringRef> Profile;

public:
static AnalysisKey Key;
explicit CtxProfAnalysis(StringRef Profile = "");
explicit CtxProfAnalysis(std::optional<StringRef> Profile = std::nullopt);

using Result = PGOContextualProfile;

Expand All @@ -113,9 +124,7 @@ class CtxProfAnalysisPrinterPass
: public PassInfoMixin<CtxProfAnalysisPrinterPass> {
public:
enum class PrintMode { Everything, JSON };
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS,
PrintMode Mode = PrintMode::Everything)
: OS(OS), Mode(Mode) {}
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS);

PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
static bool isRequired() { return true; }
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/IntrinsicInst.h
Original file line number Diff line number Diff line change
Expand Up @@ -1516,6 +1516,8 @@ class InstrProfInstBase : public IntrinsicInst {
return const_cast<Value *>(getArgOperand(0))->stripPointerCasts();
}

void setNameValue(Value *V) { setArgOperand(0, V); }

// The hash of the CFG for the instrumented function.
ConstantInt *getHash() const {
return cast<ConstantInt>(const_cast<Value *>(getArgOperand(1)));
Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/ProfileData/PGOCtxProfReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ class PGOCtxProfContext final {
Iter->second.emplace(Other.guid(), std::move(Other));
}

void ingestAllContexts(uint32_t CSId, CallTargetMapTy &&Other) {
auto [_, Inserted] = callsites().try_emplace(CSId, std::move(Other));
(void)Inserted;
assert(Inserted &&
"CSId was expected to be newly created as result of e.g. inlining");
}

void resizeCounters(uint32_t Size) { Counters.resize(Size); }

bool hasCallsite(uint32_t I) const {
Expand Down
12 changes: 12 additions & 0 deletions llvm/include/llvm/Transforms/Utils/Cloning.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/ValueHandle.h"
Expand Down Expand Up @@ -270,6 +271,17 @@ InlineResult InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
bool InsertLifetime = true,
Function *ForwardVarArgsTo = nullptr);

/// Same as above, but it will update the contextual profile. If the contextual
/// profile is invalid (i.e. not loaded because it is not present), it defaults
/// to the behavior of the non-contextual profile updating variant above. This
/// makes it easy to drop-in replace uses of the non-contextual overload.
InlineResult InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
CtxProfAnalysis::Result &CtxProf,
bool MergeAttributes = false,
AAResults *CalleeAAR = nullptr,
bool InsertLifetime = true,
Function *ForwardVarArgsTo = nullptr);

/// Clones a loop \p OrigLoop. Returns the loop and the blocks in \p
/// Blocks.
///
Expand Down
29 changes: 24 additions & 5 deletions llvm/lib/Analysis/CtxProfAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ cl::opt<std::string>
UseCtxProfile("use-ctx-profile", cl::init(""), cl::Hidden,
cl::desc("Use the specified contextual profile file"));

static cl::opt<CtxProfAnalysisPrinterPass::PrintMode> PrintLevel(
"ctx-profile-printer-level",
cl::init(CtxProfAnalysisPrinterPass::PrintMode::JSON), cl::Hidden,
cl::values(clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::Everything,
"everything", "print everything - most verbose"),
clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::JSON, "json",
"just the json representation of the profile")),
cl::desc("Verbosity level of the contextual profile printer pass."));

namespace llvm {
namespace json {
Value toJSON(const PGOCtxProfContext &P) {
Expand Down Expand Up @@ -96,12 +105,20 @@ GlobalValue::GUID AssignGUIDPass::getGUID(const Function &F) {
}
AnalysisKey CtxProfAnalysis::Key;

CtxProfAnalysis::CtxProfAnalysis(StringRef Profile)
: Profile(Profile.empty() ? UseCtxProfile : Profile) {}
CtxProfAnalysis::CtxProfAnalysis(std::optional<StringRef> Profile)
: Profile([&]() -> std::optional<StringRef> {
if (Profile)
return *Profile;
if (UseCtxProfile.getNumOccurrences())
return UseCtxProfile;
return std::nullopt;
}()) {}

PGOContextualProfile CtxProfAnalysis::run(Module &M,
ModuleAnalysisManager &MAM) {
ErrorOr<std::unique_ptr<MemoryBuffer>> MB = MemoryBuffer::getFile(Profile);
if (!Profile)
return {};
ErrorOr<std::unique_ptr<MemoryBuffer>> MB = MemoryBuffer::getFile(*Profile);
if (auto EC = MB.getError()) {
M.getContext().emitError("could not open contextual profile file: " +
EC.message());
Expand Down Expand Up @@ -150,7 +167,6 @@ PGOContextualProfile CtxProfAnalysis::run(Module &M,
// If we made it this far, the Result is valid - which we mark by setting
// .Profiles.
// Trim first the roots that aren't in this module.
DenseSet<GlobalValue::GUID> ProfiledGUIDs;
for (auto &[RootGuid, _] : llvm::make_early_inc_range(*MaybeCtx))
if (!Result.FuncInfo.contains(RootGuid))
MaybeCtx->erase(RootGuid);
Expand All @@ -165,11 +181,14 @@ PGOContextualProfile::getDefinedFunctionGUID(const Function &F) const {
return 0;
}

CtxProfAnalysisPrinterPass::CtxProfAnalysisPrinterPass(raw_ostream &OS)
: OS(OS), Mode(PrintLevel) {}

PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
ModuleAnalysisManager &MAM) {
CtxProfAnalysis::Result &C = MAM.getResult<CtxProfAnalysis>(M);
if (!C) {
M.getContext().emitError("Invalid CtxProfAnalysis");
OS << "No contextual profile was provided.\n";
return PreservedAnalyses::all();
}

Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Transforms/IPO/ModuleInliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/Analysis/InlineAdvisor.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/Analysis/InlineOrder.h"
Expand Down Expand Up @@ -113,6 +114,8 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
return PreservedAnalyses::all();
}

auto &CtxProf = MAM.getResult<CtxProfAnalysis>(M);

bool Changed = false;

ProfileSummaryInfo *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(M);
Expand Down Expand Up @@ -213,7 +216,7 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
&FAM.getResult<BlockFrequencyAnalysis>(Callee));

InlineResult IR =
InlineFunction(*CB, IFI, /*MergeAttributes=*/true,
InlineFunction(*CB, IFI, CtxProf, /*MergeAttributes=*/true,
&FAM.getResult<AAManager>(*CB->getCaller()));
if (!IR.isSuccess()) {
Advice->recordUnsuccessfulInlining(IR);
Expand Down
Loading
Loading