|
13 | 13 | #include "resolve-names-utils.h"
|
14 | 14 | #include "flang/Common/idioms.h"
|
15 | 15 | #include "flang/Evaluate/fold.h"
|
| 16 | +#include "flang/Evaluate/tools.h" |
16 | 17 | #include "flang/Evaluate/type.h"
|
17 | 18 | #include "flang/Parser/parse-tree-visitor.h"
|
18 | 19 | #include "flang/Parser/parse-tree.h"
|
@@ -266,7 +267,7 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
|
266 | 267 | Symbol::Flag::AccDevicePtr, Symbol::Flag::AccDeviceResident,
|
267 | 268 | Symbol::Flag::AccLink, Symbol::Flag::AccPresent};
|
268 | 269 |
|
269 |
| - void CheckAssociatedLoopIndex(const parser::OpenACCLoopConstruct &); |
| 270 | + void CheckAssociatedLoop(const parser::DoConstruct &); |
270 | 271 | void ResolveAccObjectList(const parser::AccObjectList &, Symbol::Flag);
|
271 | 272 | void ResolveAccObject(const parser::AccObject &, Symbol::Flag);
|
272 | 273 | Symbol *ResolveAcc(const parser::Name &, Symbol::Flag, Scope &);
|
@@ -882,7 +883,8 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCLoopConstruct &x) {
|
882 | 883 | }
|
883 | 884 | ClearDataSharingAttributeObjects();
|
884 | 885 | SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList));
|
885 |
| - CheckAssociatedLoopIndex(x); |
| 886 | + const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)}; |
| 887 | + CheckAssociatedLoop(*outer); |
886 | 888 | return true;
|
887 | 889 | }
|
888 | 890 |
|
@@ -1087,6 +1089,10 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCCombinedConstruct &x) {
|
1087 | 1089 | default:
|
1088 | 1090 | break;
|
1089 | 1091 | }
|
| 1092 | + const auto &clauseList{std::get<parser::AccClauseList>(beginBlockDir.t)}; |
| 1093 | + SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList)); |
| 1094 | + const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)}; |
| 1095 | + CheckAssociatedLoop(*outer); |
1090 | 1096 | ClearDataSharingAttributeObjects();
|
1091 | 1097 | return true;
|
1092 | 1098 | }
|
@@ -1218,8 +1224,8 @@ std::int64_t AccAttributeVisitor::GetAssociatedLoopLevelFromClauses(
|
1218 | 1224 | return 1; // default is outermost loop
|
1219 | 1225 | }
|
1220 | 1226 |
|
1221 |
| -void AccAttributeVisitor::CheckAssociatedLoopIndex( |
1222 |
| - const parser::OpenACCLoopConstruct &x) { |
| 1227 | +void AccAttributeVisitor::CheckAssociatedLoop( |
| 1228 | + const parser::DoConstruct &outerDoConstruct) { |
1223 | 1229 | std::int64_t level{GetContext().associatedLoopLevel};
|
1224 | 1230 | if (level <= 0) { // collapse value was negative or 0
|
1225 | 1231 | return;
|
@@ -1250,10 +1256,41 @@ void AccAttributeVisitor::CheckAssociatedLoopIndex(
|
1250 | 1256 | return nullptr;
|
1251 | 1257 | };
|
1252 | 1258 |
|
1253 |
| - const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)}; |
1254 |
| - for (const parser::DoConstruct *loop{&*outer}; loop && level > 0;) { |
| 1259 | + auto checkExprHasSymbols = [&](llvm::SmallVector<Symbol *> &ivs, |
| 1260 | + semantics::UnorderedSymbolSet &symbols) { |
| 1261 | + for (auto iv : ivs) { |
| 1262 | + if (symbols.count(*iv) != 0) { |
| 1263 | + context_.Say(GetContext().directiveSource, |
| 1264 | + "Trip count must be computable and invariant"_err_en_US); |
| 1265 | + } |
| 1266 | + } |
| 1267 | + }; |
| 1268 | + |
| 1269 | + Symbol::Flag flag = Symbol::Flag::AccPrivate; |
| 1270 | + llvm::SmallVector<Symbol *> ivs; |
| 1271 | + using Bounds = parser::LoopControl::Bounds; |
| 1272 | + for (const parser::DoConstruct *loop{&outerDoConstruct}; loop && level > 0;) { |
1255 | 1273 | // Go through all nested loops to ensure index variable exists.
|
1256 |
| - GetLoopIndex(*loop); |
| 1274 | + if (const parser::Name * ivName{GetLoopIndex(*loop)}) { |
| 1275 | + if (auto *symbol{ResolveAcc(*ivName, flag, currScope())}) { |
| 1276 | + if (auto &control{loop->GetLoopControl()}) { |
| 1277 | + if (const Bounds * b{std::get_if<Bounds>(&control->u)}) { |
| 1278 | + if (auto lowerExpr{semantics::AnalyzeExpr(context_, b->lower)}) { |
| 1279 | + semantics::UnorderedSymbolSet lowerSyms = |
| 1280 | + evaluate::CollectSymbols(*lowerExpr); |
| 1281 | + checkExprHasSymbols(ivs, lowerSyms); |
| 1282 | + } |
| 1283 | + if (auto upperExpr{semantics::AnalyzeExpr(context_, b->upper)}) { |
| 1284 | + semantics::UnorderedSymbolSet upperSyms = |
| 1285 | + evaluate::CollectSymbols(*upperExpr); |
| 1286 | + checkExprHasSymbols(ivs, upperSyms); |
| 1287 | + } |
| 1288 | + } |
| 1289 | + } |
| 1290 | + ivs.push_back(symbol); |
| 1291 | + } |
| 1292 | + } |
| 1293 | + |
1257 | 1294 | const auto &block{std::get<parser::Block>(loop->t)};
|
1258 | 1295 | --level;
|
1259 | 1296 | loop = getNextDoConstruct(block, level);
|
|
0 commit comments