Skip to content

[OpenMP] atomic compare weak : Parser & AST support #79475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions clang/include/clang/AST/OpenMPClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -2513,6 +2513,46 @@ class OMPRelaxedClause final : public OMPClause {
}
};

/// This represents 'weak' clause in the '#pragma omp atomic'
/// directives.
///
/// \code
/// #pragma omp atomic compare weak
/// \endcode
/// In this example directive '#pragma omp atomic' has 'weak' clause.
class OMPWeakClause final : public OMPClause {
public:
/// Build 'weak' clause.
///
/// \param StartLoc Starting location of the clause.
/// \param EndLoc Ending location of the clause.
OMPWeakClause(SourceLocation StartLoc, SourceLocation EndLoc)
: OMPClause(llvm::omp::OMPC_weak, StartLoc, EndLoc) {}

/// Build an empty clause.
OMPWeakClause()
: OMPClause(llvm::omp::OMPC_weak, SourceLocation(), SourceLocation()) {}

child_range children() {
return child_range(child_iterator(), child_iterator());
}

const_child_range children() const {
return const_child_range(const_child_iterator(), const_child_iterator());
}

child_range used_children() {
return child_range(child_iterator(), child_iterator());
}
const_child_range used_children() const {
return const_child_range(const_child_iterator(), const_child_iterator());
}

static bool classof(const OMPClause *T) {
return T->getClauseKind() == llvm::omp::OMPC_weak;
}
};

