Skip to content

[llvm][IR] Extend BranchWeightMetadata to track provenance of weights #86609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -221,5 +221,5 @@ void tu2(int &i) {
}
}

// CHECK: [[BW_LIKELY]] = !{!"branch_weights", i32 2000, i32 1}
// CHECK: [[BW_UNLIKELY]] = !{!"branch_weights", i32 1, i32 2000}
// CHECK: [[BW_LIKELY]] = !{!"branch_weights", !"expected", i32 2000, i32 1}
// CHECK: [[BW_UNLIKELY]] = !{!"branch_weights", !"expected", i32 1, i32 2000}
7 changes: 7 additions & 0 deletions llvm/docs/BranchWeightMetadata.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ Supported Instructions

Metadata is only assigned to the conditional branches. There are two extra
operands for the true and the false branch.
We optionally track if the metadata was added by ``__builtin_expect`` or
``__builtin_expect_with_probability`` with an optional field ``!"expected"``.

.. code-block:: none

!0 = !{
!"branch_weights",
[ !"expected", ]
i32 <TRUE_BRANCH_WEIGHT>,
i32 <FALSE_BRANCH_WEIGHT>
}
Expand All @@ -47,6 +50,7 @@ is always case #0).

!0 = !{
!"branch_weights",
[ !"expected", ]
i32 <DEFAULT_BRANCH_WEIGHT>
[ , i32 <CASE_BRANCH_WEIGHT> ... ]
}
Expand All @@ -60,6 +64,7 @@ Branch weights are assigned to every destination.

!0 = !{
!"branch_weights",
[ !"expected", ]
i32 <LABEL_BRANCH_WEIGHT>
[ , i32 <LABEL_BRANCH_WEIGHT> ... ]
}
Expand All @@ -75,6 +80,7 @@ block and entry counts which may not be accurate with sampling.

!0 = !{
!"branch_weights",
[ !"expected", ]
i32 <CALL_BRANCH_WEIGHT>
}

Expand All @@ -95,6 +101,7 @@ is used.

!0 = !{
!"branch_weights",
[ !"expected", ]
i32 <INVOKE_NORMAL_WEIGHT>
[ , i32 <INVOKE_UNWIND_WEIGHT> ]
}
Expand Down
11 changes: 9 additions & 2 deletions llvm/include/llvm/IR/MDBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,17 @@ class MDBuilder {
//===------------------------------------------------------------------===//

/// Return metadata containing two branch weights.
MDNode *createBranchWeights(uint32_t TrueWeight, uint32_t FalseWeight);
/// @param TrueWeight the weight of the true branch
/// @param FalseWeight the weight of the false branch
/// @param Do these weights come from __builtin_expect*
MDNode *createBranchWeights(uint32_t TrueWeight, uint32_t FalseWeight,
bool IsExpected = false);

/// Return metadata containing a number of branch weights.
MDNode *createBranchWeights(ArrayRef<uint32_t> Weights);
/// @param Weights the weights of all the branches
/// @param Do these weights come from __builtin_expect*
MDNode *createBranchWeights(ArrayRef<uint32_t> Weights,
bool IsExpected = false);

/// Return metadata specifying that a branch or switch is unpredictable.
MDNode *createUnpredictable();
Expand Down
31 changes: 27 additions & 4 deletions llvm/include/llvm/IR/ProfDataUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ MDNode *getBranchWeightMDNode(const Instruction &I);
/// Nullptr otherwise.
MDNode *getValidBranchWeightMDNode(const Instruction &I);

/// Check if Branch Weight Metadata has an "expected" field from an llvm.expect*
/// intrinsic
bool hasExpectedProvenance(const Instruction &I);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm maybe this should somehow have the name "Weight" in the name, to not introduce confusion with alias-analysis things (at least I immediately think "alias analysis" when I see the word "provenance") or find different term than "provenance"?


/// Check if Branch Weight Metadata has an "expected" field from an llvm.expect*
/// intrinsic
bool hasExpectedProvenance(const MDNode *ProfileData);
Copy link
Contributor

@MatzeB MatzeB Apr 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW: I usually prefer references like const MDNode &ProfileData to indicate that an argument mustn't be nullptr. Though admittedly that remark comes too late given we already have other const MDNode * APIs in this header and consistency is also worth something... So I guess I'm fine either way...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair points ... maybe we should consider that refactoring anyway? I can file a Github issue to track it, and try to set aside some time over the next couple of weeks to look into it.


/// Return the offset to the first branch weight data
unsigned getBranchWeightOffset(const Instruction &I);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this API really helpful? Don't you have to get your hands on an MDNode anyway to do something useful with that offset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a fair point. IIRC I needed this someplace, but maybe I'm mis-remembering and that's a moot point. I'll look at my follow up patches to ensure it isn't required.


/// Return the offset to the first branch weight data
unsigned getBranchWeightOffset(const MDNode *ProfileData);

/// Extract branch weights from MD_prof metadata
///
/// \param ProfileData A pointer to an MDNode.
Expand All @@ -65,9 +79,14 @@ bool extractBranchWeights(const MDNode *ProfileData,
SmallVectorImpl<uint32_t> &Weights);

/// Faster version of extractBranchWeights() that skips checks and must only
/// be called with "branch_weights" metadata nodes.
void extractFromBranchWeightMD(const MDNode *ProfileData,
SmallVectorImpl<uint32_t> &Weights);
/// be called with "branch_weights" metadata nodes. Supports uint32_t.
void extractFromBranchWeightMD32(const MDNode *ProfileData,
SmallVectorImpl<uint32_t> &Weights);

/// Faster version of extractBranchWeights() that skips checks and must only
/// be called with "branch_weights" metadata nodes. Supports uint64_t.
void extractFromBranchWeightMD64(const MDNode *ProfileData,
SmallVectorImpl<uint64_t> &Weights);

/// Extract branch weights attatched to an Instruction
///
Expand Down Expand Up @@ -106,7 +125,11 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);

