Skip to content

Commit f5cb9cb

Browse files
[ASTMatchers] Fix classIsDerivedFrom for recusrive cases (#67307)
By ensuring the base is only visited once. This avoids infinite recursion and expontential running times in some corner cases. See the added tests for examples. Apart from the cases that caused infinite recursion and used to crash, this change is an NFC and results of the matchers are the same.
1 parent 3e97db8 commit f5cb9cb

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

clang/lib/ASTMatchers/ASTMatchFinder.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
#include "clang/ASTMatchers/ASTMatchFinder.h"
1919
#include "clang/AST/ASTConsumer.h"
2020
#include "clang/AST/ASTContext.h"
21+
#include "clang/AST/DeclCXX.h"
2122
#include "clang/AST/RecursiveASTVisitor.h"
2223
#include "llvm/ADT/DenseMap.h"
24+
#include "llvm/ADT/SmallPtrSet.h"
2325
#include "llvm/ADT/StringMap.h"
2426
#include "llvm/Support/PrettyStackTrace.h"
2527
#include "llvm/Support/Timer.h"
@@ -651,11 +653,20 @@ class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
651653
BoundNodesTreeBuilder *Builder,
652654
bool Directly) override;
653655

656+
private:
657+
bool
658+
classIsDerivedFromImpl(const CXXRecordDecl *Declaration,
659+
const Matcher<NamedDecl> &Base,
660+
BoundNodesTreeBuilder *Builder, bool Directly,
661+
llvm::SmallPtrSetImpl<const CXXRecordDecl *> &Visited);
662+
663+
public:
654664
bool objcClassIsDerivedFrom(const ObjCInterfaceDecl *Declaration,
655665
const Matcher<NamedDecl> &Base,
656666
BoundNodesTreeBuilder *Builder,
657667
bool Directly) override;
658668

669+
public:
659670
// Implements ASTMatchFinder::matchesChildOf.
660671
bool matchesChildOf(const DynTypedNode &Node, ASTContext &Ctx,
661672
const DynTypedMatcher &Matcher,
@@ -1361,8 +1372,18 @@ bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration,
13611372
const Matcher<NamedDecl> &Base,
13621373
BoundNodesTreeBuilder *Builder,
13631374
bool Directly) {
1375+
llvm::SmallPtrSet<const CXXRecordDecl *, 8> Visited;
1376+
return classIsDerivedFromImpl(Declaration, Base, Builder, Directly, Visited);
1377+
}
1378+
1379+
bool MatchASTVisitor::classIsDerivedFromImpl(
1380+
const CXXRecordDecl *Declaration, const Matcher<NamedDecl> &Base,
1381+
BoundNodesTreeBuilder *Builder, bool Directly,
1382+
llvm::SmallPtrSetImpl<const CXXRecordDecl *> &Visited) {
13641383
if (!Declaration->hasDefinition())
13651384
return false;
1385+
if (!Visited.insert(Declaration).second)
1386+
return false;
13661387
for (const auto &It : Declaration->bases()) {
13671388
const Type *TypeNode = It.getType().getTypePtr();
13681389

@@ -1384,7 +1405,8 @@ bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration,
13841405
*Builder = std::move(Result);
13851406
return true;
13861407
}
1387-
if (!Directly && classIsDerivedFrom(ClassDecl, Base, Builder, Directly))
1408+
if (!Directly &&
1409+
classIsDerivedFromImpl(ClassDecl, Base, Builder, Directly, Visited))
13881410
return true;
13891411
}
13901412
return false;

clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2369,6 +2369,80 @@ TEST_P(ASTMatchersTest, LambdaCaptureTest_BindsToCaptureOfReferenceType) {
23692369
"}", matcher));
23702370
}
23712371

