Skip to content

Commit 5530474

Browse files
authored
[Flang][OpenMP] fix crash on sematic error in atomic capture clause (#140710)
Fix a crash caused by an invalid expression in the atomic capture clause, due to the `checkForSymbolMatch` function not accounting for `GetExpr` potentially returning null. Fix #139884
1 parent d4f0f43 commit 5530474

File tree

5 files changed

+66
-44
lines changed

5 files changed

+66
-44
lines changed

flang/include/flang/Semantics/tools.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -764,19 +764,14 @@ inline bool checkForSingleVariableOnRHS(
764764
return designator != nullptr;
765765
}
766766

767-
/// Checks if the symbol on the LHS of the assignment statement is present in
768-
/// the RHS expression.
769-
inline bool checkForSymbolMatch(
770-
const Fortran::parser::AssignmentStmt &assignmentStmt) {
771-
const auto &var{std::get<Fortran::parser::Variable>(assignmentStmt.t)};
772-
const auto &expr{std::get<Fortran::parser::Expr>(assignmentStmt.t)};
773-
const auto *e{Fortran::semantics::GetExpr(expr)};
774-
const auto *v{Fortran::semantics::GetExpr(var)};
775-
auto varSyms{Fortran::evaluate::GetSymbolVector(*v)};
776-
const Fortran::semantics::Symbol &varSymbol{*varSyms.front()};
767+
/// Checks if the symbol on the LHS is present in the RHS expression.
768+
inline bool checkForSymbolMatch(const Fortran::semantics::SomeExpr *lhs,
769+
const Fortran::semantics::SomeExpr *rhs) {
770+
auto lhsSyms{Fortran::evaluate::GetSymbolVector(*lhs)};
771+
const Fortran::semantics::Symbol &lhsSymbol{*lhsSyms.front()};
777772
for (const Fortran::semantics::Symbol &symbol :
778-
Fortran::evaluate::GetSymbolVector(*e)) {
779-
if (varSymbol == symbol) {
773+
Fortran::evaluate::GetSymbolVector(*rhs)) {
774+
if (lhsSymbol == symbol) {
780775
return true;
781776
}
782777
}

flang/lib/Lower/OpenACC.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,9 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
654654
mlir::Block &block = atomicCaptureOp->getRegion(0).back();
655655
firOpBuilder.setInsertionPointToStart(&block);
656656
if (Fortran::semantics::checkForSingleVariableOnRHS(stmt1)) {
657-
if (Fortran::semantics::checkForSymbolMatch(stmt2)) {
657+
if (Fortran::semantics::checkForSymbolMatch(
658+
Fortran::semantics::GetExpr(stmt2Var),
659+
Fortran::semantics::GetExpr(stmt2Expr))) {
658660
// Atomic capture construct is of the form [capture-stmt, update-stmt]
659661
const Fortran::semantics::SomeExpr &fromExpr =
660662
*Fortran::semantics::GetExpr(stmt1Expr);

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3198,7 +3198,8 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
31983198
mlir::Block &block = atomicCaptureOp->getRegion(0).back();
31993199
firOpBuilder.setInsertionPointToStart(&block);
32003200
if (semantics::checkForSingleVariableOnRHS(stmt1)) {
3201-
if (semantics::checkForSymbolMatch(stmt2)) {
3201+
if (semantics::checkForSymbolMatch(semantics::GetExpr(stmt2Var),
3202+
semantics::GetExpr(stmt2Expr))) {
32023203
// Atomic capture construct is of the form [capture-stmt, update-stmt]
32033204
const semantics::SomeExpr &fromExpr = *semantics::GetExpr(stmt1Expr);
32043205
mlir::Type elementType = converter.genType(fromExpr);

flang/lib/Semantics/check-omp-structure.cpp

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2910,45 +2910,47 @@ void OmpStructureChecker::CheckAtomicCaptureConstruct(
29102910
.v.statement;
29112911
const auto &stmt1Var{std::get<parser::Variable>(stmt1.t)};
29122912
const auto &stmt1Expr{std::get<parser::Expr>(stmt1.t)};
2913+
const auto *v1 = GetExpr(context_, stmt1Var);
2914+
const auto *e1 = GetExpr(context_, stmt1Expr);
29132915

29142916
const parser::AssignmentStmt &stmt2 =
29152917
std::get<parser::OmpAtomicCapture::Stmt2>(atomicCaptureConstruct.t)
29162918
.v.statement;
29172919
const auto &stmt2Var{std::get<parser::Variable>(stmt2.t)};
29182920
const auto &stmt2Expr{std::get<parser::Expr>(stmt2.t)};
2919-
2920-
if (semantics::checkForSingleVariableOnRHS(stmt1)) {
2921-
CheckAtomicCaptureStmt(stmt1);
2922-
if (semantics::checkForSymbolMatch(stmt2)) {
2923-
// ATOMIC CAPTURE construct is of the form [capture-stmt, update-stmt]
2924-
CheckAtomicUpdateStmt(stmt2);
2921+
const auto *v2 = GetExpr(context_, stmt2Var);
2922+
const auto *e2 = GetExpr(context_, stmt2Expr);
2923+
2924+
if (e1 && v1 && e2 && v2) {
2925+
if (semantics::checkForSingleVariableOnRHS(stmt1)) {
2926+
CheckAtomicCaptureStmt(stmt1);
2927+
if (semantics::checkForSymbolMatch(v2, e2)) {
2928+
// ATOMIC CAPTURE construct is of the form [capture-stmt, update-stmt]
2929+
CheckAtomicUpdateStmt(stmt2);
2930+
} else {
2931+
// ATOMIC CAPTURE construct is of the form [capture-stmt, write-stmt]
2932+
CheckAtomicWriteStmt(stmt2);
2933+
}
2934+
if (!(*e1 == *v2)) {
2935+
context_.Say(stmt1Expr.source,
2936+
"Captured variable/array element/derived-type component %s expected to be assigned in the second statement of ATOMIC CAPTURE construct"_err_en_US,
2937+
stmt1Expr.source);
2938+
}
2939+
} else if (semantics::checkForSymbolMatch(v1, e1) &&
2940+
semantics::checkForSingleVariableOnRHS(stmt2)) {
2941+
// ATOMIC CAPTURE construct is of the form [update-stmt, capture-stmt]
2942+
CheckAtomicUpdateStmt(stmt1);
2943+
CheckAtomicCaptureStmt(stmt2);
2944+
// Variable updated in stmt1 should be captured in stmt2
2945+
if (!(*v1 == *e2)) {
2946+
context_.Say(stmt1Var.GetSource(),
2947+
"Updated variable/array element/derived-type component %s expected to be captured in the second statement of ATOMIC CAPTURE construct"_err_en_US,
2948+
stmt1Var.GetSource());
2949+
}
29252950
} else {
2926-
// ATOMIC CAPTURE construct is of the form [capture-stmt, write-stmt]
2927-
CheckAtomicWriteStmt(stmt2);
2928-
}
2929-
auto *v{stmt2Var.typedExpr.get()};
2930-
auto *e{stmt1Expr.typedExpr.get()};
2931-
if (v && e && !(v->v == e->v)) {
29322951
context_.Say(stmt1Expr.source,
2933-
"Captured variable/array element/derived-type component %s expected to be assigned in the second statement of ATOMIC CAPTURE construct"_err_en_US,
2934-
stmt1Expr.source);
2935-
}
2936-
} else if (semantics::checkForSymbolMatch(stmt1) &&
2937-
semantics::checkForSingleVariableOnRHS(stmt2)) {
2938-
// ATOMIC CAPTURE construct is of the form [update-stmt, capture-stmt]
2939-
CheckAtomicUpdateStmt(stmt1);
2940-
CheckAtomicCaptureStmt(stmt2);
2941-
// Variable updated in stmt1 should be captured in stmt2
2942-
auto *v{stmt1Var.typedExpr.get()};
2943-
auto *e{stmt2Expr.typedExpr.get()};
2944-
if (v && e && !(v->v == e->v)) {
2945-
context_.Say(stmt1Var.GetSource(),
2946-
"Updated variable/array element/derived-type component %s expected to be captured in the second statement of ATOMIC CAPTURE construct"_err_en_US,
2947-
stmt1Var.GetSource());
2952+
"Invalid ATOMIC CAPTURE construct statements. Expected one of [update-stmt, capture-stmt], [capture-stmt, update-stmt], or [capture-stmt, write-stmt]"_err_en_US);
29482953
}
2949-
} else {
2950-
context_.Say(stmt1Expr.source,
2951-
"Invalid ATOMIC CAPTURE construct statements. Expected one of [update-stmt, capture-stmt], [capture-stmt, update-stmt], or [capture-stmt, write-stmt]"_err_en_US);
29522954
}
29532955
}
29542956

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
! REQUIRES: openmp_runtime
2+
3+
! RUN: %python %S/../test_errors.py %s %flang_fc1 %openmp_flags
4+
! Semantic checks on invalid atomic capture clause
5+
6+
use omp_lib
7+
logical x
8+
complex y
9+
!$omp atomic capture
10+
!ERROR: No intrinsic or user-defined ASSIGNMENT(=) matches operand types LOGICAL(4) and COMPLEX(4)
11+
x = y
12+
!ERROR: Operands of + must be numeric; have COMPLEX(4) and LOGICAL(4)
13+
y = y + x
14+
!$omp end atomic
15+
16+
!$omp atomic capture
17+
!ERROR: Operands of + must be numeric; have COMPLEX(4) and LOGICAL(4)
18+
y = y + x
19+
!ERROR: No intrinsic or user-defined ASSIGNMENT(=) matches operand types LOGICAL(4) and COMPLEX(4)
20+
x = y
21+
!$omp end atomic
22+
end

0 commit comments

Comments
 (0)