/// Create a new `branch_weights` metadata node and add or overwrite
/// a `prof` metadata reference to instruction `I`.
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
/// \param I the Instruction to set branch weights on.
/// \param Weights an array of weights to set on instruction I.
/// \param IsExpected were these weights added from an llvm.expect* intrinsic.
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
bool IsExpected);

} // namespace llvm
#endif
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/CodeGenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8843,7 +8843,8 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, ModifyDT &ModifiedDT) {
scaleWeights(NewTrueWeight, NewFalseWeight);
Br1->setMetadata(LLVMContext::MD_prof,
MDBuilder(Br1->getContext())
.createBranchWeights(TrueWeight, FalseWeight));
.createBranchWeights(TrueWeight, FalseWeight,
hasExpectedProvenance(*Br1)));

NewTrueWeight = TrueWeight;
NewFalseWeight = 2 * FalseWeight;
Expand Down
18 changes: 14 additions & 4 deletions llvm/lib/IR/Instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1210,12 +1210,22 @@ Instruction *Instruction::cloneImpl() const {

void Instruction::swapProfMetadata() {
MDNode *ProfileData = getBranchWeightMDNode(*this);
if (!ProfileData || ProfileData->getNumOperands() != 3)
if (!isBranchWeightMD(ProfileData))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extract this and other similar refactoring change into a different patch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's a good point. In my head these were all tied together w/ the change to the metadata layout, but maybe I can restructure ProfdataUtils first, and then update the surrounding code, and after that's done introduce the metadata changes. Thanks for the suggestion. I'll take a pass at that soon.

return;

// The first operand is the name. Fetch them backwards and build a new one.
Metadata *Ops[] = {ProfileData->getOperand(0), ProfileData->getOperand(2),
ProfileData->getOperand(1)};
SmallVector<Metadata *, 4> Ops;
unsigned int FirstIdx = getBranchWeightOffset(ProfileData);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a strong opinion, but I think most people just write unsigned in LLVM codebase...

Suggested change
unsigned int FirstIdx = getBranchWeightOffset(ProfileData);
unsigned FirstIdx = getBranchWeightOffset(ProfileData);

unsigned int SecondIdx = FirstIdx + 1;
// If there are more weights past the second, we can't swap them
if (ProfileData->getNumOperands() > SecondIdx + 1)
return;
Ops.push_back(ProfileData->getOperand(0));
if (hasExpectedProvenance(ProfileData)) {
Ops.push_back(ProfileData->getOperand(1));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this (I just have a feeling that leaving it more generic may help in case new sources are added one day):

Suggested change
Ops.push_back(ProfileData->getOperand(0));
if (hasExpectedProvenance(ProfileData)) {
Ops.push_back(ProfileData->getOperand(1));
}
for (unsigned I = 0; I < FirstIdx; I++) {
Ops.push_back(ProfileData->getOperand(I));
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good suggestion. I'll update the patch to reflect.

// Switch the order of the weights
Ops.push_back(ProfileData->getOperand(SecondIdx));
Ops.push_back(ProfileData->getOperand(FirstIdx));
setMetadata(LLVMContext::MD_prof,
MDNode::get(ProfileData->getContext(), Ops));
}
Expand Down
14 changes: 10 additions & 4 deletions llvm/lib/IR/Instructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,10 +857,12 @@ void CallInst::updateProfWeight(uint64_t S, uint64_t T) {
APInt APS(128, S), APT(128, T);
if (ProfDataName->getString().equals("branch_weights") &&
ProfileData->getNumOperands() > 0) {
unsigned int Offset = getBranchWeightOffset(ProfileData);
// Using APInt::div may be expensive, but most cases should fit 64 bits.
APInt Val(128, mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1))
->getValue()
.getZExtValue());
APInt Val(128,
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Offset))
->getValue()
.getZExtValue());
Val *= APS;
Vals.push_back(MDB.createConstant(
ConstantInt::get(Type::getInt32Ty(getContext()),
Expand Down Expand Up @@ -5196,7 +5198,11 @@ void SwitchInstProfUpdateWrapper::init() {
if (!ProfileData)
return;

if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) {
// FIXME: This check belongs in ProfDataUtils. Its almost equivalent to
// getValidBranchWeightMDNode(), but the need to use llvm_unreachable
// makes them slightly different.
if (ProfileData->getNumOperands() !=
SI.getNumSuccessors() + getBranchWeightOffset(ProfileData)) {
Comment on lines +5202 to +5206
Copy link
Contributor

@MatzeB MatzeB Apr 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems simple enough to do something about it instead of adding a FIXME? Could for example add a getNumBranchWeights(<profile_data>) API so this can become getNumBranchWeights(ProfileData) != SI.getNumSuccessors()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another good suggestion. Thank you.

llvm_unreachable("number of prof branch_weights metadata operands does "
"not correspond to number of succesors");
}
Expand Down
14 changes: 9 additions & 5 deletions llvm/lib/IR/MDBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,23 @@ MDNode *MDBuilder::createFPMath(float Accuracy) {
}

MDNode *MDBuilder::createBranchWeights(uint32_t TrueWeight,
uint32_t FalseWeight) {
return createBranchWeights({TrueWeight, FalseWeight});
uint32_t FalseWeight, bool IsExpected) {
return createBranchWeights({TrueWeight, FalseWeight}, IsExpected);
}

MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights) {
MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights,
bool IsExpected) {
assert(Weights.size() >= 1 && "Need at least one branch weights!");

SmallVector<Metadata *, 4> Vals(Weights.size() + 1);
unsigned int Offset = IsExpected ? 2 : 1;
SmallVector<Metadata *, 4> Vals(Weights.size() + Offset);
Vals[0] = createString("branch_weights");
if (IsExpected)
Vals[1] = createString("expected");

Type *Int32Ty = Type::getInt32Ty(Context);
for (unsigned i = 0, e = Weights.size(); i != e; ++i)
Vals[i + 1] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));
Vals[i + Offset] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));

return MDNode::get(Context, Vals);
}
Expand Down
82 changes: 60 additions & 22 deletions llvm/lib/IR/ProfDataUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ namespace {
// We maintain some constants here to ensure that we access the branch weights
// correctly, and can change the behavior in the future if the layout changes

// The index at which the weights vector starts
constexpr unsigned WeightsIdx = 1;

// the minimum number of operands for MD_prof nodes with branch weights
constexpr unsigned MinBWOps = 3;

Expand All @@ -65,6 +62,27 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
return ProfDataName->getString().equals(Name);
}