2372+
TEST_P(ASTMatchersTest, IsDerivedFromRecursion) {
2373+
if (!GetParam().isCXX11OrLater())
2374+
return;
2375+
2376+
// Check we don't crash on cycles in the traversal and inheritance hierarchy.
2377+
// Clang will normally enforce there are no cycles, but matchers opted to
2378+
// traverse primary template for dependent specializations, spuriously
2379+
// creating the cycles.
2380+
DeclarationMatcher matcher = cxxRecordDecl(isDerivedFrom("X"));
2381+
EXPECT_TRUE(notMatches(R"cpp(
2382+
template <typename T1, typename T2>
2383+
struct M;
2384+
2385+
template <typename T1>
2386+
struct M<T1, void> {};
2387+
2388+
template <typename T1, typename T2>
2389+
struct L : M<T1, T2> {};
2390+
2391+
template <typename T1, typename T2>
2392+
struct M : L<M<T1, T2>, M<T1, T2>> {};
2393+
)cpp",
2394+
matcher));
2395+
2396+
// Check the running time is not exponential. The number of subojects to
2397+
// traverse grows as fibonacci numbers even though the number of bases to
2398+
// traverse is quadratic.
2399+
// The test will hang if implementation of matchers traverses all subojects.
2400+
EXPECT_TRUE(notMatches(R"cpp(
2401+
template <class T> struct A0 {};
2402+
template <class T> struct A1 : A0<T> {};
2403+
template <class T> struct A2 : A1<T>, A0<T> {};
2404+
template <class T> struct A3 : A2<T>, A1<T> {};
2405+
template <class T> struct A4 : A3<T>, A2<T> {};
2406+
template <class T> struct A5 : A4<T>, A3<T> {};
2407+
template <class T> struct A6 : A5<T>, A4<T> {};
2408+
template <class T> struct A7 : A6<T>, A5<T> {};
2409+
template <class T> struct A8 : A7<T>, A6<T> {};
2410+
template <class T> struct A9 : A8<T>, A7<T> {};
2411+
template <class T> struct A10 : A9<T>, A8<T> {};
2412+
template <class T> struct A11 : A10<T>, A9<T> {};
2413+
template <class T> struct A12 : A11<T>, A10<T> {};
2414+
template <class T> struct A13 : A12<T>, A11<T> {};
2415+
template <class T> struct A14 : A13<T>, A12<T> {};
2416+
template <class T> struct A15 : A14<T>, A13<T> {};
2417+
template <class T> struct A16 : A15<T>, A14<T> {};
2418+
template <class T> struct A17 : A16<T>, A15<T> {};
2419+
template <class T> struct A18 : A17<T>, A16<T> {};
2420+
template <class T> struct A19 : A18<T>, A17<T> {};
2421+
template <class T> struct A20 : A19<T>, A18<T> {};
2422+
template <class T> struct A21 : A20<T>, A19<T> {};
2423+
template <class T> struct A22 : A21<T>, A20<T> {};
2424+
template <class T> struct A23 : A22<T>, A21<T> {};
2425+
template <class T> struct A24 : A23<T>, A22<T> {};
2426+
template <class T> struct A25 : A24<T>, A23<T> {};
2427+
template <class T> struct A26 : A25<T>, A24<T> {};
2428+
template <class T> struct A27 : A26<T>, A25<T> {};
2429+
template <class T> struct A28 : A27<T>, A26<T> {};
2430+
template <class T> struct A29 : A28<T>, A27<T> {};
2431+
template <class T> struct A30 : A29<T>, A28<T> {};
2432+
template <class T> struct A31 : A30<T>, A29<T> {};
2433+
template <class T> struct A32 : A31<T>, A30<T> {};
2434+
template <class T> struct A33 : A32<T>, A31<T> {};
2435+
template <class T> struct A34 : A33<T>, A32<T> {};
2436+
template <class T> struct A35 : A34<T>, A33<T> {};
2437+
template <class T> struct A36 : A35<T>, A34<T> {};
2438+
template <class T> struct A37 : A36<T>, A35<T> {};
2439+
template <class T> struct A38 : A37<T>, A36<T> {};
2440+
template <class T> struct A39 : A38<T>, A37<T> {};
2441+
template <class T> struct A40 : A39<T>, A38<T> {};
2442+
)cpp",
2443+
matcher));
2444+
}
2445+
23722446
TEST(ASTMatchersTestObjC, ObjCMessageCalees) {
23732447
StatementMatcher MessagingFoo =
23742448
objcMessageExpr(callee(objcMethodDecl(hasName("foo"))));

0 commit comments

Comments
 (0)