Skip to content

[NVPTX] Allow the ctor/dtor lowering pass to emit kernels #71549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 186 additions & 7 deletions llvm/lib/Target/NVPTX/NVPTXCtorDtorLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "NVPTXCtorDtorLowering.h"
#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTX.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/IR/Constants.h"
Expand All @@ -32,6 +33,11 @@ static cl::opt<std::string>
cl::desc("Override unique ID of ctor/dtor globals."),
cl::init(""), cl::Hidden);

static cl::opt<bool>
CreateKernels("nvptx-emit-init-fini-kernel",
cl::desc("Emit kernels to call ctor/dtor globals."),
cl::init(true), cl::Hidden);

namespace {

static std::string getHash(StringRef Str) {
Expand All @@ -42,11 +48,163 @@ static std::string getHash(StringRef Str) {
return llvm::utohexstr(Hash.low(), /*LowerCase=*/true);
}

static bool createInitOrFiniGlobls(Module &M, StringRef GlobalName,
bool IsCtor) {
GlobalVariable *GV = M.getGlobalVariable(GlobalName);
if (!GV || !GV->hasInitializer())
return false;
static void addKernelMetadata(Module &M, GlobalValue *GV) {
llvm::LLVMContext &Ctx = M.getContext();

// Get "nvvm.annotations" metadata node.
llvm::NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");

llvm::Metadata *KernelMDVals[] = {
llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "kernel"),
llvm::ConstantAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};

// This kernel is only to be called single-threaded.
llvm::Metadata *ThreadXMDVals[] = {
llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidx"),
llvm::ConstantAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
llvm::Metadata *ThreadYMDVals[] = {
llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidy"),
llvm::ConstantAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
llvm::Metadata *ThreadZMDVals[] = {
llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidz"),
llvm::ConstantAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};

llvm::Metadata *BlockMDVals[] = {
llvm::ConstantAsMetadata::get(GV),
llvm::MDString::get(Ctx, "maxclusterrank"),
llvm::ConstantAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};

// Append metadata to nvvm.annotations.
MD->addOperand(llvm::MDNode::get(Ctx, KernelMDVals));
MD->addOperand(llvm::MDNode::get(Ctx, ThreadXMDVals));
MD->addOperand(llvm::MDNode::get(Ctx, ThreadYMDVals));
MD->addOperand(llvm::MDNode::get(Ctx, ThreadZMDVals));
MD->addOperand(llvm::MDNode::get(Ctx, BlockMDVals));
}

static Function *createInitOrFiniKernelFunction(Module &M, bool IsCtor) {
StringRef InitOrFiniKernelName =
IsCtor ? "nvptx$device$init" : "nvptx$device$fini";
if (M.getFunction(InitOrFiniKernelName))
return nullptr;

Function *InitOrFiniKernel = Function::createWithDefaultAttr(
FunctionType::get(Type::getVoidTy(M.getContext()), false),
GlobalValue::WeakODRLinkage, 0, InitOrFiniKernelName, &M);
addKernelMetadata(M, InitOrFiniKernel);

return InitOrFiniKernel;
}

// We create the IR required to call each callback in this section. This is
// equivalent to the following code. Normally, the linker would provide us with
// the definitions of the init and fini array sections. The 'nvlink' linker does
// not do this so initializing these values is done by the runtime.
//
// extern "C" void **__init_array_start = nullptr;
// extern "C" void **__init_array_end = nullptr;
// extern "C" void **__fini_array_start = nullptr;
// extern "C" void **__fini_array_end = nullptr;
//
// using InitCallback = void();
// using FiniCallback = void();
//
// void call_init_array_callbacks() {
// for (auto start = __init_array_start; start != __init_array_end; ++start)
// reinterpret_cast<InitCallback *>(*start)();
// }
//
// void call_init_array_callbacks() {
// size_t fini_array_size = __fini_array_end - __fini_array_start;
// for (size_t i = fini_array_size; i > 0; --i)
// reinterpret_cast<FiniCallback *>(__fini_array_start[i - 1])();
// }
static void createInitOrFiniCalls(Function &F, bool IsCtor) {
Module &M = *F.getParent();
LLVMContext &C = M.getContext();

IRBuilder<> IRB(BasicBlock::Create(C, "entry", &F));
auto *LoopBB = BasicBlock::Create(C, "while.entry", &F);
auto *ExitBB = BasicBlock::Create(C, "while.end", &F);
Type *PtrTy = IRB.getPtrTy(llvm::ADDRESS_SPACE_GLOBAL);

auto *Begin = M.getOrInsertGlobal(
IsCtor ? "__init_array_start" : "__fini_array_start",
PointerType::get(C, 0), [&]() {
auto *GV = new GlobalVariable(
M, PointerType::get(C, 0),
/*isConstant=*/false, GlobalValue::WeakAnyLinkage,
Constant::getNullValue(PointerType::get(C, 0)),
IsCtor ? "__init_array_start" : "__fini_array_start",
/*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
/*AddressSpace=*/llvm::ADDRESS_SPACE_GLOBAL);
GV->setVisibility(GlobalVariable::ProtectedVisibility);
return GV;
});
auto *End = M.getOrInsertGlobal(
IsCtor ? "__init_array_end" : "__fini_array_end", PointerType::get(C, 0),
[&]() {
auto *GV = new GlobalVariable(
M, PointerType::get(C, 0),
/*isConstant=*/false, GlobalValue::WeakAnyLinkage,
Constant::getNullValue(PointerType::get(C, 0)),
IsCtor ? "__init_array_end" : "__fini_array_end",
/*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
/*AddressSpace=*/llvm::ADDRESS_SPACE_GLOBAL);
GV->setVisibility(GlobalVariable::ProtectedVisibility);
return GV;
});

// The constructor type is suppoed to allow using the argument vectors, but
// for now we just call them with no arguments.
auto *CallBackTy = FunctionType::get(IRB.getVoidTy(), {});

// The destructor array must be called in reverse order. Get an expression to
// the end of the array and iterate backwards in that case.
Value *BeginVal = IRB.CreateLoad(Begin->getType(), Begin, "begin");
Value *EndVal = IRB.CreateLoad(Begin->getType(), End, "stop");
if (!IsCtor) {
auto *BeginInt = IRB.CreatePtrToInt(BeginVal, IntegerType::getInt64Ty(C));
auto *EndInt = IRB.CreatePtrToInt(EndVal, IntegerType::getInt64Ty(C));
auto *SubInst = IRB.CreateSub(EndInt, BeginInt);
auto *Offset = IRB.CreateAShr(
SubInst, ConstantInt::get(IntegerType::getInt64Ty(C), 3), "offset",
/*IsExact=*/true);
auto *ValuePtr = IRB.CreateGEP(PointerType::get(C, 0), BeginVal,
ArrayRef<Value *>({Offset}));
EndVal = BeginVal;
BeginVal = IRB.CreateInBoundsGEP(
PointerType::get(C, 0), ValuePtr,
ArrayRef<Value *>(ConstantInt::get(IntegerType::getInt64Ty(C), -1)),
"start");
}
IRB.CreateCondBr(
IRB.CreateCmp(IsCtor ? ICmpInst::ICMP_NE : ICmpInst::ICMP_UGT, BeginVal,
EndVal),
LoopBB, ExitBB);
IRB.SetInsertPoint(LoopBB);
auto *CallBackPHI = IRB.CreatePHI(PtrTy, 2, "ptr");
auto *CallBack = IRB.CreateLoad(CallBackTy->getPointerTo(F.getAddressSpace()),
CallBackPHI, "callback");
IRB.CreateCall(CallBackTy, CallBack);
auto *NewCallBack =
IRB.CreateConstGEP1_64(PtrTy, CallBackPHI, IsCtor ? 1 : -1, "next");
auto *EndCmp = IRB.CreateCmp(IsCtor ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_ULT,
NewCallBack, EndVal, "end");
CallBackPHI->addIncoming(BeginVal, &F.getEntryBlock());
CallBackPHI->addIncoming(NewCallBack, LoopBB);
IRB.CreateCondBr(EndCmp, ExitBB, LoopBB);
IRB.SetInsertPoint(ExitBB);
IRB.CreateRetVoid();
}

static bool createInitOrFiniGlobals(Module &M, GlobalVariable *GV,
bool IsCtor) {
ConstantArray *GA = dyn_cast<ConstantArray>(GV->getInitializer());
if (!GA || GA->getNumOperands() == 0)
return false;
Expand Down Expand Up @@ -81,14 +239,35 @@ static bool createInitOrFiniGlobls(Module &M, StringRef GlobalName,
appendToUsed(M, {GV});
}

return true;
}

static bool createInitOrFiniKernel(Module &M, StringRef GlobalName,
bool IsCtor) {
GlobalVariable *GV = M.getGlobalVariable(GlobalName);
if (!GV || !GV->hasInitializer())
return false;

if (!createInitOrFiniGlobals(M, GV, IsCtor))
return false;

if (!CreateKernels)
return true;

Function *InitOrFiniKernel = createInitOrFiniKernelFunction(M, IsCtor);
if (!InitOrFiniKernel)
return false;

createInitOrFiniCalls(*InitOrFiniKernel, IsCtor);

GV->eraseFromParent();
return true;
}

static bool lowerCtorsAndDtors(Module &M) {
bool Modified = false;
Modified |= createInitOrFiniGlobls(M, "llvm.global_ctors", /*IsCtor =*/true);
Modified |= createInitOrFiniGlobls(M, "llvm.global_dtors", /*IsCtor =*/false);
Modified |= createInitOrFiniKernel(M, "llvm.global_ctors", /*IsCtor =*/true);
Modified |= createInitOrFiniKernel(M, "llvm.global_dtors", /*IsCtor =*/false);
return Modified;
}

Expand Down
62 changes: 62 additions & 0 deletions llvm/test/CodeGen/NVPTX/lower-ctor-dtor.ll
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals --include-generated-funcs --version 3
; RUN: opt -S -mtriple=nvptx64-- -nvptx-lower-ctor-dtor < %s | FileCheck %s
; RUN: opt -S -mtriple=nvptx64-- -passes=nvptx-lower-ctor-dtor < %s | FileCheck %s
; RUN: opt -S -mtriple=nvptx64-- -passes=nvptx-lower-ctor-dtor \
; RUN: -nvptx-lower-global-ctor-dtor-id=unique_id < %s | FileCheck %s --check-prefix=GLOBAL
; RUN: opt -S -mtriple=nvptx64-- -passes=nvptx-lower-ctor-dtor \
; RUN: -nvptx-emit-init-fini-kernel=false < %s | FileCheck %s --check-prefix=KERNEL

; Make sure we get the same result if we run multiple times
; RUN: opt -S -mtriple=nvptx64-- -passes=nvptx-lower-ctor-dtor,nvptx-lower-ctor-dtor < %s | FileCheck %s
Expand All @@ -10,15 +13,24 @@
@llvm.global_ctors = appending addrspace(1) global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 1, ptr @foo, ptr null }]
@llvm.global_dtors = appending addrspace(1) global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 1, ptr @bar, ptr null }]


