Skip to content

Commit 2f490f9

Browse files
committed
[flang][openacc] Check trip count invariance with other IVs (#79906)
2.9.1 The trip count for all loops associated with the collapse clause must be computable and invariant in all the loops. This patch checks that loops part of a collapse nest does not depends on outer loops induction variables. The check is also applied to combined construct with a loop.
1 parent bb770f0 commit 2f490f9

File tree

2 files changed

+83
-8
lines changed

2 files changed

+83
-8
lines changed

flang/lib/Semantics/resolve-directives.cpp

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "resolve-names-utils.h"
1414
#include "flang/Common/idioms.h"
1515
#include "flang/Evaluate/fold.h"
16+
#include "flang/Evaluate/tools.h"
1617
#include "flang/Evaluate/type.h"
1718
#include "flang/Parser/parse-tree-visitor.h"
1819
#include "flang/Parser/parse-tree.h"
@@ -266,7 +267,7 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
266267
Symbol::Flag::AccDevicePtr, Symbol::Flag::AccDeviceResident,
267268
Symbol::Flag::AccLink, Symbol::Flag::AccPresent};
268269

269-
void CheckAssociatedLoopIndex(const parser::OpenACCLoopConstruct &);
270+
void CheckAssociatedLoop(const parser::DoConstruct &);
270271
void ResolveAccObjectList(const parser::AccObjectList &, Symbol::Flag);
271272
void ResolveAccObject(const parser::AccObject &, Symbol::Flag);
272273
Symbol *ResolveAcc(const parser::Name &, Symbol::Flag, Scope &);
@@ -882,7 +883,8 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCLoopConstruct &x) {
882883
}
883884
ClearDataSharingAttributeObjects();
884885
SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList));
885-
CheckAssociatedLoopIndex(x);
886+
const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)};
887+
CheckAssociatedLoop(*outer);
886888
return true;
887889
}
888890

@@ -1087,6 +1089,10 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCCombinedConstruct &x) {
10871089
default:
10881090
break;
10891091
}
1092+
const auto &clauseList{std::get<parser::AccClauseList>(beginBlockDir.t)};
1093+
SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList));
1094+
const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)};
1095+
CheckAssociatedLoop(*outer);
10901096
ClearDataSharingAttributeObjects();
10911097
return true;
10921098
}
@@ -1218,8 +1224,8 @@ std::int64_t AccAttributeVisitor::GetAssociatedLoopLevelFromClauses(
12181224
return 1; // default is outermost loop
12191225
}
12201226

1221-
void AccAttributeVisitor::CheckAssociatedLoopIndex(
1222-
const parser::OpenACCLoopConstruct &x) {
1227+
void AccAttributeVisitor::CheckAssociatedLoop(
1228+
const parser::DoConstruct &outerDoConstruct) {
12231229
std::int64_t level{GetContext().associatedLoopLevel};
12241230
if (level <= 0) { // collapse value was negative or 0
12251231
return;
@@ -1250,10 +1256,41 @@ void AccAttributeVisitor::CheckAssociatedLoopIndex(
12501256
return nullptr;
12511257
};
12521258

1253-
const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)};
1254-
for (const parser::DoConstruct *loop{&*outer}; loop && level > 0;) {
1259+
auto checkExprHasSymbols = [&](llvm::SmallVector<Symbol *> &ivs,
1260+
semantics::UnorderedSymbolSet &symbols) {
1261+
for (auto iv : ivs) {
1262+
if (symbols.count(*iv) != 0) {
1263+
context_.Say(GetContext().directiveSource,
1264+
"Trip count must be computable and invariant"_err_en_US);
1265+
}
1266+
}
1267+
};
1268+
1269+
Symbol::Flag flag = Symbol::Flag::AccPrivate;
1270+
llvm::SmallVector<Symbol *> ivs;
1271+
using Bounds = parser::LoopControl::Bounds;
1272+
for (const parser::DoConstruct *loop{&outerDoConstruct}; loop && level > 0;) {
12551273
// Go through all nested loops to ensure index variable exists.
1256-
GetLoopIndex(*loop);
1274+
if (const parser::Name * ivName{GetLoopIndex(*loop)}) {
1275+
if (auto *symbol{ResolveAcc(*ivName, flag, currScope())}) {
1276+
if (auto &control{loop->GetLoopControl()}) {
1277+
if (const Bounds * b{std::get_if<Bounds>(&control->u)}) {
1278+
if (auto lowerExpr{semantics::AnalyzeExpr(context_, b->lower)}) {
1279+
semantics::UnorderedSymbolSet lowerSyms =
1280+
evaluate::CollectSymbols(*lowerExpr);
1281+
checkExprHasSymbols(ivs, lowerSyms);
1282+
}
1283+
if (auto upperExpr{semantics::AnalyzeExpr(context_, b->upper)}) {
1284+
semantics::UnorderedSymbolSet upperSyms =
1285+
evaluate::CollectSymbols(*upperExpr);
1286+
checkExprHasSymbols(ivs, upperSyms);
1287+
}
1288+
}
1289+
}
1290+
ivs.push_back(symbol);
1291+
}
1292+
}
1293+
12571294
const auto &block{std::get<parser::Block>(loop->t)};
12581295
--level;
12591296
loop = getNextDoConstruct(block, level);

flang/test/Semantics/OpenACC/acc-loop.f90

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ program openacc_loop_validity
1010
type atype
1111
real(8), dimension(10) :: arr
1212
real(8) :: s
13+
integer :: n
1314
end type atype
1415

15-
integer :: i, j, b, gang_size, vector_size, worker_size
16+
integer :: i, j, k, b, gang_size, vector_size, worker_size
1617
integer, parameter :: N = 256
1718
integer, dimension(N) :: c
1819
logical, dimension(N) :: d, e
@@ -317,4 +318,41 @@ program openacc_loop_validity
317318
END DO
318319
END DO
319320

321+
!ERROR: Trip count must be computable and invariant
322+
!$acc loop collapse(2)
323+
DO i = 1, n
324+
DO j = 1, c(i)
325+
END DO
326+
END DO
327+
328+
!ERROR: Trip count must be computable and invariant
329+
!$acc loop collapse(2)
330+
DO i = 1, n
331+
DO j = 1, i
332+
END DO
333+
END DO
334+
335+
!ERROR: Trip count must be computable and invariant
336+
!$acc loop collapse(2)
337+
DO i = 1, n
338+
DO j = 1, ta(i)%n
339+
END DO
340+
END DO
341+
342+
!ERROR: Trip count must be computable and invariant
343+
!$acc parallel loop collapse(2)
344+
DO i = 1, n
345+
DO j = 1, ta(i)%n
346+
END DO
347+
END DO
348+
349+
!ERROR: Trip count must be computable and invariant
350+
!$acc loop collapse(3)
351+
DO i = 1, n
352+
DO j = 1, n
353+
DO k = 1, i
354+
END DO
355+
END DO
356+
END DO
357+
320358
end program openacc_loop_validity

0 commit comments

Comments
 (0)