/// This represents 'fail' clause in the '#pragma omp atomic'
/// directive.
///
Expand Down
5 changes: 5 additions & 0 deletions clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3436,6 +3436,11 @@ bool RecursiveASTVisitor<Derived>::VisitOMPRelaxedClause(OMPRelaxedClause *) {
return true;
}

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPWeakClause(OMPWeakClause *) {
return true;
}

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPThreadsClause(OMPThreadsClause *) {
return true;
Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -11042,7 +11042,8 @@ def note_omp_atomic_compare: Note<
"expect lvalue for result value|expect scalar value|expect integer value|unexpected 'else' statement|expect '==' operator|expect an assignment statement 'v = x'|"
"expect a 'if' statement|expect no more than two statements|expect a compound statement|expect 'else' statement|expect a form 'r = x == e; if (r) ...'}0">;
def err_omp_atomic_fail_wrong_or_no_clauses : Error<"expected a memory order clause">;
def err_omp_atomic_fail_no_compare : Error<"expected 'compare' clause with the 'fail' modifier">;
def err_omp_atomic_no_compare : Error<"expected 'compare' clause with the '%0' modifier">;
def err_omp_atomic_weak_no_equality : Error<"expected '==' operator for 'weak' clause">;
def err_omp_atomic_several_clauses : Error<
"directive '#pragma omp atomic' cannot contain more than one 'read', 'write', 'update', 'capture', or 'compare' clause">;
def err_omp_several_mem_order_clauses : Error<
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -12377,6 +12377,9 @@ class Sema final {
/// Called on well-formed 'relaxed' clause.
OMPClause *ActOnOpenMPRelaxedClause(SourceLocation StartLoc,
SourceLocation EndLoc);
/// Called on well-formed 'weak' clause.
OMPClause *ActOnOpenMPWeakClause(SourceLocation StartLoc,
SourceLocation EndLoc);

/// Called on well-formed 'init' clause.
OMPClause *
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/OpenMPClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1957,6 +1957,8 @@ void OMPClausePrinter::VisitOMPRelaxedClause(OMPRelaxedClause *) {
OS << "relaxed";
}

void OMPClausePrinter::VisitOMPWeakClause(OMPWeakClause *) { OS << "weak"; }

void OMPClausePrinter::VisitOMPThreadsClause(OMPThreadsClause *) {
OS << "threads";
}
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,8 @@ void OMPClauseProfiler::VisitOMPReleaseClause(const OMPReleaseClause *) {}

void OMPClauseProfiler::VisitOMPRelaxedClause(const OMPRelaxedClause *) {}

void OMPClauseProfiler::VisitOMPWeakClause(const OMPWeakClause *) {}

void OMPClauseProfiler::VisitOMPThreadsClause(const OMPThreadsClause *) {}

void OMPClauseProfiler::VisitOMPSIMDClause(const OMPSIMDClause *) {}
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/CodeGen/CGStmtOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6546,6 +6546,9 @@ void CodeGenFunction::EmitOMPAtomicDirective(const OMPAtomicDirective &S) {
// Find first clause (skip seq_cst|acq_rel|aqcuire|release|relaxed clause,
// if it is first).
OpenMPClauseKind K = C->getClauseKind();
// TBD
if (K == OMPC_weak)
return;
if (K == OMPC_seq_cst || K == OMPC_acq_rel || K == OMPC_acquire ||
K == OMPC_release || K == OMPC_relaxed || K == OMPC_hint)
continue;
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Parse/ParseOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3314,6 +3314,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_acquire:
case OMPC_release:
case OMPC_relaxed:
case OMPC_weak:
case OMPC_threads:
case OMPC_simd:
case OMPC_nogroup:
Expand Down
33 changes: 32 additions & 1 deletion clang/lib/Sema/SemaOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12708,9 +12708,11 @@ StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef<OMPClause *> Clauses,
}
break;
}
case OMPC_weak:
case OMPC_fail: {
if (!EncounteredAtomicKinds.contains(OMPC_compare)) {
Diag(C->getBeginLoc(), diag::err_omp_atomic_fail_no_compare)
Diag(C->getBeginLoc(), diag::err_omp_atomic_no_compare)
<< getOpenMPClauseName(C->getClauseKind())
<< SourceRange(C->getBeginLoc(), C->getEndLoc());
return StmtError();
}
Expand Down Expand Up @@ -13202,6 +13204,27 @@ StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef<OMPClause *> Clauses,
E = Checker.getE();
D = Checker.getD();
CE = Checker.getCond();
// The weak clause may only appear if the resulting atomic operation is
// an atomic conditional update for which the comparison tests for
// equality. It was not possible to do this check in
// OpenMPAtomicCompareChecker::checkStmt() as the check for OMPC_weak
// could not be performed (Clauses are not available).
auto *It = find_if(Clauses, [](OMPClause *C) {
return C->getClauseKind() == llvm::omp::Clause::OMPC_weak;
});
if (It != Clauses.end()) {
auto *Cond = dyn_cast<BinaryOperator>(CE);
if (Cond->getOpcode() != BO_EQ) {
ErrorInfo.Error = Checker.ErrorTy::NotAnAssignment;
ErrorInfo.ErrorLoc = Cond->getExprLoc();
ErrorInfo.NoteLoc = Cond->getOperatorLoc();
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();

Diag(ErrorInfo.ErrorLoc, diag::err_omp_atomic_weak_no_equality)
<< ErrorInfo.ErrorRange;
return StmtError();
}
}
// We reuse IsXLHSInRHSPart to tell if it is in the form 'x ordop expr'.
IsXLHSInRHSPart = Checker.isXBinopExpr();
}
Expand Down Expand Up @@ -17593,6 +17616,9 @@ OMPClause *Sema::ActOnOpenMPClause(OpenMPClauseKind Kind,
case OMPC_relaxed:
Res = ActOnOpenMPRelaxedClause(StartLoc, EndLoc);
break;
case OMPC_weak:
Res = ActOnOpenMPWeakClause(StartLoc, EndLoc);
break;
case OMPC_threads:
Res = ActOnOpenMPThreadsClause(StartLoc, EndLoc);
break;
Expand Down Expand Up @@ -17781,6 +17807,11 @@ OMPClause *Sema::ActOnOpenMPRelaxedClause(SourceLocation StartLoc,
return new (Context) OMPRelaxedClause(StartLoc, EndLoc);
}

OMPClause *Sema::ActOnOpenMPWeakClause(SourceLocation StartLoc,
SourceLocation EndLoc) {
return new (Context) OMPWeakClause(StartLoc, EndLoc);
}

OMPClause *Sema::ActOnOpenMPThreadsClause(SourceLocation StartLoc,
SourceLocation EndLoc) {
return new (Context) OMPThreadsClause(StartLoc, EndLoc);
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/Sema/TreeTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -10075,6 +10075,12 @@ TreeTransform<Derived>::TransformOMPRelaxedClause(OMPRelaxedClause *C) {
return C;
}

template <typename Derived>
OMPClause *TreeTransform<Derived>::TransformOMPWeakClause(OMPWeakClause *C) {
// No need to rebuild this clause, no template-dependent parameters.
return C;
}

template <typename Derived>
OMPClause *
TreeTransform<Derived>::TransformOMPThreadsClause(OMPThreadsClause *C) {
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/Serialization/ASTReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10339,6 +10339,9 @@ OMPClause *OMPClauseReader::readClause() {
case llvm::omp::OMPC_relaxed:
C = new (Context) OMPRelaxedClause();
break;
case llvm::omp::OMPC_weak:
C = new (Context) OMPWeakClause();
break;
case llvm::omp::OMPC_threads:
C = new (Context) OMPThreadsClause();
break;
Expand Down Expand Up @@ -10737,6 +10740,8 @@ void OMPClauseReader::VisitOMPReleaseClause(OMPReleaseClause *) {}

void OMPClauseReader::VisitOMPRelaxedClause(OMPRelaxedClause *) {}

void OMPClauseReader::VisitOMPWeakClause(OMPWeakClause *) {}

void OMPClauseReader::VisitOMPThreadsClause(OMPThreadsClause *) {}

void OMPClauseReader::VisitOMPSIMDClause(OMPSIMDClause *) {}
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/Serialization/ASTWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6662,6 +6662,8 @@ void OMPClauseWriter::VisitOMPReleaseClause(OMPReleaseClause *) {}

void OMPClauseWriter::VisitOMPRelaxedClause(OMPRelaxedClause *) {}

void OMPClauseWriter::VisitOMPWeakClause(OMPWeakClause *) {}

void OMPClauseWriter::VisitOMPThreadsClause(OMPThreadsClause *) {}

void OMPClauseWriter::VisitOMPSIMDClause(OMPSIMDClause *) {}
Expand Down
8 changes: 8 additions & 0 deletions clang/test/OpenMP/atomic_ast_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ T foo(T argc) {
{ if (a < c) { a = c; } }
#pragma omp atomic compare fail(seq_cst)
{ if (a < c) { a = c; } }
#pragma omp atomic compare seq_cst weak
{ if(a == b) { a = c; } }
#endif
return T();
}
Expand Down Expand Up @@ -1111,6 +1113,8 @@ int main(int argc, char **argv) {
if(a < b) { a = b; }
#pragma omp atomic compare fail(seq_cst)
if(a < b) { a = b; }
#pragma omp atomic compare seq_cst weak
if(a == b) { a = c; }
#endif
// CHECK-NEXT: #pragma omp atomic
// CHECK-NEXT: a++;
Expand Down Expand Up @@ -1453,6 +1457,10 @@ int main(int argc, char **argv) {
// CHECK-51-NEXT: if (a < b) {
// CHECK-51-NEXT: a = b;
// CHECK-51-NEXT: }
// CHECK-51-NEXT: #pragma omp atomic compare seq_cst weak
// CHECK-51-NEXT: if (a == b) {
// CHECK-51-NEXT: a = c;
// CHECK-51-NEXT: }
// expect-note@+1 {{in instantiation of function template specialization 'foo<int>' requested here}}
return foo(a);
}
Expand Down
11 changes: 11 additions & 0 deletions clang/test/OpenMP/atomic_messages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,17 @@ int mixed() {
// expected-error@+1 {{directive '#pragma omp atomic' cannot contain more than one 'fail' clause}}
#pragma omp atomic compare fail(relaxed) fail(seq_cst)
if(v < a) { v = a; }
#pragma omp atomic compare seq_cst weak
if(v == a) { v = a; }
// expected-error@+1 {{expected 'compare' clause with the 'weak' modifier}}
#pragma omp atomic weak
if(v < a) { v = a; }
#pragma omp atomic compare release weak
// expected-error@+1 {{expected '==' operator for 'weak' clause}}
if(v < a) { v = a; }
// expected-error@+1 {{directive '#pragma omp atomic' cannot contain more than one 'weak' clause}}
#pragma omp atomic compare release weak fail(seq_cst) weak
if(v == a) { v = a; }


#endif
Expand Down
2 changes: 2 additions & 0 deletions clang/tools/libclang/CIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,8 @@ void OMPClauseEnqueue::VisitOMPReleaseClause(const OMPReleaseClause *) {}

void OMPClauseEnqueue::VisitOMPRelaxedClause(const OMPRelaxedClause *) {}

void OMPClauseEnqueue::VisitOMPWeakClause(const OMPWeakClause *) {}

void OMPClauseEnqueue::VisitOMPThreadsClause(const OMPThreadsClause *) {}

void OMPClauseEnqueue::VisitOMPSIMDClause(const OMPSIMDClause *) {}
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Semantics/check-omp-structure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2264,6 +2264,7 @@ CHECK_SIMPLE_CLAUSE(OmpxAttribute, OMPC_ompx_attribute)
CHECK_SIMPLE_CLAUSE(OmpxBare, OMPC_ompx_bare)
CHECK_SIMPLE_CLAUSE(Enter, OMPC_enter)
CHECK_SIMPLE_CLAUSE(Fail, OMPC_fail)
CHECK_SIMPLE_CLAUSE(Weak, OMPC_weak)

CHECK_REQ_SCALAR_INT_CLAUSE(Grainsize, OMPC_grainsize)
CHECK_REQ_SCALAR_INT_CLAUSE(NumTasks, OMPC_num_tasks)
Expand Down
4 changes: 3 additions & 1 deletion llvm/include/llvm/Frontend/OpenMP/OMP.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def OMPC_AcqRel : Clause<"acq_rel"> { let clangClass = "OMPAcqRelClause"; }
def OMPC_Acquire : Clause<"acquire"> { let clangClass = "OMPAcquireClause"; }
def OMPC_Release : Clause<"release"> { let clangClass = "OMPReleaseClause"; }
def OMPC_Relaxed : Clause<"relaxed"> { let clangClass = "OMPRelaxedClause"; }
def OMPC_Weak : Clause<"weak"> { let clangClass = "OMPWeakClause"; }
def OMPC_Depend : Clause<"depend"> {
let clangClass = "OMPDependClause";
let flangClass = "OmpDependClause";
Expand Down Expand Up @@ -642,7 +643,8 @@ def OMP_Atomic : Directive<"atomic"> {
VersionedClause<OMPC_Release, 50>,
VersionedClause<OMPC_Relaxed, 50>,
VersionedClause<OMPC_Hint, 50>,
VersionedClause<OMPC_Fail, 51>
VersionedClause<OMPC_Fail, 51>,
VersionedClause<OMPC_Weak, 51>
];
}
def OMP_Target : Directive<"target"> {
Expand Down