Skip to content

Commit 4d4af15

Browse files
authored
[NFC][flang][OpenMP] Split DataSharing and Clause processors (#81973)
This started as an experiment to reduce the compilation time of iterating over `Lower/OpenMP.cpp` a bit since it is too slow at the moment. Trying to do that, I split the `DataSharingProcessor`, `ReductionProcessor`, and `ClauseProcessor` into their own files and extracted some shared code into a util file. All of these new `.h/.cpp` files as well as `OpenMP.cpp` are now under a `Lower/OpenMP/` directory. This resulted is a slightly better organization of the OpenMP lowering code and hence opening this NFC. As for the compilation time, this unfortunately does not affect it much (it shaves off a few seconds of `OpenMP.cpp` compilation) since from what I learned the bottleneck is in `DirectivesCommon.h` and `PFTBuilder.h` which both consume a lot of time in template instantiation it seems.
1 parent a0869b1 commit 4d4af15

10 files changed

+2371
-2035
lines changed

flang/lib/Lower/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ add_flang_library(FortranLower
2424
LoweringOptions.cpp
2525
Mangler.cpp
2626
OpenACC.cpp
27-
OpenMP.cpp
27+
OpenMP/ClauseProcessor.cpp
28+
OpenMP/DataSharingProcessor.cpp
29+
OpenMP/OpenMP.cpp
30+
OpenMP/ReductionProcessor.cpp
31+
OpenMP/Utils.cpp
2832
PFTBuilder.cpp
2933
Runtime.cpp
3034
SymbolMap.cpp

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 880 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
//===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10+
//
11+
//===----------------------------------------------------------------------===//
12+
#ifndef FORTRAN_LOWER_CLAUASEPROCESSOR_H
13+
#define FORTRAN_LOWER_CLAUASEPROCESSOR_H
14+
15+
#include "DirectivesCommon.h"
16+
#include "ReductionProcessor.h"
17+
#include "Utils.h"
18+
#include "flang/Lower/AbstractConverter.h"
19+
#include "flang/Lower/Bridge.h"
20+
#include "flang/Optimizer/Builder/Todo.h"
21+
#include "flang/Parser/dump-parse-tree.h"
22+
#include "flang/Parser/parse-tree.h"
23+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
24+
25+
namespace fir {
26+
class FirOpBuilder;
27+
} // namespace fir
28+
29+
namespace Fortran {
30+
namespace lower {
31+
namespace omp {
32+
33+
/// Class that handles the processing of OpenMP clauses.
34+
///
35+
/// Its `process<ClauseName>()` methods perform MLIR code generation for their
36+
/// corresponding clause if it is present in the clause list. Otherwise, they
37+
/// will return `false` to signal that the clause was not found.
38+
///
39+
/// The intended use is of this class is to move clause processing outside of
40+
/// construct processing, since the same clauses can appear attached to
41+
/// different constructs and constructs can be combined, so that code
42+
/// duplication is minimized.
43+
///
44+
/// Each construct-lowering function only calls the `process<ClauseName>()`
45+
/// methods that relate to clauses that can impact the lowering of that
46+
/// construct.
47+
class ClauseProcessor {
48+
using ClauseTy = Fortran::parser::OmpClause;
49+
50+
public:
51+
ClauseProcessor(Fortran::lower::AbstractConverter &converter,
52+
Fortran::semantics::SemanticsContext &semaCtx,
53+
const Fortran::parser::OmpClauseList &clauses)
54+
: converter(converter), semaCtx(semaCtx), clauses(clauses) {}
55+
56+
// 'Unique' clauses: They can appear at most once in the clause list.
57+
bool
58+
processCollapse(mlir::Location currentLocation,
59+
Fortran::lower::pft::Evaluation &eval,
60+
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
61+
llvm::SmallVectorImpl<mlir::Value> &upperBound,
62+
llvm::SmallVectorImpl<mlir::Value> &step,
63+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
64+
std::size_t &loopVarTypeSize) const;
65+
bool processDefault() const;
66+
bool processDevice(Fortran::lower::StatementContext &stmtCtx,
67+
mlir::Value &result) const;
68+
bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const;
69+
bool processFinal(Fortran::lower::StatementContext &stmtCtx,
70+
mlir::Value &result) const;
71+
bool processHint(mlir::IntegerAttr &result) const;
72+
bool processMergeable(mlir::UnitAttr &result) const;
73+
bool processNowait(mlir::UnitAttr &result) const;
74+
bool processNumTeams(Fortran::lower::StatementContext &stmtCtx,
75+
mlir::Value &result) const;
76+
bool processNumThreads(Fortran::lower::StatementContext &stmtCtx,
77+
mlir::Value &result) const;
78+
bool processOrdered(mlir::IntegerAttr &result) const;
79+
bool processPriority(Fortran::lower::StatementContext &stmtCtx,
80+
mlir::Value &result) const;
81+
bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const;
82+
bool processSafelen(mlir::IntegerAttr &result) const;
83+
bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr,
84+
mlir::omp::ScheduleModifierAttr &modifierAttr,
85+
mlir::UnitAttr &simdModifierAttr) const;
86+
bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx,
87+
mlir::Value &result) const;
88+
bool processSimdlen(mlir::IntegerAttr &result) const;
89+
bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx,
90+
mlir::Value &result) const;
91+
bool processUntied(mlir::UnitAttr &result) const;
92+
93+
// 'Repeatable' clauses: They can appear multiple times in the clause list.
94+
bool
95+
processAllocate(llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
96+
llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const;
97+
bool processCopyin() const;
98+
bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
99+
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
100+
bool
101+
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
102+
bool
103+
processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
104+
mlir::Value &result) const;
105+
bool
106+
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
107+
108+
// This method is used to process a map clause.
109+
// The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to
110+
// store the original type, location and Fortran symbol for the map operands.
111+
// They may be used later on to create the block_arguments for some of the
112+
// target directives that require it.
113+
bool processMap(mlir::Location currentLocation,
114+
const llvm::omp::Directive &directive,
115+
Fortran::lower::StatementContext &stmtCtx,
116+
llvm::SmallVectorImpl<mlir::Value> &mapOperands,
117+
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
118+
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
119+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
120+
*mapSymbols = nullptr) const;
121+
bool
122+
processReduction(mlir::Location currentLocation,
123+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
124+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
125+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
126+
*reductionSymbols = nullptr) const;
127+
bool processSectionsReduction(mlir::Location currentLocation) const;
128+
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
129+
bool
130+
processUseDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
131+
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
132+
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
133+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
134+
&useDeviceSymbols) const;
135+
bool
136+
processUseDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
137+
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
138+
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
139+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
140+
&useDeviceSymbols) const;
141+
142+
template <typename T>
143+
bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx,
144+
llvm::SmallVectorImpl<mlir::Value> &mapOperands);
145+
146+
// Call this method for these clauses that should be supported but are not
147+
// implemented yet. It triggers a compilation error if any of the given
148+
// clauses is found.
149+
template <typename... Ts>
150+
void processTODO(mlir::Location currentLocation,
151+
llvm::omp::Directive directive) const;
152+
153+
private:
154+
using ClauseIterator = std::list<ClauseTy>::const_iterator;
155+
156+
/// Utility to find a clause within a range in the clause list.
157+
template <typename T>
158+
static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end);
159+
160+
/// Return the first instance of the given clause found in the clause list or
161+
/// `nullptr` if not present. If more than one instance is expected, use
162+
/// `findRepeatableClause` instead.
163+
template <typename T>
164+
const T *
165+
findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const;
166+
167+
/// Call `callbackFn` for each occurrence of the given clause. Return `true`
168+
/// if at least one instance was found.
169+
template <typename T>
170+
bool findRepeatableClause(
171+
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
172+
callbackFn) const;
173+
174+
/// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
175+
template <typename T>
176+
bool markClauseOccurrence(mlir::UnitAttr &result) const;
177+
178+
Fortran::lower::AbstractConverter &converter;
179+
Fortran::semantics::SemanticsContext &semaCtx;
180+
const Fortran::parser::OmpClauseList &clauses;
181+
};
182+
183+
template <typename T>
184+
bool ClauseProcessor::processMotionClauses(
185+
Fortran::lower::StatementContext &stmtCtx,
186+
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
187+
return findRepeatableClause<T>(
188+
[&](const T *motionClause, const Fortran::parser::CharBlock &source) {
189+
mlir::Location clauseLocation = converter.genLocation(source);
190+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
191+
192+
static_assert(std::is_same_v<T, ClauseProcessor::ClauseTy::To> ||
193+
std::is_same_v<T, ClauseProcessor::ClauseTy::From>);
194+
195+
// TODO Support motion modifiers: present, mapper, iterator.
196+
constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
197+
std::is_same_v<T, ClauseProcessor::ClauseTy::To>
198+
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
199+
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
200+
201+
for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) {
202+
llvm::SmallVector<mlir::Value> bounds;
203+
std::stringstream asFortran;
204+
Fortran::lower::AddrAndBoundsInfo info =
205+
Fortran::lower::gatherDataOperandAddrAndBounds<
206+
Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
207+
mlir::omp::DataBoundsType>(
208+
converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
209+
clauseLocation, asFortran, bounds, treatIndexAsSection);
210+
211+
auto origSymbol =
212+
converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
213+
mlir::Value symAddr = info.addr;
214+
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
215+
symAddr = origSymbol;
216+
217+
// Explicit map captures are captured ByRef by default,
218+
// optimisation passes may alter this to ByCopy or other capture
219+
// types to optimise
220+
mlir::Value mapOp = createMapInfoOp(
221+
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
222+
asFortran.str(), bounds, {},
223+
static_cast<
224+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
225+
mapTypeBits),
226+
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
227+
228+
mapOperands.push_back(mapOp);
229+
}
230+
});
231+
}
232+
233+
template <typename... Ts>
234+
void ClauseProcessor::processTODO(mlir::Location currentLocation,
235+
llvm::omp::Directive directive) const {
236+
auto checkUnhandledClause = [&](const auto *x) {
237+
if (!x)
238+
return;
239+
TODO(currentLocation,
240+
"Unhandled clause " +
241+
llvm::StringRef(Fortran::parser::ParseTreeDumper::GetNodeName(*x))
242+
.upper() +
243+
" in " + llvm::omp::getOpenMPDirectiveName(directive).upper() +
244+
" construct");
245+
};
246+
247+
for (ClauseIterator it = clauses.v.begin(); it != clauses.v.end(); ++it)
248+
(checkUnhandledClause(std::get_if<Ts>(&it->u)), ...);
249+
}
250+
251+
template <typename T>
252+
ClauseProcessor::ClauseIterator
253+
ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) {
254+
for (ClauseIterator it = begin; it != end; ++it) {
255+
if (std::get_if<T>(&it->u))
256+
return it;
257+
}
258+
259+
return end;
260+
}
261+
262+
template <typename T>
263+
const T *ClauseProcessor::findUniqueClause(
264+
const Fortran::parser::CharBlock **source) const {
265+
ClauseIterator it = findClause<T>(clauses.v.begin(), clauses.v.end());
266+
if (it != clauses.v.end()) {
267+
if (source)
268+
*source = &it->source;
269+
return &std::get<T>(it->u);
270+
}
271+
return nullptr;
272+
}
273+
274+
template <typename T>
275+
bool ClauseProcessor::findRepeatableClause(
276+
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
277+
callbackFn) const {
278+
bool found = false;
279+
ClauseIterator nextIt, endIt = clauses.v.end();
280+
for (ClauseIterator it = clauses.v.begin(); it != endIt; it = nextIt) {
281+
nextIt = findClause<T>(it, endIt);
282+
283+
if (nextIt != endIt) {
284+
callbackFn(&std::get<T>(nextIt->u), nextIt->source);
285+
found = true;
286+
++nextIt;
287+
}
288+
}
289+
return found;
290+
}
291+
292+
template <typename T>
293+
bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const {
294+
if (findUniqueClause<T>()) {
295+
result = converter.getFirOpBuilder().getUnitAttr();
296+
return true;
297+
}
298+
return false;
299+
}
300+
301+
} // namespace omp
302+
} // namespace lower
303+
} // namespace Fortran
304+
305+
#endif // FORTRAN_LOWER_CLAUASEPROCESSOR_H

0 commit comments

Comments
 (0)