; CHECK-NOT: @llvm.global_ctors
; CHECK-NOT: @llvm.global_dtors

; CHECK: @__init_array_object_foo_[[HASH:[0-9a-f]+]]_1 = protected addrspace(4) constant ptr @foo, section ".init_array.1"
; CHECK: @__fini_array_object_bar_[[HASH:[0-9a-f]+]]_1 = protected addrspace(4) constant ptr @bar, section ".fini_array.1"
; CHECK: @llvm.used = appending global [2 x ptr] [ptr addrspacecast (ptr addrspace(4) @__init_array_object_foo_[[HASH]]_1 to ptr), ptr addrspacecast (ptr addrspace(4) @__fini_array_object_bar_[[HASH]]_1 to ptr)], section "llvm.metadata"
; CHECK: @__fini_array_start = weak protected addrspace(1) global ptr null
; CHECK: @__fini_array_end = weak protected addrspace(1) global ptr null

; GLOBAL: @__init_array_object_foo_unique_id_1 = protected addrspace(4) constant ptr @foo, section ".init_array.1"
; GLOBAL: @__fini_array_object_bar_unique_id_1 = protected addrspace(4) constant ptr @bar, section ".fini_array.1"
; GLOBAL: @llvm.used = appending global [2 x ptr] [ptr addrspacecast (ptr addrspace(4) @__init_array_object_foo_unique_id_1 to ptr), ptr addrspacecast (ptr addrspace(4) @__fini_array_object_bar_unique_id_1 to ptr)], section "llvm.metadata"
; GLOBAL: @__fini_array_start = weak protected addrspace(1) global ptr null
; GLOBAL: @__fini_array_end = weak protected addrspace(1) global ptr null

