Skip to content

Commit cec17d7

Browse files
martinboehmechencha3
authored andcommitted
[clang][dataflow] Make optional checker work for types derived from optional. (llvm#84138)
`llvm::MaybeAlign` does this, for example. It's not an option to simply ignore these derived classes because they get cast back to the optional classes (for example, simply when calling the optional member functions), and our transfer functions will then run on those optional classes and therefore require them to be properly initialized.
1 parent 51c3f66 commit cec17d7

File tree

2 files changed

+194
-54
lines changed

2 files changed

+194
-54
lines changed

clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp

Lines changed: 134 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -64,39 +64,125 @@ static bool hasOptionalClassName(const CXXRecordDecl &RD) {
6464
return false;
6565
}
6666

67+
static const CXXRecordDecl *getOptionalBaseClass(const CXXRecordDecl *RD) {
68+
if (RD == nullptr)
69+
return nullptr;
70+
if (hasOptionalClassName(*RD))
71+
return RD;
72+
73+
if (!RD->hasDefinition())
74+
return nullptr;
75+
76+
for (const CXXBaseSpecifier &Base : RD->bases())
77+
if (const CXXRecordDecl *BaseClass =
78+
getOptionalBaseClass(Base.getType()->getAsCXXRecordDecl()))
79+
return BaseClass;
80+
81+
return nullptr;
82+
}
83+
6784
namespace {
6885

6986
using namespace ::clang::ast_matchers;
7087
using LatticeTransferState = TransferState<NoopLattice>;
7188

72-
AST_MATCHER(CXXRecordDecl, hasOptionalClassNameMatcher) {
73-
return hasOptionalClassName(Node);
89+
AST_MATCHER(CXXRecordDecl, optionalClass) { return hasOptionalClassName(Node); }
90+
91+
AST_MATCHER(CXXRecordDecl, optionalOrDerivedClass) {
92+
return getOptionalBaseClass(&Node) != nullptr;
7493
}
7594

76-
DeclarationMatcher optionalClass() {
77-
return classTemplateSpecializationDecl(
78-
hasOptionalClassNameMatcher(),
79-
hasTemplateArgument(0, refersToType(type().bind("T"))));
95+
auto desugarsToOptionalType() {
96+
return hasUnqualifiedDesugaredType(
97+
recordType(hasDeclaration(cxxRecordDecl(optionalClass()))));
8098
}
8199

82-
auto optionalOrAliasType() {
100+
auto desugarsToOptionalOrDerivedType() {
83101
return hasUnqualifiedDesugaredType(
84-
recordType(hasDeclaration(optionalClass())));
102+
recordType(hasDeclaration(cxxRecordDecl(optionalOrDerivedClass()))));
103+
}
104+
105+
auto hasOptionalType() { return hasType(desugarsToOptionalType()); }
106+
107+
/// Matches any of the spellings of the optional types and sugar, aliases,
108+
/// derived classes, etc.
109+
auto hasOptionalOrDerivedType() {
110+
return hasType(desugarsToOptionalOrDerivedType());
111+
}
112+
113+
QualType getPublicType(const Expr *E) {
114+
auto *Cast = dyn_cast<ImplicitCastExpr>(E->IgnoreParens());
115+
if (Cast == nullptr || Cast->getCastKind() != CK_UncheckedDerivedToBase) {
116+
QualType Ty = E->getType();
117+
if (Ty->isPointerType())
118+
return Ty->getPointeeType();
119+
return Ty;
120+
}
121+
122+
// Is the derived type that we're casting from the type of `*this`? In this
123+
// special case, we can upcast to the base class even if the base is
124+
// non-public.
125+
bool CastingFromThis = isa<CXXThisExpr>(Cast->getSubExpr());
126+
127+
// Find the least-derived type in the path (i.e. the last entry in the list)
128+
// that we can access.
129+
const CXXBaseSpecifier *PublicBase = nullptr;
130+
for (const CXXBaseSpecifier *Base : Cast->path()) {
131+
if (Base->getAccessSpecifier() != AS_public && !CastingFromThis)
132+
break;
133+
PublicBase = Base;
134+
CastingFromThis = false;
135+
}
136+
137+
if (PublicBase != nullptr)
138+
return PublicBase->getType();
139+
140+
// We didn't find any public type that we could cast to. There may be more
141+
// casts in `getSubExpr()`, so recurse. (If there aren't any more casts, this
142+
// will return the type of `getSubExpr()`.)
143+
return getPublicType(Cast->getSubExpr());
85144
}
86145

87-
/// Matches any of the spellings of the optional types and sugar, aliases, etc.
88-
auto hasOptionalType() { return hasType(optionalOrAliasType()); }
146+
// Returns the least-derived type for the receiver of `MCE` that
147+
// `MCE.getImplicitObjectArgument()->IgnoreParentImpCasts()` can be downcast to.
148+
// Effectively, we upcast until we reach a non-public base class, unless that
149+
// base is a base of `*this`.
150+
//
151+
// This is needed to correctly match methods called on types derived from
152+
// `std::optional`.
153+
//
154+
// Say we have a `struct Derived : public std::optional<int> {} d;` For a call
155+
// `d.has_value()`, the `getImplicitObjectArgument()` looks like this:
156+
//
157+
// ImplicitCastExpr 'const std::__optional_storage_base<int>' lvalue
158+
// | <UncheckedDerivedToBase (optional -> __optional_storage_base)>
159+
// `-DeclRefExpr 'Derived' lvalue Var 'd' 'Derived'
160+
//
161+
// The type of the implicit object argument is `__optional_storage_base`
162+
// (since this is the internal type that `has_value()` is declared on). If we
163+
// call `IgnoreParenImpCasts()` on the implicit object argument, we get the
164+
// `DeclRefExpr`, which has type `Derived`. Neither of these types is
165+
// `optional`, and hence neither is sufficient for querying whether we are
166+
// calling a method on `optional`.
167+
//
168+
// Instead, starting with the most derived type, we need to follow the chain of
169+
// casts
170+
QualType getPublicReceiverType(const CXXMemberCallExpr &MCE) {
171+
return getPublicType(MCE.getImplicitObjectArgument());
172+
}
173+
174+
AST_MATCHER_P(CXXMemberCallExpr, publicReceiverType,
175+
ast_matchers::internal::Matcher<QualType>, InnerMatcher) {
176+
return InnerMatcher.matches(getPublicReceiverType(Node), Finder, Builder);
177+
}
89178

90179
auto isOptionalMemberCallWithNameMatcher(
91180
ast_matchers::internal::Matcher<NamedDecl> matcher,
92181
const std::optional<StatementMatcher> &Ignorable = std::nullopt) {
93-
auto Exception = unless(Ignorable ? expr(anyOf(*Ignorable, cxxThisExpr()))
94-
: cxxThisExpr());
95-
return cxxMemberCallExpr(
96-
on(expr(Exception,
97-
anyOf(hasOptionalType(),
98-
hasType(pointerType(pointee(optionalOrAliasType())))))),
99-
callee(cxxMethodDecl(matcher)));
182+
return cxxMemberCallExpr(Ignorable ? on(expr(unless(*Ignorable)))
183+
: anything(),
184+
publicReceiverType(desugarsToOptionalType()),
185+
callee(cxxMethodDecl(matcher)));
100186
}
101187

102188
auto isOptionalOperatorCallWithName(
@@ -129,49 +215,51 @@ auto inPlaceClass() {
129215

130216
auto isOptionalNulloptConstructor() {
131217
return cxxConstructExpr(
132-
hasOptionalType(),
133218
hasDeclaration(cxxConstructorDecl(parameterCountIs(1),
134-
hasParameter(0, hasNulloptType()))));
219+
hasParameter(0, hasNulloptType()))),
220+
hasOptionalOrDerivedType());
135221
}
136222

137223
auto isOptionalInPlaceConstructor() {
138-
return cxxConstructExpr(hasOptionalType(),
139-
hasArgument(0, hasType(inPlaceClass())));
224+
return cxxConstructExpr(hasArgument(0, hasType(inPlaceClass())),
225+
hasOptionalOrDerivedType());
140226
}
141227

142228
auto isOptionalValueOrConversionConstructor() {
143229
return cxxConstructExpr(
144-
hasOptionalType(),
145230
unless(hasDeclaration(
146231
cxxConstructorDecl(anyOf(isCopyConstructor(), isMoveConstructor())))),
147-
argumentCountIs(1), hasArgument(0, unless(hasNulloptType())));
232+
argumentCountIs(1), hasArgument(0, unless(hasNulloptType())),
233+
hasOptionalOrDerivedType());
148234
}
149235

150236
auto isOptionalValueOrConversionAssignment() {
151237
return cxxOperatorCallExpr(
152238
hasOverloadedOperatorName("="),
153-
callee(cxxMethodDecl(ofClass(optionalClass()))),
239+
callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
154240
unless(hasDeclaration(cxxMethodDecl(
155241
anyOf(isCopyAssignmentOperator(), isMoveAssignmentOperator())))),
156242
argumentCountIs(2), hasArgument(1, unless(hasNulloptType())));
157243
}
158244

159245
auto isOptionalNulloptAssignment() {
160-
return cxxOperatorCallExpr(hasOverloadedOperatorName("="),
161-
callee(cxxMethodDecl(ofClass(optionalClass()))),
162-
argumentCountIs(2),
163-
hasArgument(1, hasNulloptType()));
246+
return cxxOperatorCallExpr(
247+
hasOverloadedOperatorName("="),
248+
callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
249+
argumentCountIs(2), hasArgument(1, hasNulloptType()));
164250
}
165251

166252
auto isStdSwapCall() {
167253
return callExpr(callee(functionDecl(hasName("std::swap"))),
168-
argumentCountIs(2), hasArgument(0, hasOptionalType()),
169-
hasArgument(1, hasOptionalType()));
254+
argumentCountIs(2),
255+
hasArgument(0, hasOptionalOrDerivedType()),
256+
hasArgument(1, hasOptionalOrDerivedType()));
170257
}
171258

172259
auto isStdForwardCall() {
173260
return callExpr(callee(functionDecl(hasName("std::forward"))),
174-
argumentCountIs(1), hasArgument(0, hasOptionalType()));
261+
argumentCountIs(1),
262+
hasArgument(0, hasOptionalOrDerivedType()));
175263
}
176264

