|
| 1 | +//===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | +#ifndef FORTRAN_LOWER_CLAUASEPROCESSOR_H |
| 13 | +#define FORTRAN_LOWER_CLAUASEPROCESSOR_H |
| 14 | + |
| 15 | +#include "DirectivesCommon.h" |
| 16 | +#include "ReductionProcessor.h" |
| 17 | +#include "Utils.h" |
| 18 | +#include "flang/Lower/AbstractConverter.h" |
| 19 | +#include "flang/Lower/Bridge.h" |
| 20 | +#include "flang/Optimizer/Builder/Todo.h" |
| 21 | +#include "flang/Parser/dump-parse-tree.h" |
| 22 | +#include "flang/Parser/parse-tree.h" |
| 23 | +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
| 24 | + |
| 25 | +namespace fir { |
| 26 | +class FirOpBuilder; |
| 27 | +} // namespace fir |
| 28 | + |
| 29 | +namespace Fortran { |
| 30 | +namespace lower { |
| 31 | +namespace omp { |
| 32 | + |
| 33 | +/// Class that handles the processing of OpenMP clauses. |
| 34 | +/// |
| 35 | +/// Its `process<ClauseName>()` methods perform MLIR code generation for their |
| 36 | +/// corresponding clause if it is present in the clause list. Otherwise, they |
| 37 | +/// will return `false` to signal that the clause was not found. |
| 38 | +/// |
| 39 | +/// The intended use is of this class is to move clause processing outside of |
| 40 | +/// construct processing, since the same clauses can appear attached to |
| 41 | +/// different constructs and constructs can be combined, so that code |
| 42 | +/// duplication is minimized. |
| 43 | +/// |
| 44 | +/// Each construct-lowering function only calls the `process<ClauseName>()` |
| 45 | +/// methods that relate to clauses that can impact the lowering of that |
| 46 | +/// construct. |
| 47 | +class ClauseProcessor { |
| 48 | + using ClauseTy = Fortran::parser::OmpClause; |
| 49 | + |
| 50 | +public: |
| 51 | + ClauseProcessor(Fortran::lower::AbstractConverter &converter, |
| 52 | + Fortran::semantics::SemanticsContext &semaCtx, |
| 53 | + const Fortran::parser::OmpClauseList &clauses) |
| 54 | + : converter(converter), semaCtx(semaCtx), clauses(clauses) {} |
| 55 | + |
| 56 | + // 'Unique' clauses: They can appear at most once in the clause list. |
| 57 | + bool |
| 58 | + processCollapse(mlir::Location currentLocation, |
| 59 | + Fortran::lower::pft::Evaluation &eval, |
| 60 | + llvm::SmallVectorImpl<mlir::Value> &lowerBound, |
| 61 | + llvm::SmallVectorImpl<mlir::Value> &upperBound, |
| 62 | + llvm::SmallVectorImpl<mlir::Value> &step, |
| 63 | + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv, |
| 64 | + std::size_t &loopVarTypeSize) const; |
| 65 | + bool processDefault() const; |
| 66 | + bool processDevice(Fortran::lower::StatementContext &stmtCtx, |
| 67 | + mlir::Value &result) const; |
| 68 | + bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const; |
| 69 | + bool processFinal(Fortran::lower::StatementContext &stmtCtx, |
| 70 | + mlir::Value &result) const; |
| 71 | + bool processHint(mlir::IntegerAttr &result) const; |
| 72 | + bool processMergeable(mlir::UnitAttr &result) const; |
| 73 | + bool processNowait(mlir::UnitAttr &result) const; |
| 74 | + bool processNumTeams(Fortran::lower::StatementContext &stmtCtx, |
| 75 | + mlir::Value &result) const; |
| 76 | + bool processNumThreads(Fortran::lower::StatementContext &stmtCtx, |
| 77 | + mlir::Value &result) const; |
| 78 | + bool processOrdered(mlir::IntegerAttr &result) const; |
| 79 | + bool processPriority(Fortran::lower::StatementContext &stmtCtx, |
| 80 | + mlir::Value &result) const; |
| 81 | + bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const; |
| 82 | + bool processSafelen(mlir::IntegerAttr &result) const; |
| 83 | + bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr, |
| 84 | + mlir::omp::ScheduleModifierAttr &modifierAttr, |
| 85 | + mlir::UnitAttr &simdModifierAttr) const; |
| 86 | + bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx, |
| 87 | + mlir::Value &result) const; |
| 88 | + bool processSimdlen(mlir::IntegerAttr &result) const; |
| 89 | + bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx, |
| 90 | + mlir::Value &result) const; |
| 91 | + bool processUntied(mlir::UnitAttr &result) const; |
| 92 | + |
| 93 | + // 'Repeatable' clauses: They can appear multiple times in the clause list. |
| 94 | + bool |
| 95 | + processAllocate(llvm::SmallVectorImpl<mlir::Value> &allocatorOperands, |
| 96 | + llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const; |
| 97 | + bool processCopyin() const; |
| 98 | + bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands, |
| 99 | + llvm::SmallVectorImpl<mlir::Value> &dependOperands) const; |
| 100 | + bool |
| 101 | + processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
| 102 | + bool |
| 103 | + processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, |
| 104 | + mlir::Value &result) const; |
| 105 | + bool |
| 106 | + processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
| 107 | + |
| 108 | + // This method is used to process a map clause. |
| 109 | + // The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to |
| 110 | + // store the original type, location and Fortran symbol for the map operands. |
| 111 | + // They may be used later on to create the block_arguments for some of the |
| 112 | + // target directives that require it. |
| 113 | + bool processMap(mlir::Location currentLocation, |
| 114 | + const llvm::omp::Directive &directive, |
| 115 | + Fortran::lower::StatementContext &stmtCtx, |
| 116 | + llvm::SmallVectorImpl<mlir::Value> &mapOperands, |
| 117 | + llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr, |
| 118 | + llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr, |
| 119 | + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
| 120 | + *mapSymbols = nullptr) const; |
| 121 | + bool |
| 122 | + processReduction(mlir::Location currentLocation, |
| 123 | + llvm::SmallVectorImpl<mlir::Value> &reductionVars, |
| 124 | + llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, |
| 125 | + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
| 126 | + *reductionSymbols = nullptr) const; |
| 127 | + bool processSectionsReduction(mlir::Location currentLocation) const; |
| 128 | + bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
| 129 | + bool |
| 130 | + processUseDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands, |
| 131 | + llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, |
| 132 | + llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, |
| 133 | + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
| 134 | + &useDeviceSymbols) const; |
| 135 | + bool |
| 136 | + processUseDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands, |
| 137 | + llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, |
| 138 | + llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, |
| 139 | + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
| 140 | + &useDeviceSymbols) const; |
| 141 | + |
| 142 | + template <typename T> |
| 143 | + bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx, |
| 144 | + llvm::SmallVectorImpl<mlir::Value> &mapOperands); |
| 145 | + |
| 146 | + // Call this method for these clauses that should be supported but are not |
| 147 | + // implemented yet. It triggers a compilation error if any of the given |
| 148 | + // clauses is found. |
| 149 | + template <typename... Ts> |
| 150 | + void processTODO(mlir::Location currentLocation, |
| 151 | + llvm::omp::Directive directive) const; |
| 152 | + |
| 153 | +private: |
| 154 | + using ClauseIterator = std::list<ClauseTy>::const_iterator; |
| 155 | + |
| 156 | + /// Utility to find a clause within a range in the clause list. |
| 157 | + template <typename T> |
| 158 | + static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end); |
| 159 | + |
| 160 | + /// Return the first instance of the given clause found in the clause list or |
| 161 | + /// `nullptr` if not present. If more than one instance is expected, use |
| 162 | + /// `findRepeatableClause` instead. |
| 163 | + template <typename T> |
| 164 | + const T * |
| 165 | + findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const; |
| 166 | + |
| 167 | + /// Call `callbackFn` for each occurrence of the given clause. Return `true` |
| 168 | + /// if at least one instance was found. |
| 169 | + template <typename T> |
| 170 | + bool findRepeatableClause( |
| 171 | + std::function<void(const T *, const Fortran::parser::CharBlock &source)> |
| 172 | + callbackFn) const; |
| 173 | + |
| 174 | + /// Set the `result` to a new `mlir::UnitAttr` if the clause is present. |
| 175 | + template <typename T> |
| 176 | + bool markClauseOccurrence(mlir::UnitAttr &result) const; |
| 177 | + |
| 178 | + Fortran::lower::AbstractConverter &converter; |
| 179 | + Fortran::semantics::SemanticsContext &semaCtx; |
| 180 | + const Fortran::parser::OmpClauseList &clauses; |
| 181 | +}; |
| 182 | + |
| 183 | +template <typename T> |
| 184 | +bool ClauseProcessor::processMotionClauses( |
| 185 | + Fortran::lower::StatementContext &stmtCtx, |
| 186 | + llvm::SmallVectorImpl<mlir::Value> &mapOperands) { |
| 187 | + return findRepeatableClause<T>( |
| 188 | + [&](const T *motionClause, const Fortran::parser::CharBlock &source) { |
| 189 | + mlir::Location clauseLocation = converter.genLocation(source); |
| 190 | + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
| 191 | + |
| 192 | + static_assert(std::is_same_v<T, ClauseProcessor::ClauseTy::To> || |
| 193 | + std::is_same_v<T, ClauseProcessor::ClauseTy::From>); |
| 194 | + |
| 195 | + // TODO Support motion modifiers: present, mapper, iterator. |
| 196 | + constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = |
| 197 | + std::is_same_v<T, ClauseProcessor::ClauseTy::To> |
| 198 | + ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
| 199 | + : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; |
| 200 | + |
| 201 | + for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) { |
| 202 | + llvm::SmallVector<mlir::Value> bounds; |
| 203 | + std::stringstream asFortran; |
| 204 | + Fortran::lower::AddrAndBoundsInfo info = |
| 205 | + Fortran::lower::gatherDataOperandAddrAndBounds< |
| 206 | + Fortran::parser::OmpObject, mlir::omp::DataBoundsOp, |
| 207 | + mlir::omp::DataBoundsType>( |
| 208 | + converter, firOpBuilder, semaCtx, stmtCtx, ompObject, |
| 209 | + clauseLocation, asFortran, bounds, treatIndexAsSection); |
| 210 | + |
| 211 | + auto origSymbol = |
| 212 | + converter.getSymbolAddress(*getOmpObjectSymbol(ompObject)); |
| 213 | + mlir::Value symAddr = info.addr; |
| 214 | + if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) |
| 215 | + symAddr = origSymbol; |
| 216 | + |
| 217 | + // Explicit map captures are captured ByRef by default, |
| 218 | + // optimisation passes may alter this to ByCopy or other capture |
| 219 | + // types to optimise |
| 220 | + mlir::Value mapOp = createMapInfoOp( |
| 221 | + firOpBuilder, clauseLocation, symAddr, mlir::Value{}, |
| 222 | + asFortran.str(), bounds, {}, |
| 223 | + static_cast< |
| 224 | + std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( |
| 225 | + mapTypeBits), |
| 226 | + mlir::omp::VariableCaptureKind::ByRef, symAddr.getType()); |
| 227 | + |
| 228 | + mapOperands.push_back(mapOp); |
| 229 | + } |
| 230 | + }); |
| 231 | +} |
| 232 | + |
| 233 | +template <typename... Ts> |
| 234 | +void ClauseProcessor::processTODO(mlir::Location currentLocation, |
| 235 | + llvm::omp::Directive directive) const { |
| 236 | + auto checkUnhandledClause = [&](const auto *x) { |
| 237 | + if (!x) |
| 238 | + return; |
| 239 | + TODO(currentLocation, |
| 240 | + "Unhandled clause " + |
| 241 | + llvm::StringRef(Fortran::parser::ParseTreeDumper::GetNodeName(*x)) |
| 242 | + .upper() + |
| 243 | + " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() + |
| 244 | + " construct"); |
| 245 | + }; |
| 246 | + |
| 247 | + for (ClauseIterator it = clauses.v.begin(); it != clauses.v.end(); ++it) |
| 248 | + (checkUnhandledClause(std::get_if<Ts>(&it->u)), ...); |
| 249 | +} |
| 250 | + |
| 251 | +template <typename T> |
| 252 | +ClauseProcessor::ClauseIterator |
| 253 | +ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) { |
| 254 | + for (ClauseIterator it = begin; it != end; ++it) { |
| 255 | + if (std::get_if<T>(&it->u)) |
| 256 | + return it; |
| 257 | + } |
| 258 | + |
| 259 | + return end; |
| 260 | +} |
| 261 | + |
| 262 | +template <typename T> |
| 263 | +const T *ClauseProcessor::findUniqueClause( |
| 264 | + const Fortran::parser::CharBlock **source) const { |
| 265 | + ClauseIterator it = findClause<T>(clauses.v.begin(), clauses.v.end()); |
| 266 | + if (it != clauses.v.end()) { |
| 267 | + if (source) |
| 268 | + *source = &it->source; |
| 269 | + return &std::get<T>(it->u); |
| 270 | + } |
| 271 | + return nullptr; |
| 272 | +} |
| 273 | + |
| 274 | +template <typename T> |
| 275 | +bool ClauseProcessor::findRepeatableClause( |
| 276 | + std::function<void(const T *, const Fortran::parser::CharBlock &source)> |
| 277 | + callbackFn) const { |
| 278 | + bool found = false; |
| 279 | + ClauseIterator nextIt, endIt = clauses.v.end(); |
| 280 | + for (ClauseIterator it = clauses.v.begin(); it != endIt; it = nextIt) { |
| 281 | + nextIt = findClause<T>(it, endIt); |
| 282 | + |
| 283 | + if (nextIt != endIt) { |
| 284 | + callbackFn(&std::get<T>(nextIt->u), nextIt->source); |
| 285 | + found = true; |
| 286 | + ++nextIt; |
| 287 | + } |
| 288 | + } |
| 289 | + return found; |
| 290 | +} |
| 291 | + |
| 292 | +template <typename T> |
| 293 | +bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const { |
| 294 | + if (findUniqueClause<T>()) { |
| 295 | + result = converter.getFirOpBuilder().getUnitAttr(); |
| 296 | + return true; |
| 297 | + } |
| 298 | + return false; |
| 299 | +} |
| 300 | + |
| 301 | +} // namespace omp |
| 302 | +} // namespace lower |
| 303 | +} // namespace Fortran |
| 304 | + |
| 305 | +#endif // FORTRAN_LOWER_CLAUASEPROCESSOR_H |
0 commit comments