Skip to content

[OpenACC] Private Clause on Compute Constructs #90521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 59 additions & 26 deletions clang/include/clang/AST/OpenACCClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,51 +156,50 @@ class OpenACCSelfClause : public OpenACCClauseWithCondition {
Expr *ConditionExpr, SourceLocation EndLoc);
};

/// Represents a clause that has one or more IntExprs. It does not own the
/// IntExprs, but provides 'children' and other accessors.
class OpenACCClauseWithIntExprs : public OpenACCClauseWithParams {
MutableArrayRef<Expr *> IntExprs;
/// Represents a clause that has one or more expressions associated with it.
class OpenACCClauseWithExprs : public OpenACCClauseWithParams {
MutableArrayRef<Expr *> Exprs;

protected:
OpenACCClauseWithIntExprs(OpenACCClauseKind K, SourceLocation BeginLoc,
SourceLocation LParenLoc, SourceLocation EndLoc)
OpenACCClauseWithExprs(OpenACCClauseKind K, SourceLocation BeginLoc,
SourceLocation LParenLoc, SourceLocation EndLoc)
: OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc) {}

/// Used only for initialization, the leaf class can initialize this to
/// trailing storage.
void setIntExprs(MutableArrayRef<Expr *> NewIntExprs) {
assert(IntExprs.empty() && "Cannot change IntExprs list");
IntExprs = NewIntExprs;
void setExprs(MutableArrayRef<Expr *> NewExprs) {
assert(Exprs.empty() && "Cannot change Exprs list");
Exprs = NewExprs;
}

/// Gets the entire list of integer expressions, but leave it to the
/// Gets the entire list of expressions, but leave it to the
/// individual clauses to expose this how they'd like.
llvm::ArrayRef<Expr *> getIntExprs() const { return IntExprs; }
llvm::ArrayRef<Expr *> getExprs() const { return Exprs; }

public:
child_range children() {
return child_range(reinterpret_cast<Stmt **>(IntExprs.begin()),
reinterpret_cast<Stmt **>(IntExprs.end()));
return child_range(reinterpret_cast<Stmt **>(Exprs.begin()),
reinterpret_cast<Stmt **>(Exprs.end()));
}

const_child_range children() const {
child_range Children =
const_cast<OpenACCClauseWithIntExprs *>(this)->children();
const_cast<OpenACCClauseWithExprs *>(this)->children();
return const_child_range(Children.begin(), Children.end());
}
};

class OpenACCNumGangsClause final
: public OpenACCClauseWithIntExprs,
: public OpenACCClauseWithExprs,
public llvm::TrailingObjects<OpenACCNumGangsClause, Expr *> {

OpenACCNumGangsClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
ArrayRef<Expr *> IntExprs, SourceLocation EndLoc)
: OpenACCClauseWithIntExprs(OpenACCClauseKind::NumGangs, BeginLoc,
LParenLoc, EndLoc) {
: OpenACCClauseWithExprs(OpenACCClauseKind::NumGangs, BeginLoc, LParenLoc,
EndLoc) {
std::uninitialized_copy(IntExprs.begin(), IntExprs.end(),
getTrailingObjects<Expr *>());
setIntExprs(MutableArrayRef(getTrailingObjects<Expr *>(), IntExprs.size()));
setExprs(MutableArrayRef(getTrailingObjects<Expr *>(), IntExprs.size()));
}

public:
Expand All @@ -209,35 +208,35 @@ class OpenACCNumGangsClause final
ArrayRef<Expr *> IntExprs, SourceLocation EndLoc);

llvm::ArrayRef<Expr *> getIntExprs() {
return OpenACCClauseWithIntExprs::getIntExprs();
return OpenACCClauseWithExprs::getExprs();
}

llvm::ArrayRef<Expr *> getIntExprs() const {
return OpenACCClauseWithIntExprs::getIntExprs();
return OpenACCClauseWithExprs::getExprs();
}
};

