Skip to content

Commit df34906

Browse files
committed
Reapply "[llvm][IR] Extend BranchWeightMetadata to track provenance of weights" llvm#95136
Reverts llvm#95060, and relands llvm#86609, with the unintended code generation changes addressed. This patch implements the changes to LLVM IR discussed in https://discourse.llvm.org/t/rfc-update-branch-weights-metadata-to-allow-tracking-branch-weight-origins/75032 In this patch, we add an optional field to MD_prof meatdata nodes for branch weights, which can be used to distinguish weights added from llvm.expect* intrinsics from those added via other methods, e.g. from profiles or inserted by the compiler. One of the major motivations, is for use with MisExpect diagnostics, which need to know if branch_weight metadata originates from an llvm.expect intrinsic. Without that information, we end up checking branch weights multiple times in the case if ThinLTO + SampleProfiling, leading to some inaccuracy in how we report MisExpect related diagnostics to users. Since we change the format of MD_prof metadata in a fundamental way, we need to update code handling branch weights in a number of places. We also update the lang ref for branch weights to reflect the change.
1 parent 3f3e85c commit df34906

31 files changed

+179
-92
lines changed

clang/test/CodeGenCXX/attr-likelihood-if-vs-builtin-expect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,5 +221,5 @@ void tu2(int &i) {
221221
}
222222
}
223223

224-
// CHECK: [[BW_LIKELY]] = !{!"branch_weights", i32 2000, i32 1}
225-
// CHECK: [[BW_UNLIKELY]] = !{!"branch_weights", i32 1, i32 2000}
224+
// CHECK: [[BW_LIKELY]] = !{!"branch_weights", !"expected", i32 2000, i32 1}
225+
// CHECK: [[BW_UNLIKELY]] = !{!"branch_weights", !"expected", i32 1, i32 2000}

llvm/docs/BranchWeightMetadata.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@ Supported Instructions
2828

2929
Metadata is only assigned to the conditional branches. There are two extra
3030
operands for the true and the false branch.
31+
We optionally track if the metadata was added by ``__builtin_expect`` or
32+
``__builtin_expect_with_probability`` with an optional field ``!"expected"``.
3133

3234
.. code-block:: none
3335
3436
!0 = !{
3537
!"branch_weights",
38+
[ !"expected", ]
3639
i32 <TRUE_BRANCH_WEIGHT>,
3740
i32 <FALSE_BRANCH_WEIGHT>
3841
}
@@ -47,6 +50,7 @@ is always case #0).
4750
4851
!0 = !{
4952
!"branch_weights",
53+
[ !"expected", ]
5054
i32 <DEFAULT_BRANCH_WEIGHT>
5155
[ , i32 <CASE_BRANCH_WEIGHT> ... ]
5256
}
@@ -60,6 +64,7 @@ Branch weights are assigned to every destination.
6064
6165
!0 = !{
6266
!"branch_weights",
67+
[ !"expected", ]
6368
i32 <LABEL_BRANCH_WEIGHT>
6469
[ , i32 <LABEL_BRANCH_WEIGHT> ... ]
6570
}
@@ -75,6 +80,7 @@ block and entry counts which may not be accurate with sampling.
7580
7681
!0 = !{
7782
!"branch_weights",
83+
[ !"expected", ]
7884
i32 <CALL_BRANCH_WEIGHT>
7985
}
8086
@@ -95,6 +101,7 @@ is used.
95101
96102
!0 = !{
97103
!"branch_weights",
104+
[ !"expected", ]
98105
i32 <INVOKE_NORMAL_WEIGHT>
99106
[ , i32 <INVOKE_UNWIND_WEIGHT> ]
100107
}