177265
constexpr llvm::StringLiteral ValueOrCallID = "ValueOrCall";
@@ -212,8 +300,9 @@ auto isValueOrNotEqX() {
212300
}
213301

214302
auto isCallReturningOptional() {
215-
return callExpr(hasType(qualType(anyOf(
216-
optionalOrAliasType(), referenceType(pointee(optionalOrAliasType()))))));
303+
return callExpr(hasType(qualType(
304+
anyOf(desugarsToOptionalOrDerivedType(),
305+
referenceType(pointee(desugarsToOptionalOrDerivedType()))))));
217306
}
218307

219308
template <typename L, typename R>
@@ -275,28 +364,23 @@ BoolValue *getHasValue(Environment &Env, RecordStorageLocation *OptionalLoc) {
275364
return HasValueVal;
276365
}
277366

278-
/// Returns true if and only if `Type` is an optional type.
279-
bool isOptionalType(QualType Type) {
280-
if (!Type->isRecordType())
281-
return false;
282-
const CXXRecordDecl *D = Type->getAsCXXRecordDecl();
283-
return D != nullptr && hasOptionalClassName(*D);
367+
QualType valueTypeFromOptionalDecl(const CXXRecordDecl &RD) {
368+
auto &CTSD = cast<ClassTemplateSpecializationDecl>(RD);
369+
return CTSD.getTemplateArgs()[0].getAsType();
284370
}
285371

286372
/// Returns the number of optional wrappers in `Type`.
287373
///
288374
/// For example, if `Type` is `optional<optional<int>>`, the result of this
289375
/// function will be 2.
290376
int countOptionalWrappers(const ASTContext &ASTCtx, QualType Type) {
291-
if (!isOptionalType(Type))
377+
const CXXRecordDecl *Optional =
378+
getOptionalBaseClass(Type->getAsCXXRecordDecl());
379+
if (Optional == nullptr)
292380
return 0;
293381
return 1 + countOptionalWrappers(
294382
ASTCtx,
295-
cast<ClassTemplateSpecializationDecl>(Type->getAsRecordDecl())
296-
->getTemplateArgs()
297-
.get(0)
298-
.getAsType()
299-
.getDesugaredType(ASTCtx));
383+
valueTypeFromOptionalDecl(*Optional).getDesugaredType(ASTCtx));
300384
}
301385