/// Represents one of a handful of clauses that have a single integer
/// expression.
class OpenACCClauseWithSingleIntExpr : public OpenACCClauseWithIntExprs {
class OpenACCClauseWithSingleIntExpr : public OpenACCClauseWithExprs {
Expr *IntExpr;

protected:
OpenACCClauseWithSingleIntExpr(OpenACCClauseKind K, SourceLocation BeginLoc,
SourceLocation LParenLoc, Expr *IntExpr,
SourceLocation EndLoc)
: OpenACCClauseWithIntExprs(K, BeginLoc, LParenLoc, EndLoc),
: OpenACCClauseWithExprs(K, BeginLoc, LParenLoc, EndLoc),
IntExpr(IntExpr) {
setIntExprs(MutableArrayRef<Expr *>{&this->IntExpr, 1});
setExprs(MutableArrayRef<Expr *>{&this->IntExpr, 1});
}

public:
bool hasIntExpr() const { return !getIntExprs().empty(); }
bool hasIntExpr() const { return !getExprs().empty(); }
const Expr *getIntExpr() const {
return hasIntExpr() ? getIntExprs()[0] : nullptr;
return hasIntExpr() ? getExprs()[0] : nullptr;
}

Expr *getIntExpr() { return hasIntExpr() ? getIntExprs()[0] : nullptr; };
Expr *getIntExpr() { return hasIntExpr() ? getExprs()[0] : nullptr; };
};

class OpenACCNumWorkersClause : public OpenACCClauseWithSingleIntExpr {
Expand All @@ -261,6 +260,40 @@ class OpenACCVectorLengthClause : public OpenACCClauseWithSingleIntExpr {
Expr *IntExpr, SourceLocation EndLoc);
};

/// Represents a clause with one or more 'var' objects, represented as an expr,
/// as its arguments. Var-list is expected to be stored in trailing storage.
/// For now, we're just storing the original expression in its entirety, unlike
/// OMP which has to do a bunch of work to create a private.
class OpenACCClauseWithVarList : public OpenACCClauseWithExprs {
protected:
OpenACCClauseWithVarList(OpenACCClauseKind K, SourceLocation BeginLoc,
SourceLocation LParenLoc, SourceLocation EndLoc)
: OpenACCClauseWithExprs(K, BeginLoc, LParenLoc, EndLoc) {}

public:
ArrayRef<Expr *> getVarList() { return getExprs(); }
ArrayRef<Expr *> getVarList() const { return getExprs(); }
};

class OpenACCPrivateClause final
: public OpenACCClauseWithVarList,
public llvm::TrailingObjects<OpenACCPrivateClause, Expr *> {

OpenACCPrivateClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
ArrayRef<Expr *> VarList, SourceLocation EndLoc)
: OpenACCClauseWithVarList(OpenACCClauseKind::Private, BeginLoc,
LParenLoc, EndLoc) {
std::uninitialized_copy(VarList.begin(), VarList.end(),
getTrailingObjects<Expr *>());
setExprs(MutableArrayRef(getTrailingObjects<Expr *>(), VarList.size()));
}

public:
static OpenACCPrivateClause *
Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
ArrayRef<Expr *> VarList, SourceLocation EndLoc);
};

template <class Impl> class OpenACCClauseVisitor {
Impl &getDerived() { return static_cast<Impl &>(*this); }

Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12305,4 +12305,7 @@ def err_acc_num_gangs_num_args
"OpenACC 'num_gangs' "
"%select{|clause: '%1' directive expects maximum of %2, %3 were "
"provided}0">;
def err_acc_not_a_var_ref
: Error<"OpenACC variable is not a valid variable name, sub-array, array "
"element, or composite variable member">;
} // end of sema component.
1 change: 1 addition & 0 deletions clang/include/clang/Basic/OpenACCClauses.def
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ VISIT_CLAUSE(If)
VISIT_CLAUSE(Self)
VISIT_CLAUSE(NumGangs)
VISIT_CLAUSE(NumWorkers)
VISIT_CLAUSE(Private)
VISIT_CLAUSE(VectorLength)

