Skip to content

Commit a05b24a

Browse files
committed
[NVPTX] Implement variadic functions using IR lowering
Summary: This patch implements support for variadic functions for NVPTX targets. The implementation here mainly follows what was done to implement it for AMDGPU in #93362. We change the NVPTX codegen to lower all variadic arguments to functions by-value. This creates a flattened set of arguments that the IR lowering pass converts into a struct with the proper alignment. The behavior of this function was determined by iteratively checking what the NVCC copmiler generates for its output. See examples like https://godbolt.org/z/KavfTGY93. I have noted the main methods that NVIDIA uses to lower variadic functions. 1. All arguments are passed in a pointer to aggregate. 2. The minimum alignment for a plain argument is 4 bytes. 3. Alignment is dictated by the underlying type 4. Structs are flattened and do not have their alignment changed. 5. NVPTX never passes any arguments indirectly, even very large ones. This patch passes the tests in the `libc` project currently, including support for `sprintf`.
1 parent ffed34e commit a05b24a

File tree

8 files changed

+915
-29
lines changed

8 files changed

+915
-29
lines changed

clang/lib/CodeGen/Targets/NVPTX.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,12 @@ ABIArgInfo NVPTXABIInfo::classifyArgumentType(QualType Ty) const {
203203
void NVPTXABIInfo::computeInfo(CGFunctionInfo &FI) const {
204204
if (!getCXXABI().classifyReturnType(FI))
205205
FI.getReturnInfo() = classifyReturnType(FI.getReturnType());
206+
207+
unsigned ArgumentsCount = 0;
206208
for (auto &I : FI.arguments())
207-
I.info = classifyArgumentType(I.type);
209+
I.info = ArgumentsCount++ < FI.getNumRequiredArgs()
210+
? classifyArgumentType(I.type)
211+
: ABIArgInfo::getDirect();
208212

209213
// Always honor user-specified calling convention.
210214
if (FI.getCallingConvention() != llvm::CallingConv::C)
@@ -215,7 +219,10 @@ void NVPTXABIInfo::computeInfo(CGFunctionInfo &FI) const {
215219

216220
RValue NVPTXABIInfo::EmitVAArg(CodeGenFunction &CGF, Address VAListAddr,
217221
QualType Ty, AggValueSlot Slot) const {
218-
llvm_unreachable("NVPTX does not support varargs");
222+
return emitVoidPtrVAArg(CGF, VAListAddr, Ty, /*IsIndirect=*/false,
223+
getContext().getTypeInfoInChars(Ty),
224+
CharUnits::fromQuantity(4),
225+
/*AllowHigherAlign=*/true, Slot);
219226
}
220227

221228
void NVPTXTargetCodeGenInfo::setTargetAttributes(

clang/test/CodeGen/variadic-nvptx.c

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
2+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -emit-llvm -o - %s | FileCheck %s
3+
4+
extern void varargs_simple(int, ...);
5+
6+
// CHECK-LABEL: define dso_local void @foo(
7+
// CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
8+
// CHECK-NEXT: [[ENTRY:.*:]]
9+
// CHECK-NEXT: [[C:%.*]] = alloca i8, align 1
10+
// CHECK-NEXT: [[S:%.*]] = alloca i16, align 2
11+
// CHECK-NEXT: [[I:%.*]] = alloca i32, align 4
12+
// CHECK-NEXT: [[L:%.*]] = alloca i64, align 8
13+
// CHECK-NEXT: [[F:%.*]] = alloca float, align 4
14+
// CHECK-NEXT: [[D:%.*]] = alloca double, align 8
15+
// CHECK-NEXT: [[A:%.*]] = alloca [[STRUCT_ANON:%.*]], align 4
16+
// CHECK-NEXT: [[V:%.*]] = alloca <4 x i32>, align 16
17+
// CHECK-NEXT: store i8 1, ptr [[C]], align 1
18+
// CHECK-NEXT: store i16 1, ptr [[S]], align 2
19+
// CHECK-NEXT: store i32 1, ptr [[I]], align 4
20+
// CHECK-NEXT: store i64 1, ptr [[L]], align 8
21+
// CHECK-NEXT: store float 1.000000e+00, ptr [[F]], align 4
22+
// CHECK-NEXT: store double 1.000000e+00, ptr [[D]], align 8
23+
// CHECK-NEXT: [[TMP0:%.*]] = load i8, ptr [[C]], align 1
24+
// CHECK-NEXT: [[CONV:%.*]] = sext i8 [[TMP0]] to i32
25+
// CHECK-NEXT: [[TMP1:%.*]] = load i16, ptr [[S]], align 2
26+
// CHECK-NEXT: [[CONV1:%.*]] = sext i16 [[TMP1]] to i32
27+
// CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[I]], align 4
28+
// CHECK-NEXT: [[TMP3:%.*]] = load i64, ptr [[L]], align 8
29+
// CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[F]], align 4
30+
// CHECK-NEXT: [[CONV2:%.*]] = fpext float [[TMP4]] to double
31+
// CHECK-NEXT: [[TMP5:%.*]] = load double, ptr [[D]], align 8
32+
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, i32 noundef [[CONV]], i32 noundef [[CONV1]], i32 noundef [[TMP2]], i64 noundef [[TMP3]], double noundef [[CONV2]], double noundef [[TMP5]])
33+
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[A]], ptr align 4 @__const.foo.a, i64 12, i1 false)
34+
// CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 0
35+
// CHECK-NEXT: [[TMP7:%.*]] = load i32, ptr [[TMP6]], align 4
36+
// CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 1
37+
// CHECK-NEXT: [[TMP9:%.*]] = load i8, ptr [[TMP8]], align 4
38+
// CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 2
39+
// CHECK-NEXT: [[TMP11:%.*]] = load i32, ptr [[TMP10]], align 4
40+
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, i32 [[TMP7]], i8 [[TMP9]], i32 [[TMP11]])
41+
// CHECK-NEXT: store <4 x i32> <i32 1, i32 1, i32 1, i32 1>, ptr [[V]], align 16
42+
// CHECK-NEXT: [[TMP12:%.*]] = load <4 x i32>, ptr [[V]], align 16
43+
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, <4 x i32> noundef [[TMP12]])
44+
// CHECK-NEXT: ret void
45+
//
46+
void foo() {
47+
char c = '\x1';
48+
short s = 1;
49+
int i = 1;
50+
long l = 1;
51+
float f = 1.f;
52+
double d = 1.;
53+
varargs_simple(0, c, s, i, l, f, d);
54+
55+
struct {int x; char c; int y;} a = {1, '\x1', 1};
56+
varargs_simple(0, a);
57+
58+
typedef int __attribute__((ext_vector_type(4))) int4;
59+
int4 v = {1, 1, 1, 1};
60+
varargs_simple(0, v);
61+
}
62+
63+
typedef struct {long x; long y;} S;
64+
extern void varargs_complex(S, S, ...);
65+
66+
// CHECK-LABEL: define dso_local void @bar(
67+
// CHECK-SAME: ) #[[ATTR0]] {
68+
// CHECK-NEXT: [[ENTRY:.*:]]
69+
// CHECK-NEXT: [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 8
70+
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[S]], ptr align 8 @__const.bar.s, i64 16, i1 false)
71+
// CHECK-NEXT: call void (ptr, ptr, ...) @varargs_complex(ptr noundef byval([[STRUCT_S]]) align 8 [[S]], ptr noundef byval([[STRUCT_S]]) align 8 [[S]], i32 noundef 1, i64 noundef 1, double noundef 1.000000e+00)
72+
// CHECK-NEXT: ret void
73+
//
74+
void bar() {
75+
S s = {1l, 1l};
76+
varargs_complex(s, s, 1, 1l, 1.0);
77+
}

