Skip to content

Commit a998f12

Browse files
committed
Reland "[CUDA][HIP] Fix overloading resolution in global var init"
https://reviews.llvm.org/D158247 caused regressions for HIP on Windows and was reverted. A reduced test case is: ``` typedef void (__stdcall* funcTy)(); void invoke(funcTy f); static void __stdcall callee() noexcept { } void foo() { invoke(callee); } ``` It is due to clang missing handling host/device attributes for calling convention at a few places This patch fixes that.
1 parent 463c9f4 commit a998f12

13 files changed

+272
-93
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,4 @@ pythonenv*
7070
/clang/utils/analyzer/projects/*/RefScanBuildResults
7171
# automodapi puts generated documentation files here.
7272
/lldb/docs/python_api/
73+
/Debug/

clang/include/clang/Sema/Sema.h

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,14 @@ class Sema final {
10121012
}
10131013
} DelayedDiagnostics;
10141014

1015+
enum CUDAFunctionTarget {
1016+
CFT_Device,
1017+
CFT_Global,
1018+
CFT_Host,
1019+
CFT_HostDevice,
1020+
CFT_InvalidTarget
1021+
};
1022+
10151023
/// A RAII object to temporarily push a declaration context.
10161024
class ContextRAII {
10171025
private:
@@ -4753,8 +4761,13 @@ class Sema final {
47534761
bool isValidPointerAttrType(QualType T, bool RefOkay = false);
47544762

47554763
bool CheckRegparmAttr(const ParsedAttr &attr, unsigned &value);
4764+
4765+
/// Check validaty of calling convention attribute \p attr. If \p FD
4766+
/// is not null pointer, use \p FD to determine the CUDA/HIP host/device
4767+
/// target. Otherwise, it is specified by \p CFT.
47564768
bool CheckCallingConvAttr(const ParsedAttr &attr, CallingConv &CC,
4757-
const FunctionDecl *FD = nullptr);
4769+
const FunctionDecl *FD = nullptr,
4770+
CUDAFunctionTarget CFT = CFT_InvalidTarget);
47584771
bool CheckAttrTarget(const ParsedAttr &CurrAttr);
47594772
bool CheckAttrNoArgs(const ParsedAttr &CurrAttr);
47604773
bool checkStringLiteralArgumentAttr(const AttributeCommonInfo &CI,
@@ -13266,14 +13279,6 @@ class Sema final {
1326613279
void checkTypeSupport(QualType Ty, SourceLocation Loc,
1326713280
ValueDecl *D = nullptr);
1326813281

13269-
enum CUDAFunctionTarget {
13270-
CFT_Device,
13271-
CFT_Global,
13272-
CFT_Host,
13273-
CFT_HostDevice,
13274-
CFT_InvalidTarget
13275-
};
13276-
1327713282
/// Determines whether the given function is a CUDA device/host/kernel/etc.
1327813283
/// function.
1327913284
///
@@ -13292,6 +13297,29 @@ class Sema final {
1329213297
/// Determines whether the given variable is emitted on host or device side.
1329313298
CUDAVariableTarget IdentifyCUDATarget(const VarDecl *D);
1329413299

13300+
/// Defines kinds of CUDA global host/device context where a function may be
13301+
/// called.
13302+
enum CUDATargetContextKind {
13303+
CTCK_Unknown, /// Unknown context
13304+
CTCK_InitGlobalVar, /// Function called during global variable
13305+
/// initialization
13306+
};
13307+
13308+
/// Define the current global CUDA host/device context where a function may be
13309+
/// called. Only used when a function is called outside of any functions.
13310+
struct CUDATargetContext {
13311+
CUDAFunctionTarget Target = CFT_HostDevice;
13312+
CUDATargetContextKind Kind = CTCK_Unknown;
13313+
Decl *D = nullptr;
13314+
} CurCUDATargetCtx;
13315+
13316+
struct CUDATargetContextRAII {
13317+
Sema &S;
13318+
CUDATargetContext SavedCtx;
13319+
CUDATargetContextRAII(Sema &S_, CUDATargetContextKind K, Decl *D);
13320+
~CUDATargetContextRAII() { S.CurCUDATargetCtx = SavedCtx; }
13321+
};
13322+
1329513323
/// Gets the CUDA target for the current context.
1329613324
CUDAFunctionTarget CurrentCUDATarget() {
1329713325
return IdentifyCUDATarget(dyn_cast<FunctionDecl>(CurContext));

clang/lib/Parse/ParseDecl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2571,6 +2571,7 @@ Decl *Parser::ParseDeclarationAfterDeclaratorAndAttributes(
25712571
}
25722572
}
25732573

2574+
Sema::CUDATargetContextRAII X(Actions, Sema::CTCK_InitGlobalVar, ThisDecl);
25742575
switch (TheInitKind) {
25752576
// Parse declarator '=' initializer.
25762577
case InitKind::Equal: {

clang/lib/Sema/SemaCUDA.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,37 @@ Sema::IdentifyCUDATarget(const ParsedAttributesView &Attrs) {
105105
}
106106

107107
template <typename A>
108-
static bool hasAttr(const FunctionDecl *D, bool IgnoreImplicitAttr) {
108+
static bool hasAttr(const Decl *D, bool IgnoreImplicitAttr) {
109109
return D->hasAttrs() && llvm::any_of(D->getAttrs(), [&](Attr *Attribute) {
110110
return isa<A>(Attribute) &&
111111
!(IgnoreImplicitAttr && Attribute->isImplicit());
112112
});
113113
}
114114

115+
Sema::CUDATargetContextRAII::CUDATargetContextRAII(Sema &S_,
116+
CUDATargetContextKind K,
117+
Decl *D)
118+
: S(S_) {
119+
SavedCtx = S.CurCUDATargetCtx;
120+
assert(K == CTCK_InitGlobalVar);
121+
auto *VD = dyn_cast_or_null<VarDecl>(D);
122+
if (VD && VD->hasGlobalStorage() && !VD->isStaticLocal()) {
123+
auto Target = CFT_Host;
124+
if ((hasAttr<CUDADeviceAttr>(VD, /*IgnoreImplicit=*/true) &&
125+
!hasAttr<CUDAHostAttr>(VD, /*IgnoreImplicit=*/true)) ||
126+
hasAttr<CUDASharedAttr>(VD, /*IgnoreImplicit=*/true) ||
127+
hasAttr<CUDAConstantAttr>(VD, /*IgnoreImplicit=*/true))
128+
Target = CFT_Device;
129+
S.CurCUDATargetCtx = {Target, K, VD};
130+
}
131+
}
132+
115133
/// IdentifyCUDATarget - Determine the CUDA compilation target for this function
116134
Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
117135
bool IgnoreImplicitHDAttr) {
118-
// Code that lives outside a function is run on the host.
136+
// Code that lives outside a function gets the target from CurCUDATargetCtx.
119137
if (D == nullptr)
120-
return CFT_Host;
138+
return CurCUDATargetCtx.Target;
121139

122140
if (D->hasAttr<CUDAInvalidTargetAttr>())
123141
return CFT_InvalidTarget;

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5132,7 +5132,8 @@ static void handleCallConvAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
51325132
// Diagnostic is emitted elsewhere: here we store the (valid) AL
51335133
// in the Decl node for syntactic reasoning, e.g., pretty-printing.
51345134
CallingConv CC;
5135-
if (S.CheckCallingConvAttr(AL, CC, /*FD*/nullptr))
5135+
if (S.CheckCallingConvAttr(AL, CC, /*FD*/ nullptr,
5136+
S.IdentifyCUDATarget(dyn_cast<FunctionDecl>(D))))
51365137
return;
51375138

51385139
if (!isa<ObjCMethodDecl>(D)) {
@@ -5317,7 +5318,8 @@ static void handleNoRandomizeLayoutAttr(Sema &S, Decl *D,
53175318
}
53185319

53195320
bool Sema::CheckCallingConvAttr(const ParsedAttr &Attrs, CallingConv &CC,
5320-
const FunctionDecl *FD) {
5321+
const FunctionDecl *FD,
5322+
CUDAFunctionTarget CFT) {
53215323
if (Attrs.isInvalid())
53225324
return true;
53235325

@@ -5416,7 +5418,8 @@ bool Sema::CheckCallingConvAttr(const ParsedAttr &Attrs, CallingConv &CC,
54165418
// on their host/device attributes.
54175419
if (LangOpts.CUDA) {
54185420
auto *Aux = Context.getAuxTargetInfo();
5419-
auto CudaTarget = IdentifyCUDATarget(FD);
5421+
assert(FD || CFT != CFT_InvalidTarget);
5422+
auto CudaTarget = FD ? IdentifyCUDATarget(FD) : CFT;
54205423
bool CheckHost = false, CheckDevice = false;
54215424
switch (CudaTarget) {
54225425
case CFT_HostDevice:

clang/lib/Sema/SemaOverload.cpp

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6697,17 +6697,19 @@ void Sema::AddOverloadCandidate(
66976697
}
66986698

66996699
// (CUDA B.1): Check for invalid calls between targets.
6700-
if (getLangOpts().CUDA)
6701-
if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
6702-
// Skip the check for callers that are implicit members, because in this
6703-
// case we may not yet know what the member's target is; the target is
6704-
// inferred for the member automatically, based on the bases and fields of
6705-
// the class.
6706-
if (!Caller->isImplicit() && !IsAllowedCUDACall(Caller, Function)) {
6707-
Candidate.Viable = false;
6708-
Candidate.FailureKind = ovl_fail_bad_target;
6709-
return;
6710-
}
6700+
if (getLangOpts().CUDA) {
6701+
const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true);
6702+
// Skip the check for callers that are implicit members, because in this
6703+
// case we may not yet know what the member's target is; the target is
6704+
// inferred for the member automatically, based on the bases and fields of
6705+
// the class.
6706+
if (!(Caller && Caller->isImplicit()) &&
6707+
!IsAllowedCUDACall(Caller, Function)) {
6708+
Candidate.Viable = false;
6709+
Candidate.FailureKind = ovl_fail_bad_target;
6710+
return;
6711+
}
6712+
}
67116713

67126714
if (Function->getTrailingRequiresClause()) {
67136715
ConstraintSatisfaction Satisfaction;
@@ -7219,12 +7221,11 @@ Sema::AddMethodCandidate(CXXMethodDecl *Method, DeclAccessPair FoundDecl,
72197221

72207222
// (CUDA B.1): Check for invalid calls between targets.
72217223
if (getLangOpts().CUDA)
7222-
if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
7223-
if (!IsAllowedCUDACall(Caller, Method)) {
7224-
Candidate.Viable = false;
7225-
Candidate.FailureKind = ovl_fail_bad_target;
7226-
return;
7227-
}
7224+
if (!IsAllowedCUDACall(getCurFunctionDecl(/*AllowLambda=*/true), Method)) {
7225+
Candidate.Viable = false;
7226+
Candidate.FailureKind = ovl_fail_bad_target;
7227+
return;
7228+
}
72287229

72297230
if (Method->getTrailingRequiresClause()) {
72307231
ConstraintSatisfaction Satisfaction;
@@ -12495,10 +12496,12 @@ class AddressOfFunctionResolver {
1249512496
return false;
1249612497

1249712498
if (FunctionDecl *FunDecl = dyn_cast<FunctionDecl>(Fn)) {
12498-
if (S.getLangOpts().CUDA)
12499-
if (FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true))
12500-
if (!Caller->isImplicit() && !S.IsAllowedCUDACall(Caller, FunDecl))
12501-
return false;
12499+
if (S.getLangOpts().CUDA) {
12500+
FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true);
12501+
if (!(Caller && Caller->isImplicit()) &&
12502+
!S.IsAllowedCUDACall(Caller, FunDecl))
12503+
return false;
12504+
}
1250212505
if (FunDecl->isMultiVersion()) {
1250312506
const auto *TA = FunDecl->getAttr<TargetAttr>();
1250412507
if (TA && !TA->isDefaultVersion())

clang/lib/Sema/SemaType.cpp

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -366,12 +366,14 @@ enum TypeAttrLocation {
366366
TAL_DeclName
367367
};
368368

369-
static void processTypeAttrs(TypeProcessingState &state, QualType &type,
370-
TypeAttrLocation TAL,
371-
const ParsedAttributesView &attrs);
369+
static void
370+
processTypeAttrs(TypeProcessingState &state, QualType &type,
371+
TypeAttrLocation TAL, const ParsedAttributesView &attrs,
372+
Sema::CUDAFunctionTarget CFT = Sema::CFT_HostDevice);
372373

373374
static bool handleFunctionTypeAttr(TypeProcessingState &state, ParsedAttr &attr,
374-
QualType &type);
375+
QualType &type,
376+
Sema::CUDAFunctionTarget CFT);
375377

376378
static bool handleMSPointerTypeQualifierAttr(TypeProcessingState &state,
377379
ParsedAttr &attr, QualType &type);
@@ -617,7 +619,8 @@ static void distributeFunctionTypeAttr(TypeProcessingState &state,
617619
/// distributed, false if no location was found.
618620
static bool distributeFunctionTypeAttrToInnermost(
619621
TypeProcessingState &state, ParsedAttr &attr,
620-
ParsedAttributesView &attrList, QualType &declSpecType) {
622+
ParsedAttributesView &attrList, QualType &declSpecType,
623+
Sema::CUDAFunctionTarget CFT) {
621624
Declarator &declarator = state.getDeclarator();
622625

623626
// Put it on the innermost function chunk, if there is one.
@@ -629,19 +632,20 @@ static bool distributeFunctionTypeAttrToInnermost(
629632
return true;
630633
}
631634

632-
return handleFunctionTypeAttr(state, attr, declSpecType);
635+
return handleFunctionTypeAttr(state, attr, declSpecType, CFT);
633636
}
634637

635638
/// A function type attribute was written in the decl spec. Try to
636639
/// apply it somewhere.
637-
static void distributeFunctionTypeAttrFromDeclSpec(TypeProcessingState &state,
638-
ParsedAttr &attr,
639-
QualType &declSpecType) {
640+
static void
641+
distributeFunctionTypeAttrFromDeclSpec(TypeProcessingState &state,
642+
ParsedAttr &attr, QualType &declSpecType,
643+
Sema::CUDAFunctionTarget CFT) {
640644
state.saveDeclSpecAttrs();
641645

642646
// Try to distribute to the innermost.
643647
if (distributeFunctionTypeAttrToInnermost(
644-
state, attr, state.getCurrentAttributes(), declSpecType))
648+
state, attr, state.getCurrentAttributes(), declSpecType, CFT))
645649
return;
646650

647651
// If that failed, diagnose the bad attribute when the declarator is
@@ -653,14 +657,14 @@ static void distributeFunctionTypeAttrFromDeclSpec(TypeProcessingState &state,
653657
/// Try to apply it somewhere.
654658
/// `Attrs` is the attribute list containing the declaration (either of the
655659
/// declarator or the declaration).
656-
static void distributeFunctionTypeAttrFromDeclarator(TypeProcessingState &state,
657-
ParsedAttr &attr,
658-
QualType &declSpecType) {
660+
static void distributeFunctionTypeAttrFromDeclarator(
661+
TypeProcessingState &state, ParsedAttr &attr, QualType &declSpecType,
662+
Sema::CUDAFunctionTarget CFT) {
659663
Declarator &declarator = state.getDeclarator();
660664

661665
// Try to distribute to the innermost.
662666
if (distributeFunctionTypeAttrToInnermost(
663-
state, attr, declarator.getAttributes(), declSpecType))
667+
state, attr, declarator.getAttributes(), declSpecType, CFT))
664668
return;
665669

666670
// If that failed, diagnose the bad attribute when the declarator is
@@ -682,7 +686,8 @@ static void distributeFunctionTypeAttrFromDeclarator(TypeProcessingState &state,
682686
/// `Attrs` is the attribute list containing the declaration (either of the
683687
/// declarator or the declaration).
684688
static void distributeTypeAttrsFromDeclarator(TypeProcessingState &state,
685-
QualType &declSpecType) {
689+
QualType &declSpecType,
690+
Sema::CUDAFunctionTarget CFT) {
686691
// The called functions in this loop actually remove things from the current
687692
// list, so iterating over the existing list isn't possible. Instead, make a
688693
// non-owning copy and iterate over that.
@@ -699,7 +704,7 @@ static void distributeTypeAttrsFromDeclarator(TypeProcessingState &state,
699704
break;
700705

701706
FUNCTION_TYPE_ATTRS_CASELIST:
702-
distributeFunctionTypeAttrFromDeclarator(state, attr, declSpecType);
707+
distributeFunctionTypeAttrFromDeclarator(state, attr, declSpecType, CFT);
703708
break;
704709

705710
MS_TYPE_ATTRS_CASELIST:
@@ -3544,7 +3549,8 @@ static QualType GetDeclSpecTypeForDeclarator(TypeProcessingState &state,
35443549
// Note: We don't need to distribute declaration attributes (i.e.
35453550
// D.getDeclarationAttributes()) because those are always C++11 attributes,
35463551
// and those don't get distributed.
3547-
distributeTypeAttrsFromDeclarator(state, T);
3552+
distributeTypeAttrsFromDeclarator(
3553+
state, T, SemaRef.IdentifyCUDATarget(D.getAttributes()));
35483554

35493555
// Find the deduced type in this type. Look in the trailing return type if we
35503556
// have one, otherwise in the DeclSpec type.
@@ -4055,7 +4061,8 @@ static CallingConv getCCForDeclaratorChunk(
40554061
// function type. We'll diagnose the failure to apply them in
40564062
// handleFunctionTypeAttr.
40574063
CallingConv CC;
4058-
if (!S.CheckCallingConvAttr(AL, CC) &&
4064+
if (!S.CheckCallingConvAttr(AL, CC, /*FunctionDecl=*/nullptr,
4065+
S.IdentifyCUDATarget(D.getAttributes())) &&
40594066
(!FTI.isVariadic || supportsVariadicCall(CC))) {
40604067
return CC;
40614068
}
@@ -5727,7 +5734,8 @@ static TypeSourceInfo *GetFullTypeForDeclarator(TypeProcessingState &state,
57275734
}
57285735

57295736
// See if there are any attributes on this declarator chunk.
5730-
processTypeAttrs(state, T, TAL_DeclChunk, DeclType.getAttrs());
5737+
processTypeAttrs(state, T, TAL_DeclChunk, DeclType.getAttrs(),
5738+
S.IdentifyCUDATarget(D.getAttributes()));
57315739

57325740
if (DeclType.Kind != DeclaratorChunk::Paren) {
57335741
if (ExpectNoDerefChunk && !IsNoDerefableChunk(DeclType))
@@ -7801,7 +7809,8 @@ static bool checkMutualExclusion(TypeProcessingState &state,
78017809
/// Process an individual function attribute. Returns true to
78027810
/// indicate that the attribute was handled, false if it wasn't.
78037811
static bool handleFunctionTypeAttr(TypeProcessingState &state, ParsedAttr &attr,
7804-
QualType &type) {
7812+
QualType &type,
7813+
Sema::CUDAFunctionTarget CFT) {
78057814
Sema &S = state.getSema();
78067815

78077816
FunctionTypeUnwrapper unwrapped(S, type);
@@ -8032,7 +8041,7 @@ static bool handleFunctionTypeAttr(TypeProcessingState &state, ParsedAttr &attr,
80328041

80338042
// Otherwise, a calling convention.
80348043
CallingConv CC;
8035-
if (S.CheckCallingConvAttr(attr, CC))
8044+
if (S.CheckCallingConvAttr(attr, CC, /*FunctionDecl=*/nullptr, CFT))
80368045
return true;
80378046

80388047
const FunctionType *fn = unwrapped.get();
@@ -8584,7 +8593,8 @@ static void HandleLifetimeBoundAttr(TypeProcessingState &State,
85848593

85858594
static void processTypeAttrs(TypeProcessingState &state, QualType &type,
85868595
TypeAttrLocation TAL,
8587-
const ParsedAttributesView &attrs) {
8596+
const ParsedAttributesView &attrs,
8597+
Sema::CUDAFunctionTarget CFT) {
85888598

85898599
state.setParsedNoDeref(false);
85908600
if (attrs.empty())
@@ -8826,7 +8836,7 @@ static void processTypeAttrs(TypeProcessingState &state, QualType &type,
88268836
// appertain to and hence should not use the "distribution" logic below.
88278837
if (attr.isStandardAttributeSyntax() ||
88288838
attr.isRegularKeywordAttribute()) {
8829-
if (!handleFunctionTypeAttr(state, attr, type)) {
8839+
if (!handleFunctionTypeAttr(state, attr, type, CFT)) {
88308840
diagnoseBadTypeAttribute(state.getSema(), attr, type);
88318841
attr.setInvalid();
88328842
}
@@ -8836,10 +8846,10 @@ static void processTypeAttrs(TypeProcessingState &state, QualType &type,
88368846
// Never process function type attributes as part of the
88378847
// declaration-specifiers.
88388848
if (TAL == TAL_DeclSpec)
8839-
distributeFunctionTypeAttrFromDeclSpec(state, attr, type);
8849+
distributeFunctionTypeAttrFromDeclSpec(state, attr, type, CFT);
88408850

88418851
// Otherwise, handle the possible delays.
8842-
else if (!handleFunctionTypeAttr(state, attr, type))
8852+
else if (!handleFunctionTypeAttr(state, attr, type, CFT))
88438853
distributeFunctionTypeAttr(state, attr, type);
88448854
break;
88458855
case ParsedAttr::AT_AcquireHandle: {

0 commit comments

Comments
 (0)