template <typename T,
typename = typename std::enable_if<std::is_arithmetic_v<T>>>
static void extractFromBranchWeightMD(const MDNode *ProfileData,
SmallVectorImpl<T> &Weights) {
assert(isBranchWeightMD(ProfileData) && "wrong metadata");

unsigned NOps = ProfileData->getNumOperands();
unsigned int WeightsIdx = getBranchWeightOffset(ProfileData);
assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
Weights.resize(NOps - WeightsIdx);

for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
ConstantInt *Weight =
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
assert(Weight && "Malformed branch_weight in MD_prof node");
assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
"Too many bits for MD_prof branch_weight");
Weights[Idx - WeightsIdx] = Weight->getZExtValue();
}
}

} // namespace

namespace llvm {
Expand All @@ -86,6 +104,30 @@ bool hasValidBranchWeightMD(const Instruction &I) {
return getValidBranchWeightMDNode(I);
}

bool hasExpectedProvenance(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
return hasExpectedProvenance(ProfileData);
}

bool hasExpectedProvenance(const MDNode *ProfileData) {
if (!isBranchWeightMD(ProfileData))
return false;

auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
if (!ProfDataName)
return false;
return ProfDataName->getString().equals("expected");
}

unsigned getBranchWeightOffset(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
return getBranchWeightOffset(ProfileData);
}

unsigned getBranchWeightOffset(const MDNode *ProfileData) {
return hasExpectedProvenance(ProfileData) ? 2 : 1;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about a hasBranchWeightProvenance() API instead that just checks whether there is a string? That way you would get the same effect today but can skip the string comparison (and maybe get better behavior if there is a string that isn't actually "expected")


MDNode *getBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
if (!isBranchWeightMD(ProfileData))
Expand All @@ -95,27 +137,21 @@ MDNode *getBranchWeightMDNode(const Instruction &I) {

MDNode *getValidBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = getBranchWeightMDNode(I);
if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors())
auto Offset = getBranchWeightOffset(ProfileData);
if (ProfileData &&
ProfileData->getNumOperands() == Offset + I.getNumSuccessors())
return ProfileData;
return nullptr;
}

void extractFromBranchWeightMD(const MDNode *ProfileData,
SmallVectorImpl<uint32_t> &Weights) {
assert(isBranchWeightMD(ProfileData) && "wrong metadata");

unsigned NOps = ProfileData->getNumOperands();
assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
Weights.resize(NOps - WeightsIdx);
void extractFromBranchWeightMD32(const MDNode *ProfileData,
SmallVectorImpl<uint32_t> &Weights) {
extractFromBranchWeightMD(ProfileData, Weights);
}

for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
ConstantInt *Weight =
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
assert(Weight && "Malformed branch_weight in MD_prof node");
assert(Weight->getValue().getActiveBits() <= 32 &&
"Too many bits for uint32_t");
Weights[Idx - WeightsIdx] = Weight->getZExtValue();
}
void extractFromBranchWeightMD64(const MDNode *ProfileData,
SmallVectorImpl<uint64_t> &Weights) {
extractFromBranchWeightMD(ProfileData, Weights);
}

bool extractBranchWeights(const MDNode *ProfileData,
Expand Down Expand Up @@ -162,7 +198,8 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
return false;

if (ProfDataName->getString().equals("branch_weights")) {
for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
unsigned int Offset = getBranchWeightOffset(ProfileData);
for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
assert(V && "Malformed branch_weight in MD_prof node");
TotalVal += V->getValue().getZExtValue();
Expand All @@ -184,9 +221,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
}

void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
bool IsExpected) {
MDBuilder MDB(I.getContext());
MDNode *BranchWeights = MDB.createBranchWeights(Weights);
MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
}

Expand Down
9 changes: 6 additions & 3 deletions llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/ModuleSlotTracker.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Statepoint.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
Expand Down Expand Up @@ -4756,8 +4757,10 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {

// Check consistency of !prof branch_weights metadata.
if (ProfName.equals("branch_weights")) {
unsigned int Offset = getBranchWeightOffset(I);
if (isa<InvokeInst>(&I)) {
Check(MD->getNumOperands() == 2 || MD->getNumOperands() == 3,
Check(MD->getNumOperands() == (1 + Offset) ||
MD->getNumOperands() == (2 + Offset),
Comment on lines +4814 to +4815
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More opportunities for a possible getNumBranchWeights(...) API...

"Wrong number of InvokeInst branch_weights operands", MD);
} else {
unsigned ExpectedNumOperands = 0;
Expand All @@ -4777,10 +4780,10 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
CheckFailed("!prof branch_weights are not allowed for this instruction",
MD);

Check(MD->getNumOperands() == 1 + ExpectedNumOperands,
Check(MD->getNumOperands() == Offset + ExpectedNumOperands,
"Wrong number of operands", MD);
}
for (unsigned i = 1; i < MD->getNumOperands(); ++i) {
for (unsigned i = Offset; i < MD->getNumOperands(); ++i) {
auto &MDO = MD->getOperand(i);
Check(MDO, "second operand should not be null", MD);
Check(mdconst::dyn_extract<ConstantInt>(MDO),
Expand Down
Loading