llvm/include/llvm/IR/MDBuilder.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ class MDBuilder {
5959
//===------------------------------------------------------------------===//
6060

6161
/// Return metadata containing two branch weights.
62-
MDNode *createBranchWeights(uint32_t TrueWeight, uint32_t FalseWeight);
62+
/// @param TrueWeight the weight of the true branch
63+
/// @param FalseWeight the weight of the false branch
64+
/// @param Do these weights come from __builtin_expect*
65+
MDNode *createBranchWeights(uint32_t TrueWeight, uint32_t FalseWeight,
66+
bool IsExpected = false);
6367

6468
/// Return metadata containing two branch weights, with significant bias
6569
/// towards `true` destination.
@@ -70,7 +74,10 @@ class MDBuilder {
7074
MDNode *createUnlikelyBranchWeights();
7175

7276
/// Return metadata containing a number of branch weights.
73-
MDNode *createBranchWeights(ArrayRef<uint32_t> Weights);
77+
/// @param Weights the weights of all the branches
78+
/// @param Do these weights come from __builtin_expect*
79+
MDNode *createBranchWeights(ArrayRef<uint32_t> Weights,
80+
bool IsExpected = false);
7481

7582
/// Return metadata specifying that a branch or switch is unpredictable.
7683
MDNode *createUnpredictable();

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ MDNode *getBranchWeightMDNode(const Instruction &I);
5555
/// Nullptr otherwise.
5656
MDNode *getValidBranchWeightMDNode(const Instruction &I);
5757

58+
/// Check if Branch Weight Metadata has an "expected" field from an llvm.expect*
59+
/// intrinsic
60+
bool hasBranchWeightOrigin(const Instruction &I);
61+
62+
/// Check if Branch Weight Metadata has an "expected" field from an llvm.expect*
63+
/// intrinsic
64+
bool hasBranchWeightOrigin(const MDNode *ProfileData);
65+
66+
/// Return the offset to the first branch weight data
67+
unsigned getBranchWeightOffset(const MDNode *ProfileData);
68+
5869
/// Extract branch weights from MD_prof metadata
5970
///
6071
/// \param ProfileData A pointer to an MDNode.
@@ -111,7 +122,11 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
111122

112123
/// Create a new `branch_weights` metadata node and add or overwrite
113124
/// a `prof` metadata reference to instruction `I`.
114-
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
125+
/// \param I the Instruction to set branch weights on.
126+
/// \param Weights an array of weights to set on instruction I.
127+
/// \param IsExpected were these weights added from an llvm.expect* intrinsic.
128+
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
129+
bool IsExpected);
115130

116131
/// Scaling the profile data attached to 'I' using the ratio of S/T.
117132
void scaleProfData(Instruction &I, uint64_t S, uint64_t T);

llvm/lib/Bitcode/Reader/BitcodeReader.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
#include "llvm/IR/Module.h"
5858
#include "llvm/IR/ModuleSummaryIndex.h"
5959
#include "llvm/IR/Operator.h"
60+
#include "llvm/IR/ProfDataUtils.h"
6061
#include "llvm/IR/Type.h"
6162
#include "llvm/IR/Value.h"
6263
#include "llvm/IR/Verifier.h"
@@ -6951,8 +6952,10 @@ Error BitcodeReader::materialize(GlobalValue *GV) {
69516952
else
69526953
continue; // ignore and continue.
69536954

6955+
unsigned Offset = getBranchWeightOffset(MD);
6956+
69546957
// If branch weight doesn't match, just strip branch weight.
6955-
if (MD->getNumOperands() != 1 + ExpectedNumOperands)
6958+
if (MD->getNumOperands() != Offset + ExpectedNumOperands)
69566959
I.setMetadata(LLVMContext::MD_prof, nullptr);
69576960
}
69586961
}

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8866,7 +8866,8 @@ bool CodeGenPrepare::splitBranchCondition(Function &F, ModifyDT &ModifiedDT) {
88668866
scaleWeights(NewTrueWeight, NewFalseWeight);
88678867
Br1->setMetadata(LLVMContext::MD_prof,
88688868
MDBuilder(Br1->getContext())
8869-
.createBranchWeights(TrueWeight, FalseWeight));
8869+
.createBranchWeights(TrueWeight, FalseWeight,
8870+
hasBranchWeightOrigin(*Br1)));
88708871

