-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[llvm][IR] Extend BranchWeightMetadata to track provenance of weights #86609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
1e9dbac
ffb5bb8
ee503bf
7217003
7b541e1
e5bd278
af27efe
bc338aa
5a5ae01
6e16264
56c4658
319aaa6
37af7e2
21382a0
6729b47
4749bdc
7760282
947f9e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,6 +55,20 @@ MDNode *getBranchWeightMDNode(const Instruction &I); | |
/// Nullptr otherwise. | ||
MDNode *getValidBranchWeightMDNode(const Instruction &I); | ||
|
||
/// Check if Branch Weight Metadata has an "expected" field from an llvm.expect* | ||
/// intrinsic | ||
bool hasExpectedProvenance(const Instruction &I); | ||
|
||
/// Check if Branch Weight Metadata has an "expected" field from an llvm.expect* | ||
/// intrinsic | ||
bool hasExpectedProvenance(const MDNode *ProfileData); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW: I usually prefer references like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair points ... maybe we should consider that refactoring anyway? I can file a Github issue to track it, and try to set aside some time over the next couple of weeks to look into it. |
||
|
||
/// Return the offset to the first branch weight data | ||
unsigned getBranchWeightOffset(const Instruction &I); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this API really helpful? Don't you have to get your hands on an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a fair point. IIRC I needed this someplace, but maybe I'm mis-remembering and that's a moot point. I'll look at my follow up patches to ensure it isn't required. |
||
|
||
/// Return the offset to the first branch weight data | ||
unsigned getBranchWeightOffset(const MDNode *ProfileData); | ||
|
||
/// Extract branch weights from MD_prof metadata | ||
/// | ||
/// \param ProfileData A pointer to an MDNode. | ||
|
@@ -65,9 +79,14 @@ bool extractBranchWeights(const MDNode *ProfileData, | |
SmallVectorImpl<uint32_t> &Weights); | ||
|
||
/// Faster version of extractBranchWeights() that skips checks and must only | ||
/// be called with "branch_weights" metadata nodes. | ||
void extractFromBranchWeightMD(const MDNode *ProfileData, | ||
SmallVectorImpl<uint32_t> &Weights); | ||
/// be called with "branch_weights" metadata nodes. Supports uint32_t. | ||
void extractFromBranchWeightMD32(const MDNode *ProfileData, | ||
SmallVectorImpl<uint32_t> &Weights); | ||
|
||
/// Faster version of extractBranchWeights() that skips checks and must only | ||
/// be called with "branch_weights" metadata nodes. Supports uint64_t. | ||
void extractFromBranchWeightMD64(const MDNode *ProfileData, | ||
SmallVectorImpl<uint64_t> &Weights); | ||
|
||
/// Extract branch weights attatched to an Instruction | ||
/// | ||
|
@@ -106,7 +125,11 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights); | |
|
||
/// Create a new `branch_weights` metadata node and add or overwrite | ||
/// a `prof` metadata reference to instruction `I`. | ||
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights); | ||
/// \param I the Instruction to set branch weights on. | ||
/// \param Weights an array of weights to set on instruction I. | ||
/// \param IsExpected were these weights added from an llvm.expect* intrinsic. | ||
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights, | ||
bool IsExpected); | ||
|
||
} // namespace llvm | ||
#endif |
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1210,12 +1210,22 @@ Instruction *Instruction::cloneImpl() const { | |||||||||||||||
|
||||||||||||||||
void Instruction::swapProfMetadata() { | ||||||||||||||||
MDNode *ProfileData = getBranchWeightMDNode(*this); | ||||||||||||||||
if (!ProfileData || ProfileData->getNumOperands() != 3) | ||||||||||||||||
if (!isBranchWeightMD(ProfileData)) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extract this and other similar refactoring change into a different patch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, that's a good point. In my head these were all tied together w/ the change to the metadata layout, but maybe I can restructure ProfdataUtils first, and then update the surrounding code, and after that's done introduce the metadata changes. Thanks for the suggestion. I'll take a pass at that soon. |
||||||||||||||||
return; | ||||||||||||||||
|
||||||||||||||||
// The first operand is the name. Fetch them backwards and build a new one. | ||||||||||||||||
Metadata *Ops[] = {ProfileData->getOperand(0), ProfileData->getOperand(2), | ||||||||||||||||
ProfileData->getOperand(1)}; | ||||||||||||||||
SmallVector<Metadata *, 4> Ops; | ||||||||||||||||
unsigned int FirstIdx = getBranchWeightOffset(ProfileData); | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a strong opinion, but I think most people just write
Suggested change
|
||||||||||||||||
unsigned int SecondIdx = FirstIdx + 1; | ||||||||||||||||
// If there are more weights past the second, we can't swap them | ||||||||||||||||
if (ProfileData->getNumOperands() > SecondIdx + 1) | ||||||||||||||||
return; | ||||||||||||||||
Ops.push_back(ProfileData->getOperand(0)); | ||||||||||||||||
if (hasExpectedProvenance(ProfileData)) { | ||||||||||||||||
Ops.push_back(ProfileData->getOperand(1)); | ||||||||||||||||
} | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this (I just have a feeling that leaving it more generic may help in case new sources are added one day):
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good suggestion. I'll update the patch to reflect. |
||||||||||||||||
// Switch the order of the weights | ||||||||||||||||
Ops.push_back(ProfileData->getOperand(SecondIdx)); | ||||||||||||||||
Ops.push_back(ProfileData->getOperand(FirstIdx)); | ||||||||||||||||
setMetadata(LLVMContext::MD_prof, | ||||||||||||||||
MDNode::get(ProfileData->getContext(), Ops)); | ||||||||||||||||
} | ||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -857,10 +857,12 @@ void CallInst::updateProfWeight(uint64_t S, uint64_t T) { | |
APInt APS(128, S), APT(128, T); | ||
if (ProfDataName->getString().equals("branch_weights") && | ||
ProfileData->getNumOperands() > 0) { | ||
unsigned int Offset = getBranchWeightOffset(ProfileData); | ||
// Using APInt::div may be expensive, but most cases should fit 64 bits. | ||
APInt Val(128, mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1)) | ||
->getValue() | ||
.getZExtValue()); | ||
APInt Val(128, | ||
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Offset)) | ||
->getValue() | ||
.getZExtValue()); | ||
Val *= APS; | ||
Vals.push_back(MDB.createConstant( | ||
ConstantInt::get(Type::getInt32Ty(getContext()), | ||
|
@@ -5196,7 +5198,11 @@ void SwitchInstProfUpdateWrapper::init() { | |
if (!ProfileData) | ||
return; | ||
|
||
if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) { | ||
// FIXME: This check belongs in ProfDataUtils. Its almost equivalent to | ||
// getValidBranchWeightMDNode(), but the need to use llvm_unreachable | ||
// makes them slightly different. | ||
if (ProfileData->getNumOperands() != | ||
SI.getNumSuccessors() + getBranchWeightOffset(ProfileData)) { | ||
Comment on lines
+5202
to
+5206
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems simple enough to do something about it instead of adding a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another good suggestion. Thank you. |
||
llvm_unreachable("number of prof branch_weights metadata operands does " | ||
"not correspond to number of succesors"); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,9 +40,6 @@ namespace { | |
// We maintain some constants here to ensure that we access the branch weights | ||
// correctly, and can change the behavior in the future if the layout changes | ||
|
||
// The index at which the weights vector starts | ||
constexpr unsigned WeightsIdx = 1; | ||
|
||
// the minimum number of operands for MD_prof nodes with branch weights | ||
constexpr unsigned MinBWOps = 3; | ||
|
||
|
@@ -65,6 +62,27 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { | |
return ProfDataName->getString().equals(Name); | ||
} | ||
|
||
template <typename T, | ||
typename = typename std::enable_if<std::is_arithmetic_v<T>>> | ||
static void extractFromBranchWeightMD(const MDNode *ProfileData, | ||
SmallVectorImpl<T> &Weights) { | ||
assert(isBranchWeightMD(ProfileData) && "wrong metadata"); | ||
|
||
unsigned NOps = ProfileData->getNumOperands(); | ||
unsigned int WeightsIdx = getBranchWeightOffset(ProfileData); | ||
assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); | ||
Weights.resize(NOps - WeightsIdx); | ||
|
||
for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { | ||
ConstantInt *Weight = | ||
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); | ||
assert(Weight && "Malformed branch_weight in MD_prof node"); | ||
assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) && | ||
"Too many bits for MD_prof branch_weight"); | ||
Weights[Idx - WeightsIdx] = Weight->getZExtValue(); | ||
} | ||
} | ||
|
||
} // namespace | ||
|
||
namespace llvm { | ||
|
@@ -86,6 +104,30 @@ bool hasValidBranchWeightMD(const Instruction &I) { | |
return getValidBranchWeightMDNode(I); | ||
} | ||
|
||
bool hasExpectedProvenance(const Instruction &I) { | ||
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); | ||
return hasExpectedProvenance(ProfileData); | ||
} | ||
|
||
bool hasExpectedProvenance(const MDNode *ProfileData) { | ||
if (!isBranchWeightMD(ProfileData)) | ||
return false; | ||
|
||
auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1)); | ||
if (!ProfDataName) | ||
return false; | ||
return ProfDataName->getString().equals("expected"); | ||
} | ||
|
||
unsigned getBranchWeightOffset(const Instruction &I) { | ||
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); | ||
return getBranchWeightOffset(ProfileData); | ||
} | ||
|
||
unsigned getBranchWeightOffset(const MDNode *ProfileData) { | ||
return hasExpectedProvenance(ProfileData) ? 2 : 1; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about a |
||
|
||
MDNode *getBranchWeightMDNode(const Instruction &I) { | ||
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); | ||
if (!isBranchWeightMD(ProfileData)) | ||
|
@@ -95,27 +137,21 @@ MDNode *getBranchWeightMDNode(const Instruction &I) { | |
|
||
MDNode *getValidBranchWeightMDNode(const Instruction &I) { | ||
auto *ProfileData = getBranchWeightMDNode(I); | ||
if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors()) | ||
auto Offset = getBranchWeightOffset(ProfileData); | ||
if (ProfileData && | ||
ProfileData->getNumOperands() == Offset + I.getNumSuccessors()) | ||
return ProfileData; | ||
return nullptr; | ||
} | ||
|
||
void extractFromBranchWeightMD(const MDNode *ProfileData, | ||
SmallVectorImpl<uint32_t> &Weights) { | ||
assert(isBranchWeightMD(ProfileData) && "wrong metadata"); | ||
|
||
unsigned NOps = ProfileData->getNumOperands(); | ||
assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); | ||
Weights.resize(NOps - WeightsIdx); | ||
void extractFromBranchWeightMD32(const MDNode *ProfileData, | ||
SmallVectorImpl<uint32_t> &Weights) { | ||
extractFromBranchWeightMD(ProfileData, Weights); | ||
} | ||
|
||
for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { | ||
ConstantInt *Weight = | ||
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); | ||
assert(Weight && "Malformed branch_weight in MD_prof node"); | ||
assert(Weight->getValue().getActiveBits() <= 32 && | ||
"Too many bits for uint32_t"); | ||
Weights[Idx - WeightsIdx] = Weight->getZExtValue(); | ||
} | ||
void extractFromBranchWeightMD64(const MDNode *ProfileData, | ||
SmallVectorImpl<uint64_t> &Weights) { | ||
extractFromBranchWeightMD(ProfileData, Weights); | ||
} | ||
|
||
bool extractBranchWeights(const MDNode *ProfileData, | ||
|
@@ -162,7 +198,8 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { | |
return false; | ||
|
||
if (ProfDataName->getString().equals("branch_weights")) { | ||
for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) { | ||
unsigned int Offset = getBranchWeightOffset(ProfileData); | ||
for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) { | ||
auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); | ||
assert(V && "Malformed branch_weight in MD_prof node"); | ||
TotalVal += V->getValue().getZExtValue(); | ||
|
@@ -184,9 +221,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { | |
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); | ||
} | ||
|
||
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) { | ||
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights, | ||
bool IsExpected) { | ||
MDBuilder MDB(I.getContext()); | ||
MDNode *BranchWeights = MDB.createBranchWeights(Weights); | ||
MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected); | ||
I.setMetadata(LLVMContext::MD_prof, BranchWeights); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,6 +103,7 @@ | |
#include "llvm/IR/Module.h" | ||
#include "llvm/IR/ModuleSlotTracker.h" | ||
#include "llvm/IR/PassManager.h" | ||
#include "llvm/IR/ProfDataUtils.h" | ||
#include "llvm/IR/Statepoint.h" | ||
#include "llvm/IR/Type.h" | ||
#include "llvm/IR/Use.h" | ||
|
@@ -4756,8 +4757,10 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) { | |
|
||
// Check consistency of !prof branch_weights metadata. | ||
if (ProfName.equals("branch_weights")) { | ||
unsigned int Offset = getBranchWeightOffset(I); | ||
if (isa<InvokeInst>(&I)) { | ||
Check(MD->getNumOperands() == 2 || MD->getNumOperands() == 3, | ||
Check(MD->getNumOperands() == (1 + Offset) || | ||
MD->getNumOperands() == (2 + Offset), | ||
Comment on lines
+4814
to
+4815
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More opportunities for a possible |
||
"Wrong number of InvokeInst branch_weights operands", MD); | ||
} else { | ||
unsigned ExpectedNumOperands = 0; | ||
|
@@ -4777,10 +4780,10 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) { | |
CheckFailed("!prof branch_weights are not allowed for this instruction", | ||
MD); | ||
|
||
Check(MD->getNumOperands() == 1 + ExpectedNumOperands, | ||
Check(MD->getNumOperands() == Offset + ExpectedNumOperands, | ||
"Wrong number of operands", MD); | ||
} | ||
for (unsigned i = 1; i < MD->getNumOperands(); ++i) { | ||
for (unsigned i = Offset; i < MD->getNumOperands(); ++i) { | ||
auto &MDO = MD->getOperand(i); | ||
Check(MDO, "second operand should not be null", MD); | ||
Check(mdconst::dyn_extract<ConstantInt>(MDO), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm maybe this should somehow have the name "Weight" in the name, to not introduce confusion with alias-analysis things (at least I immediately think "alias analysis" when I see the word "provenance") or find different term than "provenance"?