Skip to content

[NFC] [ASTMatchers] Share code of forEachArgumentWithParamType with UnsafeBufferUsage #132387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 17 additions & 63 deletions clang/include/clang/ASTMatchers/ASTMatchers.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
#include "clang/AST/TypeLoc.h"
#include "clang/ASTMatchers/ASTMatchersInternal.h"
#include "clang/ASTMatchers/ASTMatchersMacros.h"
#include "clang/ASTMatchers/LowLevelHelpers.h"
#include "clang/Basic/AttrKinds.h"
#include "clang/Basic/ExceptionSpecificationType.h"
#include "clang/Basic/FileManager.h"
Expand Down Expand Up @@ -5211,72 +5212,25 @@ AST_POLYMORPHIC_MATCHER_P2(forEachArgumentWithParamType,
internal::Matcher<Expr>, ArgMatcher,
internal::Matcher<QualType>, ParamMatcher) {
BoundNodesTreeBuilder Result;
// The first argument of an overloaded member operator is the implicit object
// argument of the method which should not be matched against a parameter, so
// we skip over it here.
BoundNodesTreeBuilder Matches;
unsigned ArgIndex =
cxxOperatorCallExpr(
callee(cxxMethodDecl(unless(isExplicitObjectMemberFunction()))))
.matches(Node, Finder, &Matches)
? 1
: 0;
const FunctionProtoType *FProto = nullptr;

if (const auto *Call = dyn_cast<CallExpr>(&Node)) {
if (const auto *Value =
dyn_cast_or_null<ValueDecl>(Call->getCalleeDecl())) {
QualType QT = Value->getType().getCanonicalType();

// This does not necessarily lead to a `FunctionProtoType`,
// e.g. K&R functions do not have a function prototype.
if (QT->isFunctionPointerType())
FProto = QT->getPointeeType()->getAs<FunctionProtoType>();

if (QT->isMemberFunctionPointerType()) {
const auto *MP = QT->getAs<MemberPointerType>();
assert(MP && "Must be member-pointer if its a memberfunctionpointer");
FProto = MP->getPointeeType()->getAs<FunctionProtoType>();
assert(FProto &&
"The call must have happened through a member function "
"pointer");
}
}
}

unsigned ParamIndex = 0;
bool Matched = false;
unsigned NumArgs = Node.getNumArgs();
if (FProto && FProto->isVariadic())
NumArgs = std::min(NumArgs, FProto->getNumParams());

for (; ArgIndex < NumArgs; ++ArgIndex, ++ParamIndex) {
auto ProcessParamAndArg = [&](QualType ParamType, const Expr *Arg) {
BoundNodesTreeBuilder ArgMatches(*Builder);
if (ArgMatcher.matches(*(Node.getArg(ArgIndex)->IgnoreParenCasts()), Finder,
&ArgMatches)) {
BoundNodesTreeBuilder ParamMatches(ArgMatches);
if (!ArgMatcher.matches(*Arg, Finder, &ArgMatches))
return;
BoundNodesTreeBuilder ParamMatches(std::move(ArgMatches));
if (!ParamMatcher.matches(ParamType, Finder, &ParamMatches))
return;
Result.addMatch(ParamMatches);
Matched = true;
return;
};
if (auto *Call = llvm::dyn_cast<CallExpr>(&Node))
matchEachArgumentWithParamType(*Call, ProcessParamAndArg);
else if (auto *Construct = llvm::dyn_cast<CXXConstructExpr>(&Node))
matchEachArgumentWithParamType(*Construct, ProcessParamAndArg);
else
llvm_unreachable("expected CallExpr or CXXConstructExpr");

// This test is cheaper compared to the big matcher in the next if.
// Therefore, please keep this order.
if (FProto && FProto->getNumParams() > ParamIndex) {
QualType ParamType = FProto->getParamType(ParamIndex);
if (ParamMatcher.matches(ParamType, Finder, &ParamMatches)) {
Result.addMatch(ParamMatches);
Matched = true;
continue;
}
}
if (expr(anyOf(cxxConstructExpr(hasDeclaration(cxxConstructorDecl(
hasParameter(ParamIndex, hasType(ParamMatcher))))),
callExpr(callee(functionDecl(
hasParameter(ParamIndex, hasType(ParamMatcher)))))))
.matches(Node, Finder, &ParamMatches)) {
Result.addMatch(ParamMatches);
Matched = true;
continue;
}
}
}
*Builder = std::move(Result);
return Matched;
}
Expand Down
37 changes: 37 additions & 0 deletions clang/include/clang/ASTMatchers/LowLevelHelpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- LowLevelHelpers.h - helpers with pure AST interface ---- *- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Collects a number of helpers that are used by matchers, but can be reused
// outside of them, e.g. when corresponding matchers cannot be used due to
// performance constraints.
//===----------------------------------------------------------------------===//