#undef VISIT_CLAUSE
9 changes: 5 additions & 4 deletions clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -3645,11 +3645,12 @@ class Parser : public CodeCompletionHandler {
ExprResult ParseOpenACCIDExpression();
/// Parses the variable list for the `cache` construct.
void ParseOpenACCCacheVarList();

using OpenACCVarParseResult = std::pair<ExprResult, OpenACCParseCanContinue>;
/// Parses a single variable in a variable list for OpenACC.
bool ParseOpenACCVar();
/// Parses the variable list for the variety of clauses that take a var-list,
/// including the optional Special Token listed for some,based on clause type.
bool ParseOpenACCClauseVarList(OpenACCClauseKind Kind);
OpenACCVarParseResult ParseOpenACCVar();
/// Parses the variable list for the variety of places that take a var-list.
llvm::SmallVector<Expr *> ParseOpenACCVarList();
/// Parses any parameters for an OpenACC Clause, including required/optional
/// parens.
OpenACCClauseParseResult
Expand Down
34 changes: 32 additions & 2 deletions clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ class SemaOpenACC : public SemaBase {
SmallVector<Expr *> IntExprs;
};

struct VarListDetails {
SmallVector<Expr *> VarList;
};

std::variant<std::monostate, DefaultDetails, ConditionDetails,
IntExprDetails>
IntExprDetails, VarListDetails>
Details = std::monostate{};

public:
Expand Down Expand Up @@ -112,6 +116,16 @@ class SemaOpenACC : public SemaBase {
return const_cast<OpenACCParsedClause *>(this)->getIntExprs();
}

ArrayRef<Expr *> getVarList() {
assert(ClauseKind == OpenACCClauseKind::Private &&
"Parsed clause kind does not have a var-list");
return std::get<VarListDetails>(Details).VarList;
}

ArrayRef<Expr *> getVarList() const {
return const_cast<OpenACCParsedClause *>(this)->getVarList();
}

void setLParenLoc(SourceLocation EndLoc) { LParenLoc = EndLoc; }
void setEndLoc(SourceLocation EndLoc) { ClauseRange.setEnd(EndLoc); }

Expand Down Expand Up @@ -147,7 +161,19 @@ class SemaOpenACC : public SemaBase {
ClauseKind == OpenACCClauseKind::NumWorkers ||
ClauseKind == OpenACCClauseKind::VectorLength) &&
"Parsed clause kind does not have a int exprs");
Details = IntExprDetails{IntExprs};
Details = IntExprDetails{std::move(IntExprs)};
}

void setVarListDetails(ArrayRef<Expr *> VarList) {
assert(ClauseKind == OpenACCClauseKind::Private &&
"Parsed clause kind does not have a var-list");
Details = VarListDetails{{VarList.begin(), VarList.end()}};
}

void setVarListDetails(llvm::SmallVector<Expr *> &&VarList) {
assert(ClauseKind == OpenACCClauseKind::Private &&
"Parsed clause kind does not have a var-list");
Details = VarListDetails{std::move(VarList)};
}
};

