Skip to content

Commit 8bd49ca

Browse files
committed
[NVPTX] Implement variadic functions using IR lowering
Summary: This patch implements support for variadic functions for NVPTX targets. The implementation here mainly follows what was done to implement it for AMDGPU in #93362. We change the NVPTX codegen to lower all variadic arguments to functions by-value. This creates a flattened set of arguments that the IR lowering pass converts into a struct with the proper alignment. The behavior of this function was determined by iteratively checking what the NVCC copmiler generates for its output. See examples like https://godbolt.org/z/KavfTGY93. I have noted the main methods that NVIDIA uses to lower variadic functions. 1. All arguments are passed in a pointer to aggregate. 2. The minimum alignment for a plain argument is 4 bytes. 3. Alignment is dictated by the underlying type 4. Structs are flattened and do not have their alignment changed. 5. NVPTX never passes any arguments indirectly, even very large ones. This patch passes the tests in the `libc` project currently, including support for `sprintf`.
1 parent 3de162f commit 8bd49ca

File tree

9 files changed

+916
-31
lines changed

9 files changed

+916
-31
lines changed

clang/lib/Basic/Targets/NVPTX.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXTargetInfo : public TargetInfo {
116116
}
117117

118118
BuiltinVaListKind getBuiltinVaListKind() const override {
119-
// FIXME: implement
120-
return TargetInfo::CharPtrBuiltinVaList;
119+
return TargetInfo::VoidPtrBuiltinVaList;
121120
}
122121

