11
11
// ===----------------------------------------------------------------------===//
12
12
13
13
#include " NVPTXCtorDtorLowering.h"
14
+ #include " MCTargetDesc/NVPTXBaseInfo.h"
14
15
#include " NVPTX.h"
15
16
#include " llvm/ADT/StringExtras.h"
16
17
#include " llvm/IR/Constants.h"
@@ -32,6 +33,11 @@ static cl::opt<std::string>
32
33
cl::desc (" Override unique ID of ctor/dtor globals." ),
33
34
cl::init(" " ), cl::Hidden);
34
35
36
+ static cl::opt<bool >
37
+ CreateKernels (" nvptx-emit-init-fini-kernel" ,
38
+ cl::desc (" Emit kernels to call ctor/dtor globals." ),
39
+ cl::init(true ), cl::Hidden);
40
+
35
41
namespace {
36
42
37
43
static std::string getHash (StringRef Str) {
@@ -42,11 +48,163 @@ static std::string getHash(StringRef Str) {
42
48
return llvm::utohexstr (Hash.low (), /* LowerCase=*/ true );
43
49
}
44
50
45
- static bool createInitOrFiniGlobls (Module &M, StringRef GlobalName,
46
- bool IsCtor) {
47
- GlobalVariable *GV = M.getGlobalVariable (GlobalName);
48
- if (!GV || !GV->hasInitializer ())
49
- return false ;
51
+ static void addKernelMetadata (Module &M, GlobalValue *GV) {
52
+ llvm::LLVMContext &Ctx = M.getContext ();
53
+
54
+ // Get "nvvm.annotations" metadata node.
55
+ llvm::NamedMDNode *MD = M.getOrInsertNamedMetadata (" nvvm.annotations" );
56
+
57
+ llvm::Metadata *KernelMDVals[] = {
58
+ llvm::ConstantAsMetadata::get (GV), llvm::MDString::get (Ctx, " kernel" ),
59
+ llvm::ConstantAsMetadata::get (
60
+ llvm::ConstantInt::get (llvm::Type::getInt32Ty (Ctx), 1 ))};
61
+
62
+ // This kernel is only to be called single-threaded.
63
+ llvm::Metadata *ThreadXMDVals[] = {
64
+ llvm::ConstantAsMetadata::get (GV), llvm::MDString::get (Ctx, " maxntidx" ),
65
+ llvm::ConstantAsMetadata::get (
66
+ llvm::ConstantInt::get (llvm::Type::getInt32Ty (Ctx), 1 ))};
67
+ llvm::Metadata *ThreadYMDVals[] = {
68
+ llvm::ConstantAsMetadata::get (GV), llvm::MDString::get (Ctx, " maxntidy" ),
69
+ llvm::ConstantAsMetadata::get (
70
+ llvm::ConstantInt::get (llvm::Type::getInt32Ty (Ctx), 1 ))};
71
+ llvm::Metadata *ThreadZMDVals[] = {
72
+ llvm::ConstantAsMetadata::get (GV), llvm::MDString::get (Ctx, " maxntidz" ),
73
+ llvm::ConstantAsMetadata::get (
74
+ llvm::ConstantInt::get (llvm::Type::getInt32Ty (Ctx), 1 ))};
75
+
76
+ llvm::Metadata *BlockMDVals[] = {
77
+ llvm::ConstantAsMetadata::get (GV),
78
+ llvm::MDString::get (Ctx, " maxclusterrank" ),
79
+ llvm::ConstantAsMetadata::get (
80
+ llvm::ConstantInt::get (llvm::Type::getInt32Ty (Ctx), 1 ))};
81
+
82
+ // Append metadata to nvvm.annotations.
83
+ MD->addOperand (llvm::MDNode::get (Ctx, KernelMDVals));
84
+ MD->addOperand (llvm::MDNode::get (Ctx, ThreadXMDVals));
85
+ MD->addOperand (llvm::MDNode::get (Ctx, ThreadYMDVals));
86
+ MD->addOperand (llvm::MDNode::get (Ctx, ThreadZMDVals));
87
+ MD->addOperand (llvm::MDNode::get (Ctx, BlockMDVals));
88
+ }
89
+
90
+ static Function *createInitOrFiniKernelFunction (Module &M, bool IsCtor) {
91
+ StringRef InitOrFiniKernelName =
92
+ IsCtor ? " nvptx$device$init" : " nvptx$device$fini" ;
93
+ if (M.getFunction (InitOrFiniKernelName))
94
+ return nullptr ;
95
+
96
+ Function *InitOrFiniKernel = Function::createWithDefaultAttr (
97
+ FunctionType::get (Type::getVoidTy (M.getContext ()), false ),
98
+ GlobalValue::WeakODRLinkage, 0 , InitOrFiniKernelName, &M);
99
+ addKernelMetadata (M, InitOrFiniKernel);
100
+
101
+ return InitOrFiniKernel;
102
+ }
103
+
104
+ // We create the IR required to call each callback in this section. This is
105
+ // equivalent to the following code. Normally, the linker would provide us with
106
+ // the definitions of the init and fini array sections. The 'nvlink' linker does
107
+ // not do this so initializing these values is done by the runtime.
108
+ //
109
+ // extern "C" void **__init_array_start = nullptr;
110
+ // extern "C" void **__init_array_end = nullptr;
111
+ // extern "C" void **__fini_array_start = nullptr;
112
+ // extern "C" void **__fini_array_end = nullptr;
113
+ //
114
+ // using InitCallback = void();
115
+ // using FiniCallback = void();
116
+ //
117
+ // void call_init_array_callbacks() {
118
+ // for (auto start = __init_array_start; start != __init_array_end; ++start)
119
+ // reinterpret_cast<InitCallback *>(*start)();
120
+ // }
121
+ //
122
+ // void call_init_array_callbacks() {
123
+ // size_t fini_array_size = __fini_array_end - __fini_array_start;
124
+ // for (size_t i = fini_array_size; i > 0; --i)
125
+ // reinterpret_cast<FiniCallback *>(__fini_array_start[i - 1])();
126
+ // }
127
+ static void createInitOrFiniCalls (Function &F, bool IsCtor) {
128
+ Module &M = *F.getParent ();
129
+ LLVMContext &C = M.getContext ();
130
+
131
+ IRBuilder<> IRB (BasicBlock::Create (C, " entry" , &F));
132
+ auto *LoopBB = BasicBlock::Create (C, " while.entry" , &F);
133
+ auto *ExitBB = BasicBlock::Create (C, " while.end" , &F);
134
+ Type *PtrTy = IRB.getPtrTy (llvm::ADDRESS_SPACE_GLOBAL);
135
+
136
+ auto *Begin = M.getOrInsertGlobal (
137
+ IsCtor ? " __init_array_start" : " __fini_array_start" ,
138
+ PointerType::get (C, 0 ), [&]() {
139
+ auto *GV = new GlobalVariable (
140
+ M, PointerType::get (C, 0 ),
141
+ /* isConstant=*/ false , GlobalValue::WeakAnyLinkage,
142
+ Constant::getNullValue (PointerType::get (C, 0 )),
143
+ IsCtor ? " __init_array_start" : " __fini_array_start" ,
144
+ /* InsertBefore=*/ nullptr , GlobalVariable::NotThreadLocal,
145
+ /* AddressSpace=*/ llvm::ADDRESS_SPACE_GLOBAL);
146
+ GV->setVisibility (GlobalVariable::ProtectedVisibility);
147
+ return GV;
148
+ });
149
+ auto *End = M.getOrInsertGlobal (
150
+ IsCtor ? " __init_array_end" : " __fini_array_end" , PointerType::get (C, 0 ),
151
+ [&]() {
152
+ auto *GV = new GlobalVariable (
153
+ M, PointerType::get (C, 0 ),
154
+ /* isConstant=*/ false , GlobalValue::WeakAnyLinkage,
155
+ Constant::getNullValue (PointerType::get (C, 0 )),
156
+ IsCtor ? " __init_array_end" : " __fini_array_end" ,
157
+ /* InsertBefore=*/ nullptr , GlobalVariable::NotThreadLocal,
158
+ /* AddressSpace=*/ llvm::ADDRESS_SPACE_GLOBAL);
159
+ GV->setVisibility (GlobalVariable::ProtectedVisibility);
160
+ return GV;
161
+ });
162
+
163
+ // The constructor type is suppoed to allow using the argument vectors, but
164
+ // for now we just call them with no arguments.
165
+ auto *CallBackTy = FunctionType::get (IRB.getVoidTy (), {});
166
+
167
+ // The destructor array must be called in reverse order. Get an expression to
168
+ // the end of the array and iterate backwards in that case.
169
+ Value *BeginVal = IRB.CreateLoad (Begin->getType (), Begin, " begin" );
170
+ Value *EndVal = IRB.CreateLoad (Begin->getType (), End, " stop" );
171
+ if (!IsCtor) {
172
+ auto *BeginInt = IRB.CreatePtrToInt (BeginVal, IntegerType::getInt64Ty (C));
173
+ auto *EndInt = IRB.CreatePtrToInt (EndVal, IntegerType::getInt64Ty (C));
174
+ auto *SubInst = IRB.CreateSub (EndInt, BeginInt);
175
+ auto *Offset = IRB.CreateAShr (
176
+ SubInst, ConstantInt::get (IntegerType::getInt64Ty (C), 3 ), " offset" ,
177
+ /* IsExact=*/ true );
178
+ auto *ValuePtr = IRB.CreateGEP (PointerType::get (C, 0 ), BeginVal,
179
+ ArrayRef<Value *>({Offset}));
180
+ EndVal = BeginVal;
181
+ BeginVal = IRB.CreateInBoundsGEP (
182
+ PointerType::get (C, 0 ), ValuePtr,
183
+ ArrayRef<Value *>(ConstantInt::get (IntegerType::getInt64Ty (C), -1 )),
184
+ " start" );
185
+ }
186
+ IRB.CreateCondBr (
187
+ IRB.CreateCmp (IsCtor ? ICmpInst::ICMP_NE : ICmpInst::ICMP_UGT, BeginVal,
188
+ EndVal),
189
+ LoopBB, ExitBB);
190
+ IRB.SetInsertPoint (LoopBB);
191
+ auto *CallBackPHI = IRB.CreatePHI (PtrTy, 2 , " ptr" );
192
+ auto *CallBack = IRB.CreateLoad (CallBackTy->getPointerTo (F.getAddressSpace ()),
193
+ CallBackPHI, " callback" );
194
+ IRB.CreateCall (CallBackTy, CallBack);
195
+ auto *NewCallBack =
196
+ IRB.CreateConstGEP1_64 (PtrTy, CallBackPHI, IsCtor ? 1 : -1 , " next" );
197
+ auto *EndCmp = IRB.CreateCmp (IsCtor ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_ULT,
198
+ NewCallBack, EndVal, " end" );
199
+ CallBackPHI->addIncoming (BeginVal, &F.getEntryBlock ());
200
+ CallBackPHI->addIncoming (NewCallBack, LoopBB);
201
+ IRB.CreateCondBr (EndCmp, ExitBB, LoopBB);
202
+ IRB.SetInsertPoint (ExitBB);
203
+ IRB.CreateRetVoid ();
204
+ }
205
+
206
+ static bool createInitOrFiniGlobals (Module &M, GlobalVariable *GV,
207
+ bool IsCtor) {
50
208
ConstantArray *GA = dyn_cast<ConstantArray>(GV->getInitializer ());
51
209
if (!GA || GA->getNumOperands () == 0 )
52
210
return false ;
@@ -81,14 +239,35 @@ static bool createInitOrFiniGlobls(Module &M, StringRef GlobalName,
81
239
appendToUsed (M, {GV});
82
240
}
83
241
242
+ return true ;
243
+ }
244
+
245
+ static bool createInitOrFiniKernel (Module &M, StringRef GlobalName,
246
+ bool IsCtor) {
247
+ GlobalVariable *GV = M.getGlobalVariable (GlobalName);
248
+ if (!GV || !GV->hasInitializer ())
249
+ return false ;
250
+
251
+ if (!createInitOrFiniGlobals (M, GV, IsCtor))
252
+ return false ;
253
+
254
+ if (!CreateKernels)
255
+ return true ;
256
+
257
+ Function *InitOrFiniKernel = createInitOrFiniKernelFunction (M, IsCtor);
258
+ if (!InitOrFiniKernel)
259
+ return false ;
260
+
261
+ createInitOrFiniCalls (*InitOrFiniKernel, IsCtor);
262
+
84
263
GV->eraseFromParent ();
85
264
return true ;
86
265
}
87
266
88
267
static bool lowerCtorsAndDtors (Module &M) {
89
268
bool Modified = false ;
90
- Modified |= createInitOrFiniGlobls (M, " llvm.global_ctors" , /* IsCtor =*/ true );
91
- Modified |= createInitOrFiniGlobls (M, " llvm.global_dtors" , /* IsCtor =*/ false );
269
+ Modified |= createInitOrFiniKernel (M, " llvm.global_ctors" , /* IsCtor =*/ true );
270
+ Modified |= createInitOrFiniKernel (M, " llvm.global_dtors" , /* IsCtor =*/ false );
92
271
return Modified;
93
272
}
94
273
0 commit comments