@@ -64,39 +64,125 @@ static bool hasOptionalClassName(const CXXRecordDecl &RD) {
64
64
return false ;
65
65
}
66
66
67
+ static const CXXRecordDecl *getOptionalBaseClass (const CXXRecordDecl *RD) {
68
+ if (RD == nullptr )
69
+ return nullptr ;
70
+ if (hasOptionalClassName (*RD))
71
+ return RD;
72
+
73
+ if (!RD->hasDefinition ())
74
+ return nullptr ;
75
+
76
+ for (const CXXBaseSpecifier &Base : RD->bases ())
77
+ if (const CXXRecordDecl *BaseClass =
78
+ getOptionalBaseClass (Base.getType ()->getAsCXXRecordDecl ()))
79
+ return BaseClass;
80
+
81
+ return nullptr ;
82
+ }
83
+
67
84
namespace {
68
85
69
86
using namespace ::clang::ast_matchers;
70
87
using LatticeTransferState = TransferState<NoopLattice>;
71
88
72
- AST_MATCHER (CXXRecordDecl, hasOptionalClassNameMatcher) {
73
- return hasOptionalClassName (Node);
89
+ AST_MATCHER (CXXRecordDecl, optionalClass) { return hasOptionalClassName (Node); }
90
+
91
+ AST_MATCHER (CXXRecordDecl, optionalOrDerivedClass) {
92
+ return getOptionalBaseClass (&Node) != nullptr ;
74
93
}
75
94
76
- DeclarationMatcher optionalClass () {
77
- return classTemplateSpecializationDecl (
78
- hasOptionalClassNameMatcher (),
79
- hasTemplateArgument (0 , refersToType (type ().bind (" T" ))));
95
+ auto desugarsToOptionalType () {
96
+ return hasUnqualifiedDesugaredType (
97
+ recordType (hasDeclaration (cxxRecordDecl (optionalClass ()))));
80
98
}
81
99
82
- auto optionalOrAliasType () {
100
+ auto desugarsToOptionalOrDerivedType () {
83
101
return hasUnqualifiedDesugaredType (
84
- recordType (hasDeclaration (optionalClass ())));
102
+ recordType (hasDeclaration (cxxRecordDecl (optionalOrDerivedClass ()))));
103
+ }
104
+
105
+ auto hasOptionalType () { return hasType (desugarsToOptionalType ()); }
106
+
107
+ // / Matches any of the spellings of the optional types and sugar, aliases,
108
+ // / derived classes, etc.
109
+ auto hasOptionalOrDerivedType () {
110
+ return hasType (desugarsToOptionalOrDerivedType ());
111
+ }
112
+
113
+ QualType getPublicType (const Expr *E) {
114
+ auto *Cast = dyn_cast<ImplicitCastExpr>(E->IgnoreParens ());
115
+ if (Cast == nullptr || Cast->getCastKind () != CK_UncheckedDerivedToBase) {
116
+ QualType Ty = E->getType ();
117
+ if (Ty->isPointerType ())
118
+ return Ty->getPointeeType ();
119
+ return Ty;
120
+ }
121
+
122
+ // Is the derived type that we're casting from the type of `*this`? In this
123
+ // special case, we can upcast to the base class even if the base is
124
+ // non-public.
125
+ bool CastingFromThis = isa<CXXThisExpr>(Cast->getSubExpr ());
126
+
127
+ // Find the least-derived type in the path (i.e. the last entry in the list)
128
+ // that we can access.
129
+ const CXXBaseSpecifier *PublicBase = nullptr ;
130
+ for (const CXXBaseSpecifier *Base : Cast->path ()) {
131
+ if (Base->getAccessSpecifier () != AS_public && !CastingFromThis)
132
+ break ;
133
+ PublicBase = Base;
134
+ CastingFromThis = false ;
135
+ }
136
+
137
+ if (PublicBase != nullptr )
138
+ return PublicBase->getType ();
139
+
140
+ // We didn't find any public type that we could cast to. There may be more
141
+ // casts in `getSubExpr()`, so recurse. (If there aren't any more casts, this
142
+ // will return the type of `getSubExpr()`.)
143
+ return getPublicType (Cast->getSubExpr ());
85
144
}
86
145
87
- // / Matches any of the spellings of the optional types and sugar, aliases, etc.
88
- auto hasOptionalType () { return hasType (optionalOrAliasType ()); }
146
+ // Returns the least-derived type for the receiver of `MCE` that
147
+ // `MCE.getImplicitObjectArgument()->IgnoreParentImpCasts()` can be downcast to.
148
+ // Effectively, we upcast until we reach a non-public base class, unless that
149
+ // base is a base of `*this`.
150
+ //
151
+ // This is needed to correctly match methods called on types derived from
152
+ // `std::optional`.
153
+ //
154
+ // Say we have a `struct Derived : public std::optional<int> {} d;` For a call
155
+ // `d.has_value()`, the `getImplicitObjectArgument()` looks like this:
156
+ //
157
+ // ImplicitCastExpr 'const std::__optional_storage_base<int>' lvalue
158
+ // | <UncheckedDerivedToBase (optional -> __optional_storage_base)>
159
+ // `-DeclRefExpr 'Derived' lvalue Var 'd' 'Derived'
160
+ //
161
+ // The type of the implicit object argument is `__optional_storage_base`
162
+ // (since this is the internal type that `has_value()` is declared on). If we
163
+ // call `IgnoreParenImpCasts()` on the implicit object argument, we get the
164
+ // `DeclRefExpr`, which has type `Derived`. Neither of these types is
165
+ // `optional`, and hence neither is sufficient for querying whether we are
166
+ // calling a method on `optional`.
167
+ //
168
+ // Instead, starting with the most derived type, we need to follow the chain of
169
+ // casts
170
+ QualType getPublicReceiverType (const CXXMemberCallExpr &MCE) {
171
+ return getPublicType (MCE.getImplicitObjectArgument ());
172
+ }
173
+
174
+ AST_MATCHER_P (CXXMemberCallExpr, publicReceiverType,
175
+ ast_matchers::internal::Matcher<QualType>, InnerMatcher) {
176
+ return InnerMatcher.matches (getPublicReceiverType (Node), Finder, Builder);
177
+ }
89
178
90
179
auto isOptionalMemberCallWithNameMatcher (
91
180
ast_matchers::internal::Matcher<NamedDecl> matcher,
92
181
const std::optional<StatementMatcher> &Ignorable = std::nullopt) {
93
- auto Exception = unless (Ignorable ? expr (anyOf (*Ignorable, cxxThisExpr ()))
94
- : cxxThisExpr ());
95
- return cxxMemberCallExpr (
96
- on (expr (Exception,
97
- anyOf (hasOptionalType (),
98
- hasType (pointerType (pointee (optionalOrAliasType ())))))),
99
- callee (cxxMethodDecl (matcher)));
182
+ return cxxMemberCallExpr (Ignorable ? on (expr (unless (*Ignorable)))
183
+ : anything (),
184
+ publicReceiverType (desugarsToOptionalType ()),
185
+ callee (cxxMethodDecl (matcher)));
100
186
}
101
187
102
188
auto isOptionalOperatorCallWithName (
@@ -129,49 +215,51 @@ auto inPlaceClass() {
129
215
130
216
auto isOptionalNulloptConstructor () {
131
217
return cxxConstructExpr (
132
- hasOptionalType (),
133
218
hasDeclaration (cxxConstructorDecl (parameterCountIs (1 ),
134
- hasParameter (0 , hasNulloptType ()))));
219
+ hasParameter (0 , hasNulloptType ()))),
220
+ hasOptionalOrDerivedType ());
135
221
}
136
222
137
223
auto isOptionalInPlaceConstructor () {
138
- return cxxConstructExpr (hasOptionalType ( ),
139
- hasArgument ( 0 , hasType ( inPlaceClass ()) ));
224
+ return cxxConstructExpr (hasArgument ( 0 , hasType ( inPlaceClass ()) ),
225
+ hasOptionalOrDerivedType ( ));
140
226
}
141
227
142
228
auto isOptionalValueOrConversionConstructor () {
143
229
return cxxConstructExpr (
144
- hasOptionalType (),
145
230
unless (hasDeclaration (
146
231
cxxConstructorDecl (anyOf (isCopyConstructor (), isMoveConstructor ())))),
147
- argumentCountIs (1 ), hasArgument (0 , unless (hasNulloptType ())));
232
+ argumentCountIs (1 ), hasArgument (0 , unless (hasNulloptType ())),
233
+ hasOptionalOrDerivedType ());
148
234
}
149
235
150
236
auto isOptionalValueOrConversionAssignment () {
151
237
return cxxOperatorCallExpr (
152
238
hasOverloadedOperatorName (" =" ),
153
- callee (cxxMethodDecl (ofClass (optionalClass ()))),
239
+ callee (cxxMethodDecl (ofClass (optionalOrDerivedClass ()))),
154
240
unless (hasDeclaration (cxxMethodDecl (
155
241
anyOf (isCopyAssignmentOperator (), isMoveAssignmentOperator ())))),
156
242
argumentCountIs (2 ), hasArgument (1 , unless (hasNulloptType ())));
157
243
}
158
244
159
245
auto isOptionalNulloptAssignment () {
160
- return cxxOperatorCallExpr (hasOverloadedOperatorName ( " = " ),
161
- callee ( cxxMethodDecl ( ofClass ( optionalClass ())) ),
162
- argumentCountIs ( 2 ),
163
- hasArgument (1 , hasNulloptType ()));
246
+ return cxxOperatorCallExpr (
247
+ hasOverloadedOperatorName ( " = " ),
248
+ callee ( cxxMethodDecl ( ofClass ( optionalOrDerivedClass ())) ),
249
+ argumentCountIs ( 2 ), hasArgument (1 , hasNulloptType ()));
164
250
}
165
251
166
252
auto isStdSwapCall () {
167
253
return callExpr (callee (functionDecl (hasName (" std::swap" ))),
168
- argumentCountIs (2 ), hasArgument (0 , hasOptionalType ()),
169
- hasArgument (1 , hasOptionalType ()));
254
+ argumentCountIs (2 ),
255
+ hasArgument (0 , hasOptionalOrDerivedType ()),
256
+ hasArgument (1 , hasOptionalOrDerivedType ()));
170
257
}
171
258
172
259
auto isStdForwardCall () {
173
260
return callExpr (callee (functionDecl (hasName (" std::forward" ))),
174
- argumentCountIs (1 ), hasArgument (0 , hasOptionalType ()));
261
+ argumentCountIs (1 ),
262
+ hasArgument (0 , hasOptionalOrDerivedType ()));
175
263
}
176
264
177
265
constexpr llvm::StringLiteral ValueOrCallID = " ValueOrCall" ;
@@ -212,8 +300,9 @@ auto isValueOrNotEqX() {
212
300
}
213
301
214
302
auto isCallReturningOptional () {
215
- return callExpr (hasType (qualType (anyOf (
216
- optionalOrAliasType (), referenceType (pointee (optionalOrAliasType ()))))));
303
+ return callExpr (hasType (qualType (
304
+ anyOf (desugarsToOptionalOrDerivedType (),
305
+ referenceType (pointee (desugarsToOptionalOrDerivedType ()))))));
217
306
}
218
307
219
308
template <typename L, typename R>
@@ -275,28 +364,23 @@ BoolValue *getHasValue(Environment &Env, RecordStorageLocation *OptionalLoc) {
275
364
return HasValueVal;
276
365
}
277
366
278
- // / Returns true if and only if `Type` is an optional type.
279
- bool isOptionalType (QualType Type) {
280
- if (!Type->isRecordType ())
281
- return false ;
282
- const CXXRecordDecl *D = Type->getAsCXXRecordDecl ();
283
- return D != nullptr && hasOptionalClassName (*D);
367
+ QualType valueTypeFromOptionalDecl (const CXXRecordDecl &RD) {
368
+ auto &CTSD = cast<ClassTemplateSpecializationDecl>(RD);
369
+ return CTSD.getTemplateArgs ()[0 ].getAsType ();
284
370
}
285
371
286
372
// / Returns the number of optional wrappers in `Type`.
287
373
// /
288
374
// / For example, if `Type` is `optional<optional<int>>`, the result of this
289
375
// / function will be 2.
290
376
int countOptionalWrappers (const ASTContext &ASTCtx, QualType Type) {
291
- if (!isOptionalType (Type))
377
+ const CXXRecordDecl *Optional =
378
+ getOptionalBaseClass (Type->getAsCXXRecordDecl ());
379
+ if (Optional == nullptr )
292
380
return 0 ;
293
381
return 1 + countOptionalWrappers (
294
382
ASTCtx,
295
- cast<ClassTemplateSpecializationDecl>(Type->getAsRecordDecl ())
296
- ->getTemplateArgs ()
297
- .get (0 )
298
- .getAsType ()
299
- .getDesugaredType (ASTCtx));
383
+ valueTypeFromOptionalDecl (*Optional).getDesugaredType (ASTCtx));
300
384
}
301
385
302
386
StorageLocation *getLocBehindPossiblePointer (const Expr &E,
@@ -843,13 +927,7 @@ auto buildDiagnoseMatchSwitch(
843
927
844
928
ast_matchers::DeclarationMatcher
845
929
UncheckedOptionalAccessModel::optionalClassDecl () {
846
- return optionalClass ();
847
- }
848
-
849
- static QualType valueTypeFromOptionalType (QualType OptionalTy) {
850
- auto *CTSD =
851
- cast<ClassTemplateSpecializationDecl>(OptionalTy->getAsCXXRecordDecl ());
852
- return CTSD->getTemplateArgs ()[0 ].getAsType ();
930
+ return cxxRecordDecl (optionalClass ());
853
931
}
854
932
855
933
UncheckedOptionalAccessModel::UncheckedOptionalAccessModel (ASTContext &Ctx,
@@ -858,9 +936,11 @@ UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
858
936
TransferMatchSwitch (buildTransferMatchSwitch()) {
859
937
Env.getDataflowAnalysisContext ().setSyntheticFieldCallback (
860
938
[&Ctx](QualType Ty) -> llvm::StringMap<QualType> {
861
- if (!isOptionalType (Ty))
939
+ const CXXRecordDecl *Optional =
940
+ getOptionalBaseClass (Ty->getAsCXXRecordDecl ());
941
+ if (Optional == nullptr )
862
942
return {};
863
- return {{" value" , valueTypeFromOptionalType (Ty )},
943
+ return {{" value" , valueTypeFromOptionalDecl (*Optional )},
864
944
{" has_value" , Ctx.BoolTy }};
865
945
});
866
946
}
0 commit comments