123122
bool isValidCPUName(StringRef Name) const override {

clang/lib/CodeGen/Targets/NVPTX.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,12 @@ ABIArgInfo NVPTXABIInfo::classifyArgumentType(QualType Ty) const {
203203
void NVPTXABIInfo::computeInfo(CGFunctionInfo &FI) const {
204204
if (!getCXXABI().classifyReturnType(FI))
205205
FI.getReturnInfo() = classifyReturnType(FI.getReturnType());
206+
207+
unsigned ArgumentsCount = 0;
206208
for (auto &I : FI.arguments())
207-
I.info = classifyArgumentType(I.type);
209+
I.info = ArgumentsCount++ < FI.getNumRequiredArgs()
210+
? classifyArgumentType(I.type)
211+
: ABIArgInfo::getDirect();
208212

209213
// Always honor user-specified calling convention.
210214
if (FI.getCallingConvention() != llvm::CallingConv::C)
@@ -215,7 +219,10 @@ void NVPTXABIInfo::computeInfo(CGFunctionInfo &FI) const {
215219

216220
RValue NVPTXABIInfo::EmitVAArg(CodeGenFunction &CGF, Address VAListAddr,
217221
QualType Ty, AggValueSlot Slot) const {
218-
llvm_unreachable("NVPTX does not support varargs");
222+
return emitVoidPtrVAArg(CGF, VAListAddr, Ty, /*IsIndirect=*/false,
223+
getContext().getTypeInfoInChars(Ty),
224+
CharUnits::fromQuantity(4),
225+
/*AllowHigherAlign=*/true, Slot);
219226
}
220227

221228
void NVPTXTargetCodeGenInfo::setTargetAttributes(

clang/test/CodeGen/variadic-nvptx.c

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
2+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -emit-llvm -o - %s | FileCheck %s
3+
4+
extern void varargs_simple(int, ...);
5+
6+
// CHECK-LABEL: define dso_local void @foo(
7+
// CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
8+
// CHECK-NEXT: [[ENTRY:.*:]]
9+
// CHECK-NEXT: [[C:%.*]] = alloca i8, align 1
10+
// CHECK-NEXT: [[S:%.*]] = alloca i16, align 2
11+
// CHECK-NEXT: [[I:%.*]] = alloca i32, align 4
12+
// CHECK-NEXT: [[L:%.*]] = alloca i64, align 8
13+
// CHECK-NEXT: [[F:%.*]] = alloca float, align 4
14+
// CHECK-NEXT: [[D:%.*]] = alloca double, align 8
15+
// CHECK-NEXT: [[A:%.*]] = alloca [[STRUCT_ANON:%.*]], align 4
16+
// CHECK-NEXT: [[V:%.*]] = alloca <4 x i32>, align 16
17+
// CHECK-NEXT: store i8 1, ptr [[C]], align 1
18+
// CHECK-NEXT: store i16 1, ptr [[S]], align 2
19+
// CHECK-NEXT: store i32 1, ptr [[I]], align 4
20+
// CHECK-NEXT: store i64 1, ptr [[L]], align 8
21+
// CHECK-NEXT: store float 1.000000e+00, ptr [[F]], align 4
22+
// CHECK-NEXT: store double 1.000000e+00, ptr [[D]], align 8
23+
// CHECK-NEXT: [[TMP0:%.*]] = load i8, ptr [[C]], align 1
24+
// CHECK-NEXT: [[CONV:%.*]] = sext i8 [[TMP0]] to i32
25+
// CHECK-NEXT: [[TMP1:%.*]] = load i16, ptr [[S]], align 2
26+
// CHECK-NEXT: [[CONV1:%.*]] = sext i16 [[TMP1]] to i32
27+
// CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[I]], align 4
28+
// CHECK-NEXT: [[TMP3:%.*]] = load i64, ptr [[L]], align 8
29+
// CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[F]], align 4
30+
// CHECK-NEXT: [[CONV2:%.*]] = fpext float [[TMP4]] to double
31+
// CHECK-NEXT: [[TMP5:%.*]] = load double, ptr [[D]], align 8
32+
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, i32 noundef [[CONV]], i32 noundef [[CONV1]], i32 noundef [[TMP2]], i64 noundef [[TMP3]], double noundef [[CONV2]], double noundef [[TMP5]])
33+
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[A]], ptr align 4 @__const.foo.a, i64 12, i1 false)
34+
// CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 0
35+
// CHECK-NEXT: [[TMP7:%.*]] = load i32, ptr [[TMP6]], align 4
36+
// CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 1
37+
// CHECK-NEXT: [[TMP9:%.*]] = load i8, ptr [[TMP8]], align 4
38+
// CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 2
39+
// CHECK-NEXT: [[TMP11:%.*]] = load i32, ptr [[TMP10]], align 4
40+
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, i32 [[TMP7]], i8 [[TMP9]], i32 [[TMP11]])
41+
// CHECK-NEXT: store <4 x i32> <i32 1, i32 1, i32 1, i32 1>, ptr [[V]], align 16
42+
// CHECK-NEXT: [[TMP12:%.*]] = load <4 x i32>, ptr [[V]], align 16
43+
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, <4 x i32> noundef [[TMP12]])
44+
// CHECK-NEXT: ret void
45+
//
46+
void foo() {
47+
char c = '\x1';
48+
short s = 1;
49+
int i = 1;
50+
long l = 1;
51+
float f = 1.f;
52+
double d = 1.;
53+
varargs_simple(0, c, s, i, l, f, d);
54+
55+
struct {int x; char c; int y;} a = {1, '\x1', 1};
56+
varargs_simple(0, a);
57+
58+
typedef int __attribute__((ext_vector_type(4))) int4;
59+
int4 v = {1, 1, 1, 1};
60+
varargs_simple(0, v);
61+
}
62+
63+
typedef struct {long x; long y;} S;
64+
extern void varargs_complex(S, S, ...);
65+
66+
// CHECK-LABEL: define dso_local void @bar(
67+
// CHECK-SAME: ) #[[ATTR0]] {
68+
// CHECK-NEXT: [[ENTRY:.*:]]
69+
// CHECK-NEXT: [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 8
70+
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[S]], ptr align 8 @__const.bar.s, i64 16, i1 false)
71+
// CHECK-NEXT: call void (ptr, ptr, ...) @varargs_complex(ptr noundef byval([[STRUCT_S]]) align 8 [[S]], ptr noundef byval([[STRUCT_S]]) align 8 [[S]], i32 noundef 1, i64 noundef 1, double noundef 1.000000e+00)
72+
// CHECK-NEXT: ret void
73+
//
74+
void bar() {
75+
S s = {1l, 1l};
76+
varargs_complex(s, s, 1, 1l, 1.0);
77+
}