#ifndef LLVM_CLANG_ASTMATCHERS_LOWLEVELHELPERS_H
#define LLVM_CLANG_ASTMATCHERS_LOWLEVELHELPERS_H

#include "clang/AST/Expr.h"
#include "clang/AST/ExprCXX.h"
#include "clang/AST/Type.h"
#include "llvm/ADT/STLFunctionalExtras.h"

namespace clang {
namespace ast_matchers {

void matchEachArgumentWithParamType(
const CallExpr &Node,
llvm::function_ref<void(QualType /*Param*/, const Expr * /*Arg*/)>
OnParamAndArg);

void matchEachArgumentWithParamType(
const CXXConstructExpr &Node,
llvm::function_ref<void(QualType /*Param*/, const Expr * /*Arg*/)>
OnParamAndArg);

} // namespace ast_matchers
} // namespace clang

#endif
1 change: 1 addition & 0 deletions clang/lib/ASTMatchers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_clang_library(clangASTMatchers
ASTMatchFinder.cpp
ASTMatchersInternal.cpp
GtestMatchers.cpp
LowLevelHelpers.cpp

LINK_LIBS
clangAST
Expand Down
106 changes: 106 additions & 0 deletions clang/lib/ASTMatchers/LowLevelHelpers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
//===- LowLevelHelpers.cpp -------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "clang/ASTMatchers/LowLevelHelpers.h"
#include "clang/AST/Decl.h"
#include "clang/AST/DeclCXX.h"
#include "clang/AST/Expr.h"
#include "clang/AST/ExprCXX.h"
#include <type_traits>

namespace clang {
namespace ast_matchers {

static const FunctionDecl *getCallee(const CXXConstructExpr &D) {
return D.getConstructor();
}
static const FunctionDecl *getCallee(const CallExpr &D) {
return D.getDirectCallee();
}

template <class ExprNode>
static void matchEachArgumentWithParamTypeImpl(
const ExprNode &Node,
llvm::function_ref<void(QualType /*Param*/, const Expr * /*Arg*/)>
OnParamAndArg) {
static_assert(std::is_same_v<CallExpr, ExprNode> ||
std::is_same_v<CXXConstructExpr, ExprNode>);
// The first argument of an overloaded member operator is the implicit object
// argument of the method which should not be matched against a parameter, so
// we skip over it here.
unsigned ArgIndex = 0;
if (const auto *CE = dyn_cast<CXXOperatorCallExpr>(&Node)) {
const auto *MD = dyn_cast_or_null<CXXMethodDecl>(CE->getDirectCallee());
if (MD && !MD->isExplicitObjectMemberFunction()) {
// This is an overloaded operator call.
// We need to skip the first argument, which is the implicit object
// argument of the method which should not be matched against a
// parameter.
++ArgIndex;
}
}

const FunctionProtoType *FProto = nullptr;

if (const auto *Call = dyn_cast<CallExpr>(&Node)) {
if (const auto *Value =
dyn_cast_or_null<ValueDecl>(Call->getCalleeDecl())) {
QualType QT = Value->getType().getCanonicalType();

// This does not necessarily lead to a `FunctionProtoType`,
// e.g. K&R functions do not have a function prototype.
if (QT->isFunctionPointerType())
FProto = QT->getPointeeType()->getAs<FunctionProtoType>();

if (QT->isMemberFunctionPointerType()) {
const auto *MP = QT->getAs<MemberPointerType>();
assert(MP && "Must be member-pointer if its a memberfunctionpointer");
FProto = MP->getPointeeType()->getAs<FunctionProtoType>();
assert(FProto &&
"The call must have happened through a member function "
"pointer");
}
}
}

unsigned ParamIndex = 0;
unsigned NumArgs = Node.getNumArgs();
if (FProto && FProto->isVariadic())
NumArgs = std::min(NumArgs, FProto->getNumParams());

for (; ArgIndex < NumArgs; ++ArgIndex, ++ParamIndex) {
QualType ParamType;
if (FProto && FProto->getNumParams() > ParamIndex)
ParamType = FProto->getParamType(ParamIndex);
else if (const FunctionDecl *FD = getCallee(Node);
FD && FD->getNumParams() > ParamIndex)
ParamType = FD->getParamDecl(ParamIndex)->getType();
else
continue;

OnParamAndArg(ParamType, Node.getArg(ArgIndex)->IgnoreParenCasts());
}
}

void matchEachArgumentWithParamType(
const CallExpr &Node,
llvm::function_ref<void(QualType /*Param*/, const Expr * /*Arg*/)>
OnParamAndArg) {
matchEachArgumentWithParamTypeImpl(Node, OnParamAndArg);
}

void matchEachArgumentWithParamType(
const CXXConstructExpr &Node,
llvm::function_ref<void(QualType /*Param*/, const Expr * /*Arg*/)>
OnParamAndArg) {
matchEachArgumentWithParamTypeImpl(Node, OnParamAndArg);
}

} // namespace ast_matchers

} // namespace clang
95 changes: 2 additions & 93 deletions clang/lib/Analysis/UnsafeBufferUsage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "clang/AST/Stmt.h"
#include "clang/AST/StmtVisitor.h"
#include "clang/AST/Type.h"
#include "clang/ASTMatchers/LowLevelHelpers.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Lex/Lexer.h"
#include "clang/Lex/Preprocessor.h"
Expand Down Expand Up @@ -300,98 +301,6 @@ static void findStmtsInUnspecifiedLvalueContext(
OnResult(BO->getLHS());
}

/// Note: Copied and modified from ASTMatchers.
/// Matches all arguments and their respective types for a \c CallExpr.
/// It is very similar to \c forEachArgumentWithParam but
/// it works on calls through function pointers as well.
///
/// The difference is, that function pointers do not provide access to a
/// \c ParmVarDecl, but only the \c QualType for each argument.
///
/// Given
/// \code
/// void f(int i);
/// int y;
/// f(y);
/// void (*f_ptr)(int) = f;
/// f_ptr(y);
/// \endcode
/// callExpr(
/// forEachArgumentWithParamType(
/// declRefExpr(to(varDecl(hasName("y")))),
/// qualType(isInteger()).bind("type)
/// ))
/// matches f(y) and f_ptr(y)
/// with declRefExpr(...)
/// matching int y
/// and qualType(...)
/// matching int
static void forEachArgumentWithParamType(
const CallExpr &Node,
const llvm::function_ref<void(QualType /*Param*/, const Expr * /*Arg*/)>
OnParamAndArg) {
// The first argument of an overloaded member operator is the implicit object
// argument of the method which should not be matched against a parameter, so
// we skip over it here.
unsigned ArgIndex = 0;
if (const auto *CE = dyn_cast<CXXOperatorCallExpr>(&Node)) {
const auto *MD = dyn_cast_or_null<CXXMethodDecl>(CE->getDirectCallee());
if (MD && !MD->isExplicitObjectMemberFunction()) {
// This is an overloaded operator call.
// We need to skip the first argument, which is the implicit object
// argument of the method which should not be matched against a
// parameter.
++ArgIndex;
}
}

const FunctionProtoType *FProto = nullptr;

if (const auto *Call = dyn_cast<CallExpr>(&Node)) {
if (const auto *Value =
dyn_cast_or_null<ValueDecl>(Call->getCalleeDecl())) {
QualType QT = Value->getType().getCanonicalType();

// This does not necessarily lead to a `FunctionProtoType`,
// e.g. K&R functions do not have a function prototype.
if (QT->isFunctionPointerType())
FProto = QT->getPointeeType()->getAs<FunctionProtoType>();

if (QT->isMemberFunctionPointerType()) {
const auto *MP = QT->getAs<MemberPointerType>();
assert(MP && "Must be member-pointer if its a memberfunctionpointer");
FProto = MP->getPointeeType()->getAs<FunctionProtoType>();
assert(FProto &&
"The call must have happened through a member function "
"pointer");
}
}
}

unsigned ParamIndex = 0;
unsigned NumArgs = Node.getNumArgs();
if (FProto && FProto->isVariadic())
NumArgs = std::min(NumArgs, FProto->getNumParams());

const auto GetParamType =
[&FProto, &Node](unsigned int ParamIndex) -> std::optional<QualType> {
if (FProto && FProto->getNumParams() > ParamIndex) {
return FProto->getParamType(ParamIndex);
}
const auto *FD = Node.getDirectCallee();
if (FD && FD->getNumParams() > ParamIndex) {
return FD->getParamDecl(ParamIndex)->getType();
}
return std::nullopt;
};

for (; ArgIndex < NumArgs; ++ArgIndex, ++ParamIndex) {
auto ParamType = GetParamType(ParamIndex);
if (ParamType)
OnParamAndArg(*ParamType, Node.getArg(ArgIndex)->IgnoreParenCasts());
}
}

// Finds any expression `e` such that `InnerMatcher` matches `e` and
// `e` is in an Unspecified Pointer Context (UPC).
static void findStmtsInUnspecifiedPointerContext(
Expand All @@ -408,7 +317,7 @@ static void findStmtsInUnspecifiedPointerContext(
if (const auto *FnDecl = CE->getDirectCallee();
FnDecl && FnDecl->hasAttr<UnsafeBufferUsageAttr>())
return;
forEachArgumentWithParamType(
ast_matchers::matchEachArgumentWithParamType(
*CE, [&InnerMatcher](QualType Type, const Expr *Arg) {
if (Type->isAnyPointerType())
InnerMatcher(Arg);
Expand Down