Skip to content

Commit fa22fed

Browse files
yxsamliuDavid Salinas
authored and
David Salinas
committed
Reland "[CUDA][HIP] Fix overloading resolution in global var init" (llvm#65606)
Cherry-pick 9b77638 https://reviews.llvm.org/D158247 caused regressions for HIP on Windows and was reverted. A reduced test case is: ``` typedef void (__stdcall* funcTy)(); void invoke(funcTy f); static void __stdcall callee() noexcept { } void foo() { invoke(callee); } ``` It is due to clang missing handling host/device attributes for calling convention at a few places This patch fixes that. Change-Id: Ibbc5cbe232d73ddaeb91f13f6afbff3151c9bf0b
1 parent 2f3a56b commit fa22fed

13 files changed

+272
-93
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,4 @@ pythonenv*
7070
/clang/utils/analyzer/projects/*/RefScanBuildResults
7171
# automodapi puts generated documentation files here.
7272
/lldb/docs/python_api/
73+
/Debug/

clang/include/clang/Sema/Sema.h

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,14 @@ class Sema final {
994994
}
995995
} DelayedDiagnostics;
996996

997+
enum CUDAFunctionTarget {
998+
CFT_Device,
999+
CFT_Global,
1000+
CFT_Host,
1001+
CFT_HostDevice,
1002+
CFT_InvalidTarget
1003+
};
1004+
9971005
/// A RAII object to temporarily push a declaration context.
9981006
class ContextRAII {
9991007
private:
@@ -4696,8 +4704,13 @@ class Sema final {
46964704
bool isValidPointerAttrType(QualType T, bool RefOkay = false);
46974705

46984706
bool CheckRegparmAttr(const ParsedAttr &attr, unsigned &value);
4707+
4708+
/// Check validaty of calling convention attribute \p attr. If \p FD
4709+
/// is not null pointer, use \p FD to determine the CUDA/HIP host/device
4710+
/// target. Otherwise, it is specified by \p CFT.
46994711
bool CheckCallingConvAttr(const ParsedAttr &attr, CallingConv &CC,
4700-
const FunctionDecl *FD = nullptr);
4712+
const FunctionDecl *FD = nullptr,
4713+
CUDAFunctionTarget CFT = CFT_InvalidTarget);
47014714
bool CheckAttrTarget(const ParsedAttr &CurrAttr);
47024715
bool CheckAttrNoArgs(const ParsedAttr &CurrAttr);
47034716
bool checkStringLiteralArgumentAttr(const AttributeCommonInfo &CI,
@@ -13094,14 +13107,6 @@ class Sema final {
1309413107
void checkTypeSupport(QualType Ty, SourceLocation Loc,
1309513108
ValueDecl *D = nullptr);
1309613109

13097-
enum CUDAFunctionTarget {
13098-
CFT_Device,
13099-
CFT_Global,
13100-
CFT_Host,
13101-
CFT_HostDevice,
13102-
CFT_InvalidTarget
13103-
};
13104-
1310513110
/// Determines whether the given function is a CUDA device/host/kernel/etc.
1310613111
/// function.
1310713112
///
@@ -13120,6 +13125,29 @@ class Sema final {
1312013125
/// Determines whether the given variable is emitted on host or device side.
1312113126
CUDAVariableTarget IdentifyCUDATarget(const VarDecl *D);
1312213127

13128+
/// Defines kinds of CUDA global host/device context where a function may be
13129+
/// called.
13130+
enum CUDATargetContextKind {
13131+
CTCK_Unknown, /// Unknown context
13132+
CTCK_InitGlobalVar, /// Function called during global variable
13133+
/// initialization
13134+
};
13135+
13136+
/// Define the current global CUDA host/device context where a function may be
13137+
/// called. Only used when a function is called outside of any functions.
13138+
struct CUDATargetContext {
13139+
CUDAFunctionTarget Target = CFT_HostDevice;
13140+
CUDATargetContextKind Kind = CTCK_Unknown;
13141+
Decl *D = nullptr;
13142+
} CurCUDATargetCtx;
13143+
13144+
struct CUDATargetContextRAII {
13145+
Sema &S;
13146+
CUDATargetContext SavedCtx;
13147+
CUDATargetContextRAII(Sema &S_, CUDATargetContextKind K, Decl *D);
13148+
~CUDATargetContextRAII() { S.CurCUDATargetCtx = SavedCtx; }
13149+
};
13150+
1312313151
/// Gets the CUDA target for the current context.
1312413152
CUDAFunctionTarget CurrentCUDATarget() {
1312513153
return IdentifyCUDATarget(dyn_cast<FunctionDecl>(CurContext));

clang/lib/Parse/ParseDecl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2479,6 +2479,7 @@ Decl *Parser::ParseDeclarationAfterDeclaratorAndAttributes(
24792479
}
24802480
}
24812481

2482+
Sema::CUDATargetContextRAII X(Actions, Sema::CTCK_InitGlobalVar, ThisDecl);
24822483
switch (TheInitKind) {
24832484
// Parse declarator '=' initializer.
24842485
case InitKind::Equal: {

clang/lib/Sema/SemaCUDA.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,37 @@ Sema::IdentifyCUDATarget(const ParsedAttributesView &Attrs) {
105105
}
106106

107107
template <typename A>
108-
static bool hasAttr(const FunctionDecl *D, bool IgnoreImplicitAttr) {
108+
static bool hasAttr(const Decl *D, bool IgnoreImplicitAttr) {
109109
return D->hasAttrs() && llvm::any_of(D->getAttrs(), [&](Attr *Attribute) {
110110
return isa<A>(Attribute) &&
111111
!(IgnoreImplicitAttr && Attribute->isImplicit());
112112
});
113113
}
114114

115+
Sema::CUDATargetContextRAII::CUDATargetContextRAII(Sema &S_,
116+
CUDATargetContextKind K,
117+
Decl *D)
118+
: S(S_) {
119+
SavedCtx = S.CurCUDATargetCtx;
120+
assert(K == CTCK_InitGlobalVar);
121+
auto *VD = dyn_cast_or_null<VarDecl>(D);
122+
if (VD && VD->hasGlobalStorage() && !VD->isStaticLocal()) {
123+
auto Target = CFT_Host;
124+
if ((hasAttr<CUDADeviceAttr>(VD, /*IgnoreImplicit=*/true) &&
125+
!hasAttr<CUDAHostAttr>(VD, /*IgnoreImplicit=*/true)) ||
126+
hasAttr<CUDASharedAttr>(VD, /*IgnoreImplicit=*/true) ||
127+
hasAttr<CUDAConstantAttr>(VD, /*IgnoreImplicit=*/true))
128+
Target = CFT_Device;
129+
S.CurCUDATargetCtx = {Target, K, VD};
130+
}
131+
}
132+
115133
/// IdentifyCUDATarget - Determine the CUDA compilation target for this function
116134
Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
117135
bool IgnoreImplicitHDAttr) {
118-
// Code that lives outside a function is run on the host.
136+
// Code that lives outside a function gets the target from CurCUDATargetCtx.
119137
if (D == nullptr)
120-
return CFT_Host;
138+
return CurCUDATargetCtx.Target;
121139

122140
if (D->hasAttr<CUDAInvalidTargetAttr>())
123141
return CFT_InvalidTarget;

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5030,7 +5030,8 @@ static void handleCallConvAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
50305030
// Diagnostic is emitted elsewhere: here we store the (valid) AL
50315031
// in the Decl node for syntactic reasoning, e.g., pretty-printing.
50325032
CallingConv CC;
5033-
if (S.CheckCallingConvAttr(AL, CC, /*FD*/nullptr))
5033+
if (S.CheckCallingConvAttr(AL, CC, /*FD*/ nullptr,
5034+
S.IdentifyCUDATarget(dyn_cast<FunctionDecl>(D))))
50345035
return;
50355036

50365037
if (!isa<ObjCMethodDecl>(D)) {
@@ -5211,7 +5212,8 @@ static void handleNoRandomizeLayoutAttr(Sema &S, Decl *D,
52115212
}
52125213

52135214
bool Sema::CheckCallingConvAttr(const ParsedAttr &Attrs, CallingConv &CC,
5214-
const FunctionDecl *FD) {
5215+
const FunctionDecl *FD,
5216+
CUDAFunctionTarget CFT) {
52155217
if (Attrs.isInvalid())
52165218
return true;
52175219

@@ -5310,7 +5312,8 @@ bool Sema::CheckCallingConvAttr(const ParsedAttr &Attrs, CallingConv &CC,
53105312
// on their host/device attributes.
53115313
if (LangOpts.CUDA) {
53125314
auto *Aux = Context.getAuxTargetInfo();
5313-
auto CudaTarget = IdentifyCUDATarget(FD);
5315+
assert(FD || CFT != CFT_InvalidTarget);
5316+
auto CudaTarget = FD ? IdentifyCUDATarget(FD) : CFT;
53145317
bool CheckHost = false, CheckDevice = false;
53155318
switch (CudaTarget) {
53165319
case CFT_HostDevice:

clang/lib/Sema/SemaOverload.cpp

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6608,17 +6608,19 @@ void Sema::AddOverloadCandidate(
66086608
}
66096609

66106610
// (CUDA B.1): Check for invalid calls between targets.
6611-
if (getLangOpts().CUDA)
6612-
if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
6613-
// Skip the check for callers that are implicit members, because in this
6614-
// case we may not yet know what the member's target is; the target is
6615-
// inferred for the member automatically, based on the bases and fields of
6616-
// the class.
6617-
if (!Caller->isImplicit() && !IsAllowedCUDACall(Caller, Function)) {
6618-
Candidate.Viable = false;
6619-
Candidate.FailureKind = ovl_fail_bad_target;
6620-
return;
6621-
}
6611+
if (getLangOpts().CUDA) {
6612+
const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true);
6613+
// Skip the check for callers that are implicit members, because in this
6614+
// case we may not yet know what the member's target is; the target is
6615+
// inferred for the member automatically, based on the bases and fields of
6616+
// the class.
6617+
if (!(Caller && Caller->isImplicit()) &&
6618+
!IsAllowedCUDACall(Caller, Function)) {
6619+
Candidate.Viable = false;
6620+
Candidate.FailureKind = ovl_fail_bad_target;
6621+
return;
6622+
}
6623+
}
66226624

66236625
if (Function->getTrailingRequiresClause()) {
66246626
ConstraintSatisfaction Satisfaction;
@@ -7130,12 +7132,11 @@ Sema::AddMethodCandidate(CXXMethodDecl *Method, DeclAccessPair FoundDecl,
71307132

71317133
// (CUDA B.1): Check for invalid calls between targets.
71327134
if (getLangOpts().CUDA)
7133-
if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
7134-
if (!IsAllowedCUDACall(Caller, Method)) {
7135-
Candidate.Viable = false;
7136-
Candidate.FailureKind = ovl_fail_bad_target;
7137-
return;
7138-
}
7135+
if (!IsAllowedCUDACall(getCurFunctionDecl(/*AllowLambda=*/true), Method)) {
7136+
Candidate.Viable = false;
7137+
Candidate.FailureKind = ovl_fail_bad_target;
7138+
return;
7139+
}
71397140

71407141
if (Method->getTrailingRequiresClause()) {
71417142
ConstraintSatisfaction Satisfaction;
@@ -12383,10 +12384,12 @@ class AddressOfFunctionResolver {
1238312384
return false;
1238412385

1238512386
if (FunctionDecl *FunDecl = dyn_cast<FunctionDecl>(Fn)) {
12386-
if (S.getLangOpts().CUDA)
12387-
if (FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true))
12388-
if (!Caller->isImplicit() && !S.IsAllowedCUDACall(Caller, FunDecl))
12389-
return false;
12387+
if (S.getLangOpts().CUDA) {
12388+
FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true);
12389+
if (!(Caller && Caller->isImplicit()) &&
12390+
!S.IsAllowedCUDACall(Caller, FunDecl))
12391+
return false;
12392+
}
1239012393
if (FunDecl->isMultiVersion()) {
1239112394
const auto *TA = FunDecl->getAttr<TargetAttr>();
1239212395
if (TA && !TA->isDefaultVersion())

clang/lib/Sema/SemaType.cpp

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,14 @@ enum TypeAttrLocation {
359359
TAL_DeclName
360360
};
361361

362-
static void processTypeAttrs(TypeProcessingState &state, QualType &type,
363-
TypeAttrLocation TAL,
364-
const ParsedAttributesView &attrs);
362+
static void
363+
processTypeAttrs(TypeProcessingState &state, QualType &type,
364+
TypeAttrLocation TAL, const ParsedAttributesView &attrs,
365+
Sema::CUDAFunctionTarget CFT = Sema::CFT_HostDevice);
365366

366367
static bool handleFunctionTypeAttr(TypeProcessingState &state, ParsedAttr &attr,
367-
QualType &type);
368+
QualType &type,
369+
Sema::CUDAFunctionTarget CFT);
368370

369371
static bool handleMSPointerTypeQualifierAttr(TypeProcessingState &state,
370372
ParsedAttr &attr, QualType &type);
@@ -610,7 +612,8 @@ static void distributeFunctionTypeAttr(TypeProcessingState &state,
610612
/// distributed, false if no location was found.
611613
static bool distributeFunctionTypeAttrToInnermost(
612614
TypeProcessingState &state, ParsedAttr &attr,
613-
ParsedAttributesView &attrList, QualType &declSpecType) {
615+
ParsedAttributesView &attrList, QualType &declSpecType,
616+
Sema::CUDAFunctionTarget CFT) {
614617
Declarator &declarator = state.getDeclarator();
615618

616619
// Put it on the innermost function chunk, if there is one.
@@ -622,19 +625,20 @@ static bool distributeFunctionTypeAttrToInnermost(
622625
return true;
623626
}
624627

625-
return handleFunctionTypeAttr(state, attr, declSpecType);
628+
return handleFunctionTypeAttr(state, attr, declSpecType, CFT);
626629
}
627630

628631
/// A function type attribute was written in the decl spec. Try to
629632
/// apply it somewhere.
630-
static void distributeFunctionTypeAttrFromDeclSpec(TypeProcessingState &state,
631-
ParsedAttr &attr,
632-
QualType &declSpecType) {
633+
static void
634+
distributeFunctionTypeAttrFromDeclSpec(TypeProcessingState &state,
635+
ParsedAttr &attr, QualType &declSpecType,
636+
Sema::CUDAFunctionTarget CFT) {
633637
state.saveDeclSpecAttrs();
634638

635639
// Try to distribute to the innermost.
636640
if (distributeFunctionTypeAttrToInnermost(
637-
state, attr, state.getCurrentAttributes(), declSpecType))
641+
state, attr, state.getCurrentAttributes(), declSpecType, CFT))
638642
return;
639643

640644
// If that failed, diagnose the bad attribute when the declarator is
@@ -646,14 +650,14 @@ static void distributeFunctionTypeAttrFromDeclSpec(TypeProcessingState &state,
646650
/// Try to apply it somewhere.
647651
/// `Attrs` is the attribute list containing the declaration (either of the
648652
/// declarator or the declaration).
649-
static void distributeFunctionTypeAttrFromDeclarator(TypeProcessingState &state,
650-
ParsedAttr &attr,
651-
QualType &declSpecType) {
653+
static void distributeFunctionTypeAttrFromDeclarator(
654+
TypeProcessingState &state, ParsedAttr &attr, QualType &declSpecType,
655+
Sema::CUDAFunctionTarget CFT) {
652656
Declarator &declarator = state.getDeclarator();
653657

654658
// Try to distribute to the innermost.
655659
if (distributeFunctionTypeAttrToInnermost(
656-
state, attr, declarator.getAttributes(), declSpecType))
660+
state, attr, declarator.getAttributes(), declSpecType, CFT))
657661
return;
658662

659663
// If that failed, diagnose the bad attribute when the declarator is
@@ -675,7 +679,8 @@ static void distributeFunctionTypeAttrFromDeclarator(TypeProcessingState &state,
675679
/// `Attrs` is the attribute list containing the declaration (either of the
676680
/// declarator or the declaration).
677681
static void distributeTypeAttrsFromDeclarator(TypeProcessingState &state,
678-
QualType &declSpecType) {
682+
QualType &declSpecType,
683+
Sema::CUDAFunctionTarget CFT) {
679684
// The called functions in this loop actually remove things from the current
680685
// list, so iterating over the existing list isn't possible. Instead, make a
681686
// non-owning copy and iterate over that.
@@ -692,7 +697,7 @@ static void distributeTypeAttrsFromDeclarator(TypeProcessingState &state,
692697
break;
693698

694699
FUNCTION_TYPE_ATTRS_CASELIST:
695-
distributeFunctionTypeAttrFromDeclarator(state, attr, declSpecType);
700+
distributeFunctionTypeAttrFromDeclarator(state, attr, declSpecType, CFT);
696701
break;
697702

698703
MS_TYPE_ATTRS_CASELIST:
@@ -3510,7 +3515,8 @@ static QualType GetDeclSpecTypeForDeclarator(TypeProcessingState &state,
35103515
// Note: We don't need to distribute declaration attributes (i.e.
35113516
// D.getDeclarationAttributes()) because those are always C++11 attributes,
35123517
// and those don't get distributed.
3513-
distributeTypeAttrsFromDeclarator(state, T);
3518+
distributeTypeAttrsFromDeclarator(
3519+
state, T, SemaRef.IdentifyCUDATarget(D.getAttributes()));
35143520

35153521
// Find the deduced type in this type. Look in the trailing return type if we
35163522
// have one, otherwise in the DeclSpec type.
@@ -4021,7 +4027,8 @@ static CallingConv getCCForDeclaratorChunk(
40214027
// function type. We'll diagnose the failure to apply them in
40224028
// handleFunctionTypeAttr.
40234029
CallingConv CC;
4024-
if (!S.CheckCallingConvAttr(AL, CC) &&
4030+
if (!S.CheckCallingConvAttr(AL, CC, /*FunctionDecl=*/nullptr,
4031+
S.IdentifyCUDATarget(D.getAttributes())) &&
40254032
(!FTI.isVariadic || supportsVariadicCall(CC))) {
40264033
return CC;
40274034
}
@@ -5665,7 +5672,8 @@ static TypeSourceInfo *GetFullTypeForDeclarator(TypeProcessingState &state,
56655672
}
56665673

56675674
// See if there are any attributes on this declarator chunk.
5668-
processTypeAttrs(state, T, TAL_DeclChunk, DeclType.getAttrs());
5675+
processTypeAttrs(state, T, TAL_DeclChunk, DeclType.getAttrs(),
5676+
S.IdentifyCUDATarget(D.getAttributes()));
56695677

56705678
if (DeclType.Kind != DeclaratorChunk::Paren) {
56715679
if (ExpectNoDerefChunk && !IsNoDerefableChunk(DeclType))
@@ -7709,7 +7717,8 @@ static Attr *getCCTypeAttr(ASTContext &Ctx, ParsedAttr &Attr) {
77097717
/// Process an individual function attribute. Returns true to
77107718
/// indicate that the attribute was handled, false if it wasn't.
77117719
static bool handleFunctionTypeAttr(TypeProcessingState &state, ParsedAttr &attr,
7712-
QualType &type) {
7720+
QualType &type,
7721+
Sema::CUDAFunctionTarget CFT) {
77137722
Sema &S = state.getSema();
77147723

77157724
FunctionTypeUnwrapper unwrapped(S, type);
@@ -7891,7 +7900,7 @@ static bool handleFunctionTypeAttr(TypeProcessingState &state, ParsedAttr &attr,
78917900

78927901
// Otherwise, a calling convention.
78937902
CallingConv CC;
7894-
if (S.CheckCallingConvAttr(attr, CC))
7903+
if (S.CheckCallingConvAttr(attr, CC, /*FunctionDecl=*/nullptr, CFT))
78957904
return true;
78967905

78977906
const FunctionType *fn = unwrapped.get();
@@ -8369,7 +8378,8 @@ static void HandleLifetimeBoundAttr(TypeProcessingState &State,
83698378

83708379
static void processTypeAttrs(TypeProcessingState &state, QualType &type,
83718380
TypeAttrLocation TAL,
8372-
const ParsedAttributesView &attrs) {
8381+
const ParsedAttributesView &attrs,
8382+
Sema::CUDAFunctionTarget CFT) {
83738383

83748384
state.setParsedNoDeref(false);
83758385
if (attrs.empty())
@@ -8603,7 +8613,7 @@ static void processTypeAttrs(TypeProcessingState &state, QualType &type,
86038613
// Attributes with standard syntax have strict rules for what they
86048614
// appertain to and hence should not use the "distribution" logic below.
86058615
if (attr.isStandardAttributeSyntax()) {
8606-
if (!handleFunctionTypeAttr(state, attr, type)) {
8616+
if (!handleFunctionTypeAttr(state, attr, type, CFT)) {
86078617
diagnoseBadTypeAttribute(state.getSema(), attr, type);
86088618
attr.setInvalid();
86098619
}
@@ -8613,10 +8623,10 @@ static void processTypeAttrs(TypeProcessingState &state, QualType &type,
86138623
// Never process function type attributes as part of the
86148624
// declaration-specifiers.
86158625
if (TAL == TAL_DeclSpec)
8616-
distributeFunctionTypeAttrFromDeclSpec(state, attr, type);
8626+
distributeFunctionTypeAttrFromDeclSpec(state, attr, type, CFT);
86178627

86188628
// Otherwise, handle the possible delays.
8619-
else if (!handleFunctionTypeAttr(state, attr, type))
8629+
else if (!handleFunctionTypeAttr(state, attr, type, CFT))
86208630
distributeFunctionTypeAttr(state, attr, type);
86218631
break;
86228632
case ParsedAttr::AT_AcquireHandle: {

0 commit comments

Comments
 (0)