Skip to content

Commit 6af4118

Browse files
authored
Reapply #91116 with fix (#93160)
This PR contains 2 commits: 1. A commit to reapply changes introduced #91116 (was reverted earlier due to test suite failures) 2. A commit containing a possible solution for the issue causing the test suite failures. In particular, it introduces a simple symbol visitor class to keep track of the current active OMP construct and marking this active construct as the scope defining the symbol being visisted.
1 parent 8760d4b commit 6af4118

File tree

61 files changed

+420
-217
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+420
-217
lines changed

flang/include/flang/Lower/AbstractConverter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ class AbstractConverter {
131131

132132
/// For a given symbol, check if it is present in the inner-most
133133
/// level of the symbol map.
134-
virtual bool isPresentShallowLookup(Fortran::semantics::Symbol &sym) = 0;
134+
virtual bool
135+
isPresentShallowLookup(const Fortran::semantics::Symbol &sym) = 0;
135136

136137
/// Collect the set of symbols with \p flag in \p eval
137138
/// region if \p collectSymbols is true. Otherwise, collect the

flang/lib/Lower/Bridge.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
602602
return typeConstructionStack;
603603
}
604604

605-
bool isPresentShallowLookup(Fortran::semantics::Symbol &sym) override final {
605+
bool
606+
isPresentShallowLookup(const Fortran::semantics::Symbol &sym) override final {
606607
return bool(shallowLookupSymbol(sym));
607608
}
608609

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,41 @@
2222
namespace Fortran {
2323
namespace lower {
2424
namespace omp {
25+
bool DataSharingProcessor::OMPConstructSymbolVisitor::isSymbolDefineBy(
26+
const semantics::Symbol *symbol, lower::pft::Evaluation &eval) const {
27+
return eval.visit(
28+
common::visitors{[&](const parser::OpenMPConstruct &functionParserNode) {
29+
return symDefMap.count(symbol) &&
30+
symDefMap.at(symbol) == &functionParserNode;
31+
},
32+
[](const auto &functionParserNode) { return false; }});
33+
}
34+
35+
DataSharingProcessor::DataSharingProcessor(
36+
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
37+
const List<Clause> &clauses, lower::pft::Evaluation &eval,
38+
bool shouldCollectPreDeterminedSymbols, bool useDelayedPrivatization,
39+
lower::SymMap *symTable)
40+
: hasLastPrivateOp(false), converter(converter), semaCtx(semaCtx),
41+
firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval),
42+
shouldCollectPreDeterminedSymbols(shouldCollectPreDeterminedSymbols),
43+
useDelayedPrivatization(useDelayedPrivatization), symTable(symTable),
44+
visitor() {
45+
eval.visit([&](const auto &functionParserNode) {
46+
parser::Walk(functionParserNode, visitor);
47+
});
48+
}
2549

2650
void DataSharingProcessor::processStep1(
2751
mlir::omp::PrivateClauseOps *clauseOps,
2852
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms) {
2953
collectSymbolsForPrivatization();
3054
collectDefaultSymbols();
3155
collectImplicitSymbols();
56+
collectPreDeterminedSymbols();
57+
3258
privatize(clauseOps, privateSyms);
33-
defaultPrivatize(clauseOps, privateSyms);
34-
implicitPrivatize(clauseOps, privateSyms);
59+
3560
insertBarrier();
3661
}
3762

@@ -57,7 +82,7 @@ void DataSharingProcessor::processStep2(mlir::Operation *op, bool isLoop) {
5782
}
5883

5984
void DataSharingProcessor::insertDeallocs() {
60-
for (const semantics::Symbol *sym : privatizedSymbols)
85+
for (const semantics::Symbol *sym : allPrivatizedSymbols)
6186
if (semantics::IsAllocatable(sym->GetUltimate())) {
6287
if (!useDelayedPrivatization) {
6388
converter.createHostAssociateVarCloneDealloc(*sym);
@@ -92,10 +117,6 @@ void DataSharingProcessor::insertDeallocs() {
92117
}
93118

94119
void DataSharingProcessor::cloneSymbol(const semantics::Symbol *sym) {
95-
// Privatization for symbols which are pre-determined (like loop index
96-
// variables) happen separately, for everything else privatize here.
97-
if (sym->test(semantics::Symbol::Flag::OmpPreDetermined))
98-
return;
99120
bool success = converter.createHostAssociateVarClone(*sym);
100121
(void)success;
101122
assert(success && "Privatization failed due to existing binding");
@@ -126,20 +147,24 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
126147
for (const omp::Clause &clause : clauses) {
127148
if (const auto &privateClause =
128149
std::get_if<omp::clause::Private>(&clause.u)) {
129-
collectOmpObjectListSymbol(privateClause->v, privatizedSymbols);
150+
collectOmpObjectListSymbol(privateClause->v, explicitlyPrivatizedSymbols);
130151
} else if (const auto &firstPrivateClause =
131152
std::get_if<omp::clause::Firstprivate>(&clause.u)) {
132-
collectOmpObjectListSymbol(firstPrivateClause->v, privatizedSymbols);
153+
collectOmpObjectListSymbol(firstPrivateClause->v,
154+
explicitlyPrivatizedSymbols);
133155
} else if (const auto &lastPrivateClause =
134156
std::get_if<omp::clause::Lastprivate>(&clause.u)) {
135157
const ObjectList &objects = std::get<ObjectList>(lastPrivateClause->t);
136-
collectOmpObjectListSymbol(objects, privatizedSymbols);
158+
collectOmpObjectListSymbol(objects, explicitlyPrivatizedSymbols);
137159
hasLastPrivateOp = true;
138160
} else if (std::get_if<omp::clause::Collapse>(&clause.u)) {
139161
hasCollapse = true;
140162
}
141163
}
142164

165+
for (auto *sym : explicitlyPrivatizedSymbols)
166+
allPrivatizedSymbols.insert(sym);
167+
143168
if (hasCollapse && hasLastPrivateOp)
144169
TODO(converter.getCurrentLocation(), "Collapse clause with lastprivate");
145170
}
@@ -149,7 +174,7 @@ bool DataSharingProcessor::needBarrier() {
149174
// initialization of firstprivate variables and post-update of lastprivate
150175
// variables.
151176
// Emit implicit barrier for linear clause. Maybe on somewhere else.
152-
for (const semantics::Symbol *sym : privatizedSymbols) {
177+
for (const semantics::Symbol *sym : allPrivatizedSymbols) {
153178
if (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) &&
154179
sym->test(semantics::Symbol::Flag::OmpLastPrivate))
155180
return true;
@@ -283,10 +308,11 @@ void DataSharingProcessor::collectSymbolsInNestedRegions(
283308
if (nestedEval.isConstruct())
284309
// Recursively look for OpenMP constructs within `nestedEval`'s region
285310
collectSymbolsInNestedRegions(nestedEval, flag, symbolsInNestedRegions);
286-
else
311+
else {
287312
converter.collectSymbolSet(nestedEval, symbolsInNestedRegions, flag,
288313
/*collectSymbols=*/true,
289314
/*collectHostAssociatedSymbols=*/false);
315+
}
290316
}
291317
}
292318
}
@@ -322,24 +348,44 @@ void DataSharingProcessor::collectSymbols(
322348
converter.collectSymbolSet(eval, allSymbols, flag,
323349
/*collectSymbols=*/true,
324350
/*collectHostAssociatedSymbols=*/true);
351+
325352
llvm::SetVector<const semantics::Symbol *> symbolsInNestedRegions;
326353
collectSymbolsInNestedRegions(eval, flag, symbolsInNestedRegions);
354+
355+
for (auto *symbol : allSymbols)
356+
if (visitor.isSymbolDefineBy(symbol, eval))
357+
symbolsInNestedRegions.remove(symbol);
358+
327359
// Filter-out symbols that must not be privatized.
328360
bool collectImplicit = flag == semantics::Symbol::Flag::OmpImplicit;
361+
bool collectPreDetermined = flag == semantics::Symbol::Flag::OmpPreDetermined;
362+
329363
auto isPrivatizable = [](const semantics::Symbol &sym) -> bool {
330364
return !semantics::IsProcedure(sym) &&
331365
!sym.GetUltimate().has<semantics::DerivedTypeDetails>() &&
332366
!sym.GetUltimate().has<semantics::NamelistDetails>() &&
333367
!semantics::IsImpliedDoIndex(sym.GetUltimate());
334368
};
369+
370+
auto shouldCollectSymbol = [&](const semantics::Symbol *sym) {
371+
if (collectImplicit)
372+
return sym->test(semantics::Symbol::Flag::OmpImplicit);
373+
374+
if (collectPreDetermined)
375+
return sym->test(semantics::Symbol::Flag::OmpPreDetermined);
376+
377+
return !sym->test(semantics::Symbol::Flag::OmpImplicit) &&
378+
!sym->test(semantics::Symbol::Flag::OmpPreDetermined);
379+
};
380+
335381
for (const auto *sym : allSymbols) {
336382
assert(curScope && "couldn't find current scope");
337383
if (isPrivatizable(*sym) && !symbolsInNestedRegions.contains(sym) &&
338-
!privatizedSymbols.contains(sym) &&
339-
!sym->test(semantics::Symbol::Flag::OmpPreDetermined) &&
340-
(collectImplicit || !sym->test(semantics::Symbol::Flag::OmpImplicit)) &&
341-
clauseScopes.contains(&sym->owner()))
384+
!explicitlyPrivatizedSymbols.contains(sym) &&
385+
shouldCollectSymbol(sym) && clauseScopes.contains(&sym->owner())) {
386+
allPrivatizedSymbols.insert(sym);
342387
symbols.insert(sym);
388+
}
343389
}
344390
}
345391

@@ -363,10 +409,16 @@ void DataSharingProcessor::collectImplicitSymbols() {
363409
collectSymbols(semantics::Symbol::Flag::OmpImplicit, implicitSymbols);
364410
}
365411

412+
void DataSharingProcessor::collectPreDeterminedSymbols() {
413+
if (shouldCollectPreDeterminedSymbols)
414+
collectSymbols(semantics::Symbol::Flag::OmpPreDetermined,
415+
preDeterminedSymbols);
416+
}
417+
366418
void DataSharingProcessor::privatize(
367419
mlir::omp::PrivateClauseOps *clauseOps,
368420
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms) {
369-
for (const semantics::Symbol *sym : privatizedSymbols) {
421+
for (const semantics::Symbol *sym : allPrivatizedSymbols) {
370422
if (const auto *commonDet =
371423
sym->detailsIf<semantics::CommonBlockDetails>()) {
372424
for (const auto &mem : commonDet->objects())
@@ -378,7 +430,7 @@ void DataSharingProcessor::privatize(
378430

379431
void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) {
380432
insertLastPrivateCompare(op);
381-
for (const semantics::Symbol *sym : privatizedSymbols)
433+
for (const semantics::Symbol *sym : allPrivatizedSymbols)
382434
if (const auto *commonDet =
383435
sym->detailsIf<semantics::CommonBlockDetails>()) {
384436
for (const auto &mem : commonDet->objects()) {
@@ -389,20 +441,6 @@ void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) {
389441
}
390442
}
391443

392-
void DataSharingProcessor::defaultPrivatize(
393-
mlir::omp::PrivateClauseOps *clauseOps,
394-
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms) {
395-
for (const semantics::Symbol *sym : defaultSymbols)
396-
doPrivatize(sym, clauseOps, privateSyms);
397-
}
398-
399-
void DataSharingProcessor::implicitPrivatize(
400-
mlir::omp::PrivateClauseOps *clauseOps,
401-
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms) {
402-
for (const semantics::Symbol *sym : implicitSymbols)
403-
doPrivatize(sym, clauseOps, privateSyms);
404-
}
405-
406444
void DataSharingProcessor::doPrivatize(
407445
const semantics::Symbol *sym, mlir::omp::PrivateClauseOps *clauseOps,
408446
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms) {

flang/lib/Lower/OpenMP/DataSharingProcessor.h

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,62 @@ namespace omp {
3232

3333
class DataSharingProcessor {
3434
private:
35+
/// A symbol visitor that keeps track of the currently active OpenMPConstruct
36+
/// at any point in time. This is used to track Symbol definition scopes in
37+
/// order to tell which OMP scope defined vs. references a certain Symbol.
38+
struct OMPConstructSymbolVisitor {
39+
template <typename T>
40+
bool Pre(const T &) {
41+
return true;
42+
}
43+
template <typename T>
44+
void Post(const T &) {}
45+
46+
bool Pre(const parser::OpenMPConstruct &omp) {
47+
currentConstruct = &omp;
48+
return true;
49+
}
50+
51+
void Post(const parser::OpenMPConstruct &omp) {
52+
currentConstruct = nullptr;
53+
}
54+
55+
void Post(const parser::Name &name) {
56+
symDefMap.try_emplace(name.symbol, currentConstruct);
57+
}
58+
59+
const parser::OpenMPConstruct *currentConstruct = nullptr;
60+
llvm::DenseMap<semantics::Symbol *, const parser::OpenMPConstruct *>
61+
symDefMap;
62+
63+
/// Given a \p symbol and an \p eval, returns true if eval is the OMP
64+
/// construct that defines symbol.
65+
bool isSymbolDefineBy(const semantics::Symbol *symbol,
66+
lower::pft::Evaluation &eval) const;
67+
};
68+
3569
bool hasLastPrivateOp;
3670
mlir::OpBuilder::InsertPoint lastPrivIP;
3771
mlir::OpBuilder::InsertPoint insPt;
3872
mlir::Value loopIV;
3973
// Symbols in private, firstprivate, and/or lastprivate clauses.
40-
llvm::SetVector<const semantics::Symbol *> privatizedSymbols;
74+
llvm::SetVector<const semantics::Symbol *> explicitlyPrivatizedSymbols;
4175
llvm::SetVector<const semantics::Symbol *> defaultSymbols;
4276
llvm::SetVector<const semantics::Symbol *> implicitSymbols;
77+
llvm::SetVector<const semantics::Symbol *> preDeterminedSymbols;
78+
llvm::SetVector<const semantics::Symbol *> allPrivatizedSymbols;
79+
4380
llvm::DenseMap<const semantics::Symbol *, mlir::omp::PrivateClauseOp>
4481
symToPrivatizer;
4582
lower::AbstractConverter &converter;
4683
semantics::SemanticsContext &semaCtx;
4784
fir::FirOpBuilder &firOpBuilder;
4885
omp::List<omp::Clause> clauses;
4986
lower::pft::Evaluation &eval;
87+
bool shouldCollectPreDeterminedSymbols;
5088
bool useDelayedPrivatization;
5189
lower::SymMap *symTable;
90+
OMPConstructSymbolVisitor visitor;
5291

5392
bool needBarrier();
5493
void collectSymbols(semantics::Symbol::Flag flag,
@@ -63,6 +102,7 @@ class DataSharingProcessor {
63102
void insertBarrier();
64103
void collectDefaultSymbols();
65104
void collectImplicitSymbols();
105+
void collectPreDeterminedSymbols();
66106
void privatize(mlir::omp::PrivateClauseOps *clauseOps,
67107
llvm::SmallVectorImpl<const semantics::Symbol *> *privateSyms);
68108
void defaultPrivatize(
@@ -90,11 +130,9 @@ class DataSharingProcessor {
90130
semantics::SemanticsContext &semaCtx,
91131
const List<Clause> &clauses,
92132
lower::pft::Evaluation &eval,
133+
bool shouldCollectPreDeterminedSymbols,
93134
bool useDelayedPrivatization = false,
94-
lower::SymMap *symTable = nullptr)
95-
: hasLastPrivateOp(false), converter(converter), semaCtx(semaCtx),
96-
firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval),
97-
useDelayedPrivatization(useDelayedPrivatization), symTable(symTable) {}
135+
lower::SymMap *symTable = nullptr);
98136

99137
// Privatisation is split into two steps.
100138
// Step1 performs cloning of all privatisation clauses and copying for

flang/lib/Lower/OpenMP/Decomposer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,9 @@ ConstructQueue buildConstructQueue(
123123

124124
return constructs;
125125
}
126+
127+
bool isLastItemInQueue(ConstructQueue::iterator item,
128+
const ConstructQueue &queue) {
129+
return std::next(item) == queue.end();
130+
}
126131
} // namespace Fortran::lower::omp

flang/lib/Lower/OpenMP/Decomposer.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ ConstructQueue buildConstructQueue(mlir::ModuleOp modOp,
4646
const parser::CharBlock &source,
4747
llvm::omp::Directive compound,
4848
const List<Clause> &clauses);
49+
50+
bool isLastItemInQueue(ConstructQueue::iterator item,
51+
const ConstructQueue &queue);
4952
} // namespace Fortran::lower::omp
5053

5154
#endif // FORTRAN_LOWER_OPENMP_DECOMPOSER_H

0 commit comments

Comments
 (0)