12
12
#include " llvm/Analysis/CtxProfAnalysis.h"
13
13
#include " llvm/Analysis/OptimizationRemarkEmitter.h"
14
14
#include " llvm/IR/Analysis.h"
15
+ #include " llvm/IR/Constants.h"
15
16
#include " llvm/IR/DiagnosticInfo.h"
16
17
#include " llvm/IR/GlobalValue.h"
17
18
#include " llvm/IR/IRBuilder.h"
19
+ #include " llvm/IR/InstrTypes.h"
18
20
#include " llvm/IR/Instructions.h"
19
21
#include " llvm/IR/IntrinsicInst.h"
20
22
#include " llvm/IR/Module.h"
@@ -55,14 +57,15 @@ class CtxInstrumentationLowerer final {
55
57
Module &M;
56
58
ModuleAnalysisManager &MAM;
57
59
Type *ContextNodeTy = nullptr ;
58
- Type *FunctionDataTy = nullptr ;
60
+ StructType *FunctionDataTy = nullptr ;
59
61
60
62
DenseSet<const Function *> ContextRootSet;
61
63
Function *StartCtx = nullptr ;
62
64
Function *GetCtx = nullptr ;
63
65
Function *ReleaseCtx = nullptr ;
64
66
GlobalVariable *ExpectedCalleeTLS = nullptr ;
65
67
GlobalVariable *CallsiteInfoTLS = nullptr ;
68
+ Constant *CannotBeRootInitializer = nullptr ;
66
69
67
70
public:
68
71
CtxInstrumentationLowerer (Module &M, ModuleAnalysisManager &MAM);
@@ -117,82 +120,101 @@ CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
117
120
118
121
#define _PTRDECL (_, __ ) PointerTy,
119
122
#define _VOLATILE_PTRDECL (_, __ ) PointerTy,
123
+ #define _CONTEXT_ROOT PointerTy,
120
124
#define _MUTEXDECL (_ ) SanitizerMutexType,
121
125
122
126
FunctionDataTy = StructType::get (
123
127
M.getContext (),
124
- {CTXPROF_FUNCTION_DATA (_PTRDECL, _VOLATILE_PTRDECL, _MUTEXDECL)});
128
+ {CTXPROF_FUNCTION_DATA (_PTRDECL, _CONTEXT_ROOT, _VOLATILE_PTRDECL, _MUTEXDECL)});
125
129
#undef _PTRDECL
130
+ #undef _CONTEXT_ROOT
126
131
#undef _VOLATILE_PTRDECL
127
132
#undef _MUTEXDECL
128
133
129
- // The Context header.
130
- ContextNodeTy = StructType::get (M.getContext (), {
131
- I64Ty, /* Guid*/
132
- PointerTy, /* Next*/
133
- I32Ty, /* NumCounters*/
134
- I32Ty, /* NumCallsites*/
135
- });
136
-
137
- // Define a global for each entrypoint. We'll reuse the entrypoint's name as
138
- // prefix. We assume the entrypoint names to be unique.
139
- for (const auto &Fname : ContextRoots) {
140
- if (const auto *F = M.getFunction (Fname)) {
141
- if (F->isDeclaration ())
142
- continue ;
143
- ContextRootSet.insert (F);
144
- for (const auto &BB : *F)
145
- for (const auto &I : BB)
146
- if (const auto *CB = dyn_cast<CallBase>(&I))
147
- if (CB->isMustTailCall ()) {
148
- M.getContext ().emitError (
149
- " The function " + Fname +
150
- " was indicated as a context root, but it features musttail "
151
- " calls, which is not supported." );
152
- }
153
- }
154
- }
134
+ #define _PTRDECL (_, __ ) Constant::getNullValue(PointerTy),
135
+ #define _VOLATILE_PTRDECL (_, __ ) _PTRDECL(_, __)
136
+ #define _MUTEXDECL (_ ) Constant::getNullValue(SanitizerMutexType),
137
+ #define _CONTEXT_ROOT \
138
+ Constant::getIntegerValue ( \
139
+ PointerTy, \
140
+ APInt (M.getDataLayout ().getPointerTypeSizeInBits (PointerTy), 1U )),
141
+ CannotBeRootInitializer = ConstantStruct::get (
142
+ FunctionDataTy, {CTXPROF_FUNCTION_DATA (_PTRDECL, _CONTEXT_ROOT,
143
+ _VOLATILE_PTRDECL, _MUTEXDECL)});
144
+ #undef _PTRDECL
145
+ #undef _CONTEXT_ROOT
146
+ #undef _VOLATILE_PTRDECL
147
+ #undef _MUTEXDECL
148
+
149
+ // The Context header.
150
+ ContextNodeTy =
151
+ StructType::get (M.getContext (), {
152
+ I64Ty, /* Guid*/
153
+ PointerTy, /* Next*/
154
+ I32Ty, /* NumCounters*/
155
+ I32Ty, /* NumCallsites*/
156
+ });
155
157
156
- // Declare the functions we will call.
157
- StartCtx = cast<Function>(
158
- M.getOrInsertFunction (
159
- CompilerRtAPINames::StartCtx,
160
- FunctionType::get (PointerTy,
161
- {PointerTy, /* FunctionData*/
162
- I64Ty, /* Guid*/ I32Ty,
163
- /* NumCounters*/ I32Ty /* NumCallsites*/ },
164
- false ))
165
- .getCallee ());
166
- GetCtx = cast<Function>(
167
- M.getOrInsertFunction (CompilerRtAPINames::GetCtx,
168
- FunctionType::get (PointerTy,
169
- {PointerTy, /* FunctionData*/
170
- PointerTy, /* Callee*/
171
- I64Ty, /* Guid*/
172
- I32Ty, /* NumCounters*/
173
- I32Ty}, /* NumCallsites*/
174
- false ))
175
- .getCallee ());
176
- ReleaseCtx = cast<Function>(
177
- M.getOrInsertFunction (CompilerRtAPINames::ReleaseCtx,
178
- FunctionType::get (Type::getVoidTy (M.getContext ()),
179
- {
180
- PointerTy, /* FunctionData*/
181
- },
182
- false ))
183
- .getCallee ());
184
-
185
- // Declare the TLSes we will need to use.
186
- CallsiteInfoTLS =
187
- new GlobalVariable (M, PointerTy, false , GlobalValue::ExternalLinkage,
188
- nullptr , CompilerRtAPINames::CallsiteTLS);
189
- CallsiteInfoTLS->setThreadLocal (true );
190
- CallsiteInfoTLS->setVisibility (llvm::GlobalValue::HiddenVisibility);
191
- ExpectedCalleeTLS =
192
- new GlobalVariable (M, PointerTy, false , GlobalValue::ExternalLinkage,
193
- nullptr , CompilerRtAPINames::ExpectedCalleeTLS);
194
- ExpectedCalleeTLS->setThreadLocal (true );
195
- ExpectedCalleeTLS->setVisibility (llvm::GlobalValue::HiddenVisibility);
158
+ // Define a global for each entrypoint. We'll reuse the entrypoint's name
159
+ // as prefix. We assume the entrypoint names to be unique.
160
+ for (const auto &Fname : ContextRoots) {
161
+ if (const auto *F = M.getFunction (Fname)) {
162
+ if (F->isDeclaration ())
163
+ continue ;
164
+ ContextRootSet.insert (F);
165
+ for (const auto &BB : *F)
166
+ for (const auto &I : BB)
167
+ if (const auto *CB = dyn_cast<CallBase>(&I))
168
+ if (CB->isMustTailCall ()) {
169
+ M.getContext ().emitError (" The function " + Fname +
170
+ " was indicated as a context root, "
171
+ " but it features musttail "
172
+ " calls, which is not supported." );
173
+ }
174
+ }
175
+ }
176
+
177
+ // Declare the functions we will call.
178
+ StartCtx = cast<Function>(
179
+ M.getOrInsertFunction (
180
+ CompilerRtAPINames::StartCtx,
181
+ FunctionType::get (PointerTy,
182
+ {PointerTy, /* FunctionData*/
183
+ I64Ty, /* Guid*/ I32Ty,
184
+ /* NumCounters*/ I32Ty /* NumCallsites*/ },
185
+ false ))
186
+ .getCallee ());
187
+ GetCtx = cast<Function>(
188
+ M.getOrInsertFunction (CompilerRtAPINames::GetCtx,
189
+ FunctionType::get (PointerTy,
190
+ {PointerTy, /* FunctionData*/
191
+ PointerTy, /* Callee*/
192
+ I64Ty, /* Guid*/
193
+ I32Ty, /* NumCounters*/
194
+ I32Ty}, /* NumCallsites*/
195
+ false ))
196
+ .getCallee ());
197
+ ReleaseCtx =
198
+ cast<Function>(M.getOrInsertFunction (
199
+ CompilerRtAPINames::ReleaseCtx,
200
+ FunctionType::get (Type::getVoidTy (M.getContext ()),
201
+ {
202
+ PointerTy, /* FunctionData*/
203
+ },
204
+ false ))
205
+ .getCallee ());
206
+
207
+ // Declare the TLSes we will need to use.
208
+ CallsiteInfoTLS =
209
+ new GlobalVariable (M, PointerTy, false , GlobalValue::ExternalLinkage,
210
+ nullptr , CompilerRtAPINames::CallsiteTLS);
211
+ CallsiteInfoTLS->setThreadLocal (true );
212
+ CallsiteInfoTLS->setVisibility (llvm::GlobalValue::HiddenVisibility);
213
+ ExpectedCalleeTLS =
214
+ new GlobalVariable (M, PointerTy, false , GlobalValue::ExternalLinkage,
215
+ nullptr , CompilerRtAPINames::ExpectedCalleeTLS);
216
+ ExpectedCalleeTLS->setThreadLocal (true );
217
+ ExpectedCalleeTLS->setVisibility (llvm::GlobalValue::HiddenVisibility);
196
218
}
197
219
198
220
PreservedAnalyses PGOCtxProfLoweringPass::run (Module &M,
@@ -240,6 +262,14 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
240
262
return false ;
241
263
}();
242
264
265
+ if (HasMusttail && ContextRootSet.contains (&F)) {
266
+ F.getContext ().emitError (
267
+ " [ctx_prof] A function with musttail calls was explicitly requested as "
268
+ " root. That is not supported because we cannot instrument a return "
269
+ " instruction to release the context: " +
270
+ F.getName ());
271
+ return false ;
272
+ }
243
273
auto &Head = F.getEntryBlock ();
244
274
for (auto &I : Head) {
245
275
// Find the increment intrinsic in the entry basic block.
@@ -263,9 +293,14 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
263
293
// regular function)
264
294
// Don't set a name, they end up taking a lot of space and we don't need
265
295
// them.
296
+
297
+ // Zero-initialize the FunctionData, except for functions that have
298
+ // musttail calls. There, we set the CtxRoot field to 1, which will be
299
+ // treated as a "can't be set as root".
266
300
TheRootFuctionData = new GlobalVariable (
267
301
M, FunctionDataTy, false , GlobalVariable::InternalLinkage,
268
- Constant::getNullValue (FunctionDataTy));
302
+ HasMusttail ? CannotBeRootInitializer
303
+ : Constant::getNullValue (FunctionDataTy));
269
304
270
305
if (ContextRootSet.contains (&F)) {
271
306
Context = Builder.CreateCall (
@@ -366,10 +401,6 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
366
401
}
367
402
}
368
403
}
369
- // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be
370
- // to disallow this, (so this then stays as an error), another is to detect
371
- // that and then do a wrapper or disallow the tail call. This only affects
372
- // instrumentation, when we want to detect the call graph.
373
404
if (!HasMusttail && !ContextWasReleased)
374
405
F.getContext ().emitError (
375
406
" [ctx_prof] A function that doesn't have musttail calls was "
0 commit comments