libc/config/gpu/entrypoints.txt

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
if(LIBC_TARGET_ARCHITECTURE_IS_AMDGPU)
2-
set(extra_entrypoints
3-
# stdio.h entrypoints
4-
libc.src.stdio.sprintf
5-
libc.src.stdio.snprintf
6-
libc.src.stdio.vsprintf
7-
libc.src.stdio.vsnprintf
8-
)
9-
endif()
10-
111
set(TARGET_LIBC_ENTRYPOINTS
122
# assert.h entrypoints
133
libc.src.assert.__assert_fail
@@ -185,7 +175,10 @@ set(TARGET_LIBC_ENTRYPOINTS
185175
libc.src.errno.errno
186176

187177
# stdio.h entrypoints
188-
${extra_entrypoints}
178+
libc.src.stdio.sprintf
179+
libc.src.stdio.snprintf
180+
libc.src.stdio.vsprintf
181+
libc.src.stdio.vsnprintf
189182
libc.src.stdio.feof
190183
libc.src.stdio.ferror
191184
libc.src.stdio.fseek

libc/test/src/__support/CMakeLists.txt

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,18 +128,15 @@ add_libc_test(
128128
libc.src.__support.uint128
129129
)
130130

131-
# NVPTX does not support varargs currently.
132-
if(NOT LIBC_TARGET_ARCHITECTURE_IS_NVPTX)
133-
add_libc_test(
134-
arg_list_test
135-
SUITE
136-
libc-support-tests
137-
SRCS
138-
arg_list_test.cpp
139-
DEPENDS
140-
libc.src.__support.arg_list
141-
)
142-
endif()
131+
add_libc_test(
132+
arg_list_test
133+
SUITE
134+
libc-support-tests
135+
SRCS
136+
arg_list_test.cpp
137+
DEPENDS
138+
libc.src.__support.arg_list
139+
)
143140

144141
if(NOT LIBC_TARGET_ARCHITECTURE_IS_NVPTX)
145142
add_libc_test(

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/Target/TargetMachine.h"
3434
#include "llvm/Target/TargetOptions.h"
3535
#include "llvm/TargetParser/Triple.h"
36+
#include "llvm/Transforms/IPO/ExpandVariadics.h"
3637
#include "llvm/Transforms/Scalar.h"
3738
#include "llvm/Transforms/Scalar/GVN.h"
3839
#include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
@@ -343,6 +344,7 @@ void NVPTXPassConfig::addIRPasses() {
343344
}
344345

345346
addPass(createAtomicExpandLegacyPass());
347+
addPass(createExpandVariadicsPass(ExpandVariadicsMode::Lowering));
346348
addPass(createNVPTXCtorDtorLoweringLegacyPass());
347349

348350
// === LSR and other generic IR passes ===

llvm/lib/Transforms/IPO/ExpandVariadics.cpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,8 @@ bool ExpandVariadics::runOnFunction(Module &M, IRBuilder<> &Builder,
456456
// Replace known calls to the variadic with calls to the va_list equivalent
457457
for (User *U : make_early_inc_range(VariadicWrapper->users())) {
458458
if (CallBase *CB = dyn_cast<CallBase>(U)) {
459-
Value *calledOperand = CB->getCalledOperand();
460-
if (VariadicWrapper == calledOperand)
459+
Value *CalledOperand = CB->getCalledOperand();
460+
if (VariadicWrapper == CalledOperand)
461461
Changed |=
462462
expandCall(M, Builder, CB, VariadicWrapper->getFunctionType(),
463463
FixedArityReplacement);
@@ -938,6 +938,36 @@ struct Amdgpu final : public VariadicABIInfo {
938938
}
939939
};
940940

941+
struct NVPTX final : public VariadicABIInfo {
942+
943+
bool enableForTarget() override { return true; }
944+
945+
bool vaListPassedInSSARegister() override { return true; }
946+
947+
Type *vaListType(LLVMContext &Ctx) override {
948+
return PointerType::getUnqual(Ctx);
949+
}
950+
951+
Type *vaListParameterType(Module &M) override {
952+
return PointerType::getUnqual(M.getContext());
953+
}
954+
955+
Value *initializeVaList(Module &M, LLVMContext &Ctx, IRBuilder<> &Builder,
956+
AllocaInst *, Value *Buffer) override {
957+
return Builder.CreateAddrSpaceCast(Buffer, vaListParameterType(M));
958+
}
959+
960+
VAArgSlotInfo slotInfo(const DataLayout &DL, Type *Parameter) override {
961+
// NVPTX expects natural alignment in all cases. The variadic call ABI will
962+
// handle promoting types to their appropriate size and alignment.
963+
const unsigned MinAlign = 1;
964+
Align A = DL.getABITypeAlign(Parameter);
965+
if (A < MinAlign)
966+
A = Align(MinAlign);
967+
return {A, false};
968+
}
969+
};
970+
941971
struct Wasm final : public VariadicABIInfo {
942972

943973
bool enableForTarget() override {
@@ -967,8 +997,8 @@ struct Wasm final : public VariadicABIInfo {
967997
if (A < MinAlign)
968998
A = Align(MinAlign);
969999

970-
if (auto s = dyn_cast<StructType>(Parameter)) {
971-
if (s->getNumElements() > 1) {
1000+
if (auto *S = dyn_cast<StructType>(Parameter)) {
1001+
if (S->getNumElements() > 1) {
9721002
return {DL.getABITypeAlign(PointerType::getUnqual(Ctx)), true};
9731003
}
9741004
}
@@ -988,6 +1018,11 @@ std::unique_ptr<VariadicABIInfo> VariadicABIInfo::create(const Triple &T) {
9881018
return std::make_unique<Wasm>();
9891019
}
9901020

1021+
case Triple::nvptx:
1022+
case Triple::nvptx64: {
1023+
return std::make_unique<NVPTX>();
1024+
}
1025+
9911026
default:
9921027
return {};
9931028
}

0 commit comments

Comments
 (0)