88718872
NewTrueWeight = TrueWeight;
88728873
NewFalseWeight = 2 * FalseWeight;

llvm/lib/IR/Instruction.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,12 +1268,23 @@ Instruction *Instruction::cloneImpl() const {
12681268

12691269
void Instruction::swapProfMetadata() {
12701270
MDNode *ProfileData = getBranchWeightMDNode(*this);
1271-
if (!ProfileData || ProfileData->getNumOperands() != 3)
1271+
if (!ProfileData)
1272+
return;
1273+
unsigned FirstIdx = getBranchWeightOffset(ProfileData);
1274+
if (ProfileData->getNumOperands() != 2 + FirstIdx)
12721275
return;
12731276

1274-
// The first operand is the name. Fetch them backwards and build a new one.
1275-
Metadata *Ops[] = {ProfileData->getOperand(0), ProfileData->getOperand(2),
1276-
ProfileData->getOperand(1)};
1277+
unsigned SecondIdx = FirstIdx + 1;
1278+
SmallVector<Metadata *, 4> Ops;
1279+
// If there are more weights past the second, we can't swap them
1280+
if (ProfileData->getNumOperands() > SecondIdx + 1)
1281+
return;
1282+
for (unsigned Idx = 0; Idx < FirstIdx; ++Idx) {
1283+
Ops.push_back(ProfileData->getOperand(Idx));
1284+
}
1285+
// Switch the order of the weights
1286+
Ops.push_back(ProfileData->getOperand(SecondIdx));
1287+
Ops.push_back(ProfileData->getOperand(FirstIdx));
12771288
setMetadata(LLVMContext::MD_prof,
12781289
MDNode::get(ProfileData->getContext(), Ops));
12791290
}

llvm/lib/IR/Instructions.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5199,7 +5199,11 @@ void SwitchInstProfUpdateWrapper::init() {
51995199
if (!ProfileData)
52005200
return;
52015201

5202-
if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) {
5202+
// FIXME: This check belongs in ProfDataUtils. Its almost equivalent to
5203+
// getValidBranchWeightMDNode(), but the need to use llvm_unreachable
5204+
// makes them slightly different.
5205+
if (ProfileData->getNumOperands() !=
5206+
SI.getNumSuccessors() + getBranchWeightOffset(ProfileData)) {
52035207
llvm_unreachable("number of prof branch_weights metadata operands does "
52045208
"not correspond to number of succesors");
52055209
}

llvm/lib/IR/MDBuilder.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ MDNode *MDBuilder::createFPMath(float Accuracy) {
3535
}
3636

3737
MDNode *MDBuilder::createBranchWeights(uint32_t TrueWeight,
38-
uint32_t FalseWeight) {
39-
return createBranchWeights({TrueWeight, FalseWeight});
38+
uint32_t FalseWeight, bool IsExpected) {
39+
return createBranchWeights({TrueWeight, FalseWeight}, IsExpected);
4040
}
4141

4242
MDNode *MDBuilder::createLikelyBranchWeights() {
@@ -49,15 +49,19 @@ MDNode *MDBuilder::createUnlikelyBranchWeights() {
4949
return createBranchWeights(1, (1U << 20) - 1);
5050
}
5151

52-
MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights) {
52+
MDNode *MDBuilder::createBranchWeights(ArrayRef<uint32_t> Weights,
53+
bool IsExpected) {
5354
assert(Weights.size() >= 1 && "Need at least one branch weights!");
5455

55-
SmallVector<Metadata *, 4> Vals(Weights.size() + 1);
56+
unsigned int Offset = IsExpected ? 2 : 1;
57+
SmallVector<Metadata *, 4> Vals(Weights.size() + Offset);
5658
Vals[0] = createString("branch_weights");
59+
if (IsExpected)
60+
Vals[1] = createString("expected");
5761

5862
Type *Int32Ty = Type::getInt32Ty(Context);
5963
for (unsigned i = 0, e = Weights.size(); i != e; ++i)
60-
Vals[i + 1] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));
64+
Vals[i + Offset] = createConstant(ConstantInt::get(Int32Ty, Weights[i]));
6165

6266
return MDNode::get(Context, Vals);
6367
}

llvm/lib/IR/Metadata.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,10 +1196,10 @@ MDNode *MDNode::mergeDirectCallProfMetadata(MDNode *A, MDNode *B,
11961196
StringRef AProfName = AMDS->getString();
11971197
StringRef BProfName = BMDS->getString();
11981198
if (AProfName == "branch_weights" && BProfName == "branch_weights") {
1199-
ConstantInt *AInstrWeight =
1200-
mdconst::dyn_extract<ConstantInt>(A->getOperand(1));
1201-
ConstantInt *BInstrWeight =
1202-
mdconst::dyn_extract<ConstantInt>(B->getOperand(1));
1199+
ConstantInt *AInstrWeight = mdconst::dyn_extract<ConstantInt>(
1200+
A->getOperand(getBranchWeightOffset(A)));
1201+
ConstantInt *BInstrWeight = mdconst::dyn_extract<ConstantInt>(
1202+
B->getOperand(getBranchWeightOffset(B)));
12031203
assert(AInstrWeight && BInstrWeight && "verified by LLVM verifier");
12041204
return MDNode::get(Ctx,
12051205
{MDHelper.createString("branch_weights"),

llvm/lib/IR/ProfDataUtils.cpp

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/IR/LLVMContext.h"
2020
#include "llvm/IR/MDBuilder.h"
2121
#include "llvm/IR/Metadata.h"
22+
#include "llvm/IR/ProfDataUtils.h"
2223
#include "llvm/Support/BranchProbability.h"
2324
#include "llvm/Support/CommandLine.h"
2425

@@ -40,9 +41,6 @@ namespace {
4041
// We maintain some constants here to ensure that we access the branch weights
4142
// correctly, and can change the behavior in the future if the layout changes
4243

43-
// The index at which the weights vector starts
44-
constexpr unsigned WeightsIdx = 1;
45-
4644
// the minimum number of operands for MD_prof nodes with branch weights
4745
constexpr unsigned MinBWOps = 3;
4846

@@ -75,15 +73,16 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData,
7573
assert(isBranchWeightMD(ProfileData) && "wrong metadata");
7674

7775
unsigned NOps = ProfileData->getNumOperands();
76+
unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
7877
assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
7978
Weights.resize(NOps - WeightsIdx);
8079

8180
for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
8281
ConstantInt *Weight =
8382
mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
8483
assert(Weight && "Malformed branch_weight in MD_prof node");
85-
assert(Weight->getValue().getActiveBits() <= 32 &&
86-
"Too many bits for uint32_t");
84+
assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
85+
"Too many bits for MD_prof branch_weight");
8786
Weights[Idx - WeightsIdx] = Weight->getZExtValue();
8887
}
8988
}
@@ -123,6 +122,26 @@ bool hasValidBranchWeightMD(const Instruction &I) {
123122
return getValidBranchWeightMDNode(I);
124123
}
125124

125+
bool hasBranchWeightOrigin(const Instruction &I) {
126+
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
127+
return hasBranchWeightOrigin(ProfileData);
128+
}
129+
130+
bool hasBranchWeightOrigin(const MDNode *ProfileData) {
131+
if (!isBranchWeightMD(ProfileData))
132+
return false;
133+
auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
134+
// NOTE: if we ever have more types of branch weight provenance,
135+
// we need to check the string value is "expected". For now, we
136+
// supply a more generic API, and avoid the spurious comparisons.
137+
assert(ProfDataName == nullptr || ProfDataName->getString() == "expected");
138+
return ProfDataName != nullptr;
139+
}
140+
141+
unsigned getBranchWeightOffset(const MDNode *ProfileData) {
142+
return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
143+
}
144+
126145
MDNode *getBranchWeightMDNode(const Instruction &I) {
127146
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
128147
if (!isBranchWeightMD(ProfileData))
@@ -132,7 +151,9 @@ MDNode *getBranchWeightMDNode(const Instruction &I) {
132151

133152
MDNode *getValidBranchWeightMDNode(const Instruction &I) {
134153
auto *ProfileData = getBranchWeightMDNode(I);
135-
if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors())
154+
auto Offset = getBranchWeightOffset(ProfileData);
155+
if (ProfileData &&
156+
ProfileData->getNumOperands() == Offset + I.getNumSuccessors())
136157
return ProfileData;
137158
return nullptr;
138159
}
@@ -191,7 +212,8 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
191212
return false;
192213

193214
if (ProfDataName->getString() == "branch_weights") {
194-
for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
215+
unsigned Offset = getBranchWeightOffset(ProfileData);
216+
for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
195217
auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
196218
assert(V && "Malformed branch_weight in MD_prof node");
197219
TotalVal += V->getValue().getZExtValue();
@@ -212,9 +234,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
212234
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
213235
}
214236

215-
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
237+
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
238+
bool IsExpected) {
216239
MDBuilder MDB(I.getContext());
217-
MDNode *BranchWeights = MDB.createBranchWeights(Weights);
240+
MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
218241
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
219242
}
220243

@@ -241,9 +264,11 @@ void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
241264
if (ProfDataName->getString() == "branch_weights" &&
242265
ProfileData->getNumOperands() > 0) {
243266
// Using APInt::div may be expensive, but most cases should fit 64 bits.
244-
APInt Val(128, mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1))
245-
->getValue()
246-
.getZExtValue());
267+
APInt Val(128,
268+
mdconst::dyn_extract<ConstantInt>(
269+
ProfileData->getOperand(getBranchWeightOffset(ProfileData)))
270+
->getValue()
271+
.getZExtValue());
247272
Val *= APS;
248273
Vals.push_back(MDB.createConstant(ConstantInt::get(
249274
Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX))));

llvm/lib/IR/Verifier.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
#include "llvm/IR/Module.h"
105105
#include "llvm/IR/ModuleSlotTracker.h"
106106
#include "llvm/IR/PassManager.h"
107+
#include "llvm/IR/ProfDataUtils.h"
107108
#include "llvm/IR/Statepoint.h"
108109
#include "llvm/IR/Type.h"
109110
#include "llvm/IR/Use.h"
@@ -4808,8 +4809,10 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
48084809

48094810
// Check consistency of !prof branch_weights metadata.
48104811
if (ProfName == "branch_weights") {
4812+
unsigned int Offset = getBranchWeightOffset(MD);
48114813
if (isa<InvokeInst>(&I)) {
4812-
Check(MD->getNumOperands() == 2 || MD->getNumOperands() == 3,
4814+
Check(MD->getNumOperands() == (1 + Offset) ||
4815+
MD->getNumOperands() == (2 + Offset),
48134816
"Wrong number of InvokeInst branch_weights operands", MD);
48144817
} else {
48154818
unsigned ExpectedNumOperands = 0;
@@ -4829,10 +4832,10 @@ void Verifier::visitProfMetadata(Instruction &I, MDNode *MD) {
48294832
CheckFailed("!prof branch_weights are not allowed for this instruction",
48304833
MD);
48314834

4832-
Check(MD->getNumOperands() == 1 + ExpectedNumOperands,
4835+
Check(MD->getNumOperands() == Offset + ExpectedNumOperands,
48334836
"Wrong number of operands", MD);
48344837
}
4835-
for (unsigned i = 1; i < MD->getNumOperands(); ++i) {
4838+
for (unsigned i = Offset; i < MD->getNumOperands(); ++i) {
48364839
auto &MDO = MD->getOperand(i);
48374840
Check(MDO, "second operand should not be null", MD);
48384841
Check(mdconst::dyn_extract<ConstantInt>(MDO),

0 commit comments

Comments
 (0)