Expand Down Expand Up @@ -194,6 +220,10 @@ class SemaOpenACC : public SemaBase {
ExprResult ActOnIntExpr(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
SourceLocation Loc, Expr *IntExpr);

/// Called when encountering a 'var' for OpenACC, ensures it is actually a
/// declaration reference to a variable of the correct type.
ExprResult ActOnVar(Expr *VarExpr);

/// Checks and creates an Array Section used in an OpenACC construct/clause.
ExprResult ActOnArraySectionExpr(Expr *Base, SourceLocation LBLoc,
Expr *LowerBound,
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Serialization/ASTRecordReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ class ASTRecordReader
/// Read an OpenMP children, advancing Idx.
void readOMPChildren(OMPChildren *Data);

/// Read a list of Exprs used for a var-list.
llvm::SmallVector<Expr *> readOpenACCVarList();

/// Read an OpenACC clause, advancing Idx.
OpenACCClause *readOpenACCClause();

Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Serialization/ASTRecordWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#define LLVM_CLANG_SERIALIZATION_ASTRECORDWRITER_H

#include "clang/AST/AbstractBasicWriter.h"
#include "clang/AST/OpenACCClause.h"
#include "clang/AST/OpenMPClause.h"
#include "clang/Serialization/ASTWriter.h"
#include "clang/Serialization/SourceLocationEncoding.h"
Expand Down Expand Up @@ -293,6 +294,8 @@ class ASTRecordWriter
/// Writes data related to the OpenMP directives.
void writeOMPChildren(OMPChildren *Data);

void writeOpenACCVarList(const OpenACCClauseWithVarList *C);

/// Writes out a single OpenACC Clause.
void writeOpenACCClause(const OpenACCClause *C);

Expand Down
24 changes: 24 additions & 0 deletions clang/lib/AST/OpenACCClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,24 @@ OpenACCNumGangsClause *OpenACCNumGangsClause::Create(const ASTContext &C,
return new (Mem) OpenACCNumGangsClause(BeginLoc, LParenLoc, IntExprs, EndLoc);
}

OpenACCPrivateClause *OpenACCPrivateClause::Create(const ASTContext &C,
SourceLocation BeginLoc,
SourceLocation LParenLoc,
ArrayRef<Expr *> VarList,
SourceLocation EndLoc) {
void *Mem = C.Allocate(
OpenACCPrivateClause::totalSizeToAlloc<Expr *>(VarList.size()));
return new (Mem) OpenACCPrivateClause(BeginLoc, LParenLoc, VarList, EndLoc);
}

// ValueDecl *getDeclFromExpr(Expr *RefExpr) {
// //RefExpr = RefExpr->IgnoreParenImpCasts();
//
// ////while (isa<ArraySubscriptExpr, ArraySectionExpr>(RefExpr)) {
// ////}
// // TODO:
// }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

woops! Thanks, good catch!

//===----------------------------------------------------------------------===//
// OpenACC clauses printing methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -166,3 +184,9 @@ void OpenACCClausePrinter::VisitVectorLengthClause(
const OpenACCVectorLengthClause &C) {
OS << "vector_length(" << C.getIntExpr() << ")";
}

void OpenACCClausePrinter::VisitPrivateClause(const OpenACCPrivateClause &C) {
OS << "private(";
llvm::interleaveComma(C.getVarList(), OS);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a test for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont! I THINK this uses 'ast-print', so give me a minute and I'll put one on this review that tests all of the variants here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test added :) A bit of exploration, plus a patch to fix the rest (upstreamed as review after commit), and now up to date here :)

OS << ")";
}
6 changes: 6 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2509,6 +2509,12 @@ void OpenACCClauseProfiler::VisitNumWorkersClause(
Profiler.VisitStmt(Clause.getIntExpr());
}

void OpenACCClauseProfiler::VisitPrivateClause(
const OpenACCPrivateClause &Clause) {
for (auto *E : Clause.getVarList())
Profiler.VisitStmt(E);
}

void OpenACCClauseProfiler::VisitVectorLengthClause(
const OpenACCVectorLengthClause &Clause) {
assert(Clause.hasIntExpr() &&
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
case OpenACCClauseKind::Self:
case OpenACCClauseKind::NumGangs:
case OpenACCClauseKind::NumWorkers:
case OpenACCClauseKind::Private:
case OpenACCClauseKind::VectorLength:
// The condition expression will be printed as a part of the 'children',
// but print 'clause' here so it is clear what is happening from the dump.
Expand Down
Loading
Loading