; KERNEL: @__init_array_object_foo_[[HASH:[0-9a-f]+]]_1 = protected addrspace(4) constant ptr @foo, section ".init_array.1"
; KERNEL: @__fini_array_object_bar_[[HASH:[0-9a-f]+]]_1 = protected addrspace(4) constant ptr @bar, section ".fini_array.1"

; VISIBILITY: .visible .const .align 8 .u64 __init_array_object_foo_[[HASH:[0-9a-f]+]]_1 = foo;
; VISIBILITY: .visible .const .align 8 .u64 __fini_array_object_bar_[[HASH:[0-9a-f]+]]_1 = bar;
Expand All @@ -30,3 +42,53 @@ define internal void @foo() {
define internal void @bar() {
ret void
}

; CHECK-LABEL: define weak_odr void @"nvptx$device$init"() {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[BEGIN:%.*]] = load ptr addrspace(1), ptr addrspace(1) @__init_array_start, align 8
; CHECK-NEXT: [[STOP:%.*]] = load ptr addrspace(1), ptr addrspace(1) @__init_array_end, align 8
; CHECK-NEXT: [[TMP0:%.*]] = icmp ne ptr addrspace(1) [[BEGIN]], [[STOP]]
; CHECK-NEXT: br i1 [[TMP0]], label [[WHILE_ENTRY:%.*]], label [[WHILE_END:%.*]]
; CHECK: while.entry:
; CHECK-NEXT: [[PTR:%.*]] = phi ptr addrspace(1) [ [[BEGIN]], [[ENTRY:%.*]] ], [ [[NEXT:%.*]], [[WHILE_ENTRY]] ]
; CHECK-NEXT: [[CALLBACK:%.*]] = load ptr, ptr addrspace(1) [[PTR]], align 8
; CHECK-NEXT: call void [[CALLBACK]]()
; CHECK-NEXT: [[NEXT]] = getelementptr ptr addrspace(1), ptr addrspace(1) [[PTR]], i64 1
; CHECK-NEXT: [[END:%.*]] = icmp eq ptr addrspace(1) [[NEXT]], [[STOP]]
; CHECK-NEXT: br i1 [[END]], label [[WHILE_END]], label [[WHILE_ENTRY]]
; CHECK: while.end:
; CHECK-NEXT: ret void
;
;
; CHECK-LABEL: define weak_odr void @"nvptx$device$fini"() {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[BEGIN:%.*]] = load ptr addrspace(1), ptr addrspace(1) @__fini_array_start, align 8
; CHECK-NEXT: [[STOP:%.*]] = load ptr addrspace(1), ptr addrspace(1) @__fini_array_end, align 8
; CHECK-NEXT: [[TMP0:%.*]] = ptrtoint ptr addrspace(1) [[BEGIN]] to i64
; CHECK-NEXT: [[TMP1:%.*]] = ptrtoint ptr addrspace(1) [[STOP]] to i64
; CHECK-NEXT: [[TMP2:%.*]] = sub i64 [[TMP1]], [[TMP0]]
; CHECK-NEXT: [[OFFSET:%.*]] = ashr exact i64 [[TMP2]], 3
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr ptr, ptr addrspace(1) [[BEGIN]], i64 [[OFFSET]]
; CHECK-NEXT: [[START:%.*]] = getelementptr inbounds ptr, ptr addrspace(1) [[TMP3]], i64 -1
; CHECK-NEXT: [[TMP4:%.*]] = icmp ugt ptr addrspace(1) [[START]], [[BEGIN]]
; CHECK-NEXT: br i1 [[TMP4]], label [[WHILE_ENTRY:%.*]], label [[WHILE_END:%.*]]
; CHECK: while.entry:
; CHECK-NEXT: [[PTR:%.*]] = phi ptr addrspace(1) [ [[START]], [[ENTRY:%.*]] ], [ [[NEXT:%.*]], [[WHILE_ENTRY]] ]
; CHECK-NEXT: [[CALLBACK:%.*]] = load ptr, ptr addrspace(1) [[PTR]], align 8
; CHECK-NEXT: call void [[CALLBACK]]()
; CHECK-NEXT: [[NEXT]] = getelementptr ptr addrspace(1), ptr addrspace(1) [[PTR]], i64 -1
; CHECK-NEXT: [[END:%.*]] = icmp ult ptr addrspace(1) [[NEXT]], [[BEGIN]]
; CHECK-NEXT: br i1 [[END]], label [[WHILE_END]], label [[WHILE_ENTRY]]
; CHECK: while.end:
; CHECK-NEXT: ret void

; CHECK: [[META0:![0-9]+]] = !{ptr @"nvptx$device$init", !"kernel", i32 1}
; CHECK: [[META1:![0-9]+]] = !{ptr @"nvptx$device$init", !"maxntidx", i32 1}
; CHECK: [[META2:![0-9]+]] = !{ptr @"nvptx$device$init", !"maxntidy", i32 1}
; CHECK: [[META3:![0-9]+]] = !{ptr @"nvptx$device$init", !"maxntidz", i32 1}
; CHECK: [[META4:![0-9]+]] = !{ptr @"nvptx$device$init", !"maxclusterrank", i32 1}
; CHECK: [[META5:![0-9]+]] = !{ptr @"nvptx$device$fini", !"kernel", i32 1}
; CHECK: [[META6:![0-9]+]] = !{ptr @"nvptx$device$fini", !"maxntidx", i32 1}
; CHECK: [[META7:![0-9]+]] = !{ptr @"nvptx$device$fini", !"maxntidy", i32 1}
; CHECK: [[META8:![0-9]+]] = !{ptr @"nvptx$device$fini", !"maxntidz", i32 1}
; CHECK: [[META9:![0-9]+]] = !{ptr @"nvptx$device$fini", !"maxclusterrank", i32 1}