302386
StorageLocation *getLocBehindPossiblePointer(const Expr &E,
@@ -843,13 +927,7 @@ auto buildDiagnoseMatchSwitch(
843927

844928
ast_matchers::DeclarationMatcher
845929
UncheckedOptionalAccessModel::optionalClassDecl() {
846-
return optionalClass();
847-
}
848-
849-
static QualType valueTypeFromOptionalType(QualType OptionalTy) {
850-
auto *CTSD =
851-
cast<ClassTemplateSpecializationDecl>(OptionalTy->getAsCXXRecordDecl());
852-
return CTSD->getTemplateArgs()[0].getAsType();
930+
return cxxRecordDecl(optionalClass());
853931
}
854932

855933
UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
@@ -858,9 +936,11 @@ UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
858936
TransferMatchSwitch(buildTransferMatchSwitch()) {
859937
Env.getDataflowAnalysisContext().setSyntheticFieldCallback(
860938
[&Ctx](QualType Ty) -> llvm::StringMap<QualType> {
861-
if (!isOptionalType(Ty))
939+
const CXXRecordDecl *Optional =
940+
getOptionalBaseClass(Ty->getAsCXXRecordDecl());
941+
if (Optional == nullptr)
862942
return {};
863-
return {{"value", valueTypeFromOptionalType(Ty)},
943+
return {{"value", valueTypeFromOptionalDecl(*Optional)},
864944
{"has_value", Ctx.BoolTy}};
865945
});
866946
}

clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3383,6 +3383,66 @@ TEST_P(UncheckedOptionalAccessTest, LambdaCaptureStateNotPropagated) {
33833383
}
33843384
)");
33853385
}
3386+
3387+
TEST_P(UncheckedOptionalAccessTest, ClassDerivedFromOptional) {
3388+
ExpectDiagnosticsFor(R"(
3389+
#include "unchecked_optional_access_test.h"
3390+
3391+
struct Derived : public $ns::$optional<int> {};
3392+
3393+
void target(Derived opt) {
3394+
*opt; // [[unsafe]]
3395+
if (opt.has_value())
3396+
*opt;
3397+
3398+
// The same thing, but with a pointer receiver.
3399+
Derived *popt = &opt;
3400+
**popt; // [[unsafe]]
3401+
if (popt->has_value())
3402+
**popt;
3403+
}
3404+
)");
3405+
}
3406+
3407+
TEST_P(UncheckedOptionalAccessTest, ClassTemplateDerivedFromOptional) {
3408+
ExpectDiagnosticsFor(R"(
3409+
#include "unchecked_optional_access_test.h"
3410+
3411+
template <class T>
3412+
struct Derived : public $ns::$optional<T> {};
3413+
3414+
void target(Derived<int> opt) {
3415+
*opt; // [[unsafe]]
3416+
if (opt.has_value())
3417+
*opt;
3418+
3419+
// The same thing, but with a pointer receiver.
3420+
Derived<int> *popt = &opt;
3421+
**popt; // [[unsafe]]
3422+
if (popt->has_value())
3423+
**popt;
3424+
}
3425+
)");
3426+
}
3427+
3428+
TEST_P(UncheckedOptionalAccessTest, ClassDerivedPrivatelyFromOptional) {
3429+
// Classes that derive privately from optional can themselves still call
3430+
// member functions of optional. Check that we model the optional correctly
3431+
// in this situation.
3432+
ExpectDiagnosticsFor(R"(
3433+
#include "unchecked_optional_access_test.h"
3434+
3435+
struct Derived : private $ns::$optional<int> {
3436+
void Method() {
3437+
**this; // [[unsafe]]
3438+
if (this->has_value())
3439+
**this;
3440+
}
3441+
};
3442+
)",
3443+
ast_matchers::hasName("Method"));
3444+
}
3445+
33863446
// FIXME: Add support for:
33873447
// - constructors (copy, move)
33883448
// - assignment operators (default, copy, move)

0 commit comments

Comments
 (0)