libc/config/gpu/entrypoints.txt

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
if(LIBC_TARGET_ARCHITECTURE_IS_AMDGPU)
2-
set(extra_entrypoints
3-
# stdio.h entrypoints
4-
libc.src.stdio.sprintf
5-
libc.src.stdio.snprintf
6-
libc.src.stdio.vsprintf
7-
libc.src.stdio.vsnprintf
8-
)
9-
endif()
10-
111
set(TARGET_LIBC_ENTRYPOINTS
122
# assert.h entrypoints
133
libc.src.assert.__assert_fail
@@ -185,7 +175,10 @@ set(TARGET_LIBC_ENTRYPOINTS
185175
libc.src.errno.errno
186176

187177
# stdio.h entrypoints
188-
${extra_entrypoints}
178+
libc.src.stdio.sprintf
179+
libc.src.stdio.snprintf
180+
libc.src.stdio.vsprintf
181+
libc.src.stdio.vsnprintf
189182
libc.src.stdio.feof
190183
libc.src.stdio.ferror
191184
libc.src.stdio.fseek

libc/test/src/__support/CMakeLists.txt

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,15 @@ add_libc_test(
8686
libc.src.__support.uint128
8787
)
8888

89-
# NVPTX does not support varargs currently.
90-
if(NOT LIBC_TARGET_ARCHITECTURE_IS_NVPTX)
91-
add_libc_test(
92-
arg_list_test
93-
SUITE
94-
libc-support-tests
95-
SRCS
96-
arg_list_test.cpp
97-
DEPENDS
98-
libc.src.__support.arg_list
99-
)
100-
endif()
89+
add_libc_test(
90+
arg_list_test
91+
SUITE
92+
libc-support-tests
93+
SRCS
94+
arg_list_test.cpp
95+
DEPENDS
96+
libc.src.__support.arg_list
97+
)
10198

10299
if(NOT LIBC_TARGET_ARCHITECTURE_IS_NVPTX)
103100
add_libc_test(

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/Target/TargetMachine.h"
3434
#include "llvm/Target/TargetOptions.h"
3535
#include "llvm/TargetParser/Triple.h"
36+
#include "llvm/Transforms/IPO/ExpandVariadics.h"
3637
#include "llvm/Transforms/Scalar.h"
3738
#include "llvm/Transforms/Scalar/GVN.h"
3839
#include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
@@ -343,6 +344,7 @@ void NVPTXPassConfig::addIRPasses() {
343344
}
344345

345346
addPass(createAtomicExpandLegacyPass());
347+
addPass(createExpandVariadicsPass(ExpandVariadicsMode::Lowering));
346348
addPass(createNVPTXCtorDtorLoweringLegacyPass());
347349

348350
// === LSR and other generic IR passes ===

llvm/lib/Transforms/IPO/ExpandVariadics.cpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,8 @@ bool ExpandVariadics::runOnFunction(Module &M, IRBuilder<> &Builder,
456456
// Replace known calls to the variadic with calls to the va_list equivalent
457457
for (User *U : make_early_inc_range(VariadicWrapper->users())) {
458458
if (CallBase *CB = dyn_cast<CallBase>(U)) {
459-
Value *calledOperand = CB->getCalledOperand();
460-
if (VariadicWrapper == calledOperand)
459+
Value *CalledOperand = CB->getCalledOperand();
460+
if (VariadicWrapper == CalledOperand)
461461
Changed |=
462462
expandCall(M, Builder, CB, VariadicWrapper->getFunctionType(),
463463
FixedArityReplacement);
@@ -938,6 +938,36 @@ struct Amdgpu final : public VariadicABIInfo {
938938
}
939939
};
940940

941+
struct NVPTX final : public VariadicABIInfo {
942+
943+
bool enableForTarget() override { return true; }
944+
945+
bool vaListPassedInSSARegister() override { return true; }
946+
947+
Type *vaListType(LLVMContext &Ctx) override {
948+
return PointerType::getUnqual(Ctx);
949+
}
950+
951+
Type *vaListParameterType(Module &M) override {
952+
return PointerType::getUnqual(M.getContext());
953+
}
954+
955+
Value *initializeVaList(Module &M, LLVMContext &Ctx, IRBuilder<> &Builder,
956+
AllocaInst *, Value *Buffer) override {
957+
return Builder.CreateAddrSpaceCast(Buffer, vaListParameterType(M));
958+
}
959+
960+
VAArgSlotInfo slotInfo(const DataLayout &DL, Type *Parameter) override {
961+
// NVPTX expects natural alignment in all cases. The variadic call ABI will
962+
// handle promoting types to their appropriate size and alignment.
963+
const unsigned MinAlign = 1;
964+
Align A = DL.getABITypeAlign(Parameter);
965+
if (A < MinAlign)
966+
A = Align(MinAlign);
967+
return {A, false};
968+
}
969+
};
970+
941971
struct Wasm final : public VariadicABIInfo {
942972

943973
bool enableForTarget() override {
@@ -967,8 +997,8 @@ struct Wasm final : public VariadicABIInfo {
967997
if (A < MinAlign)
968998
A = Align(MinAlign);
969999

970-
if (auto s = dyn_cast<StructType>(Parameter)) {
971-
if (s->getNumElements() > 1) {
1000+
if (auto *S = dyn_cast<StructType>(Parameter)) {
1001+
if (S->getNumElements() > 1) {
9721002
return {DL.getABITypeAlign(PointerType::getUnqual(Ctx)), true};
9731003
}
9741004
}
@@ -988,6 +1018,11 @@ std::unique_ptr<VariadicABIInfo> VariadicABIInfo::create(const Triple &T) {
9881018
return std::make_unique<Wasm>();
9891019
}
9901020

1021+
case Triple::nvptx:
1022+
case Triple::nvptx64: {
1023+
return std::make_unique<NVPTX>();
1024+
}
1025+
9911026
default:
9921027
return {};
9931028
}

0 commit comments

Comments
 (0)