Skip to content

Commit 435e850

Browse files
committed
[Flang][OpenMP][MLIR] Initial derived type member map support
This patch is one in a series of four patches that seeks to refactor slightly and extend the current record type map support that was put in place for Fortran's descriptor types to handle explicit member mapping for record types at a single level of depth. For example, the below case where two members of a Fortran derived type are mapped explicitly: '''' type :: scalar_and_array real(4) :: real integer(4) :: array(10) integer(4) :: int end type scalar_and_array type(scalar_and_array) :: scalar_arr !$omp target map(tofrom: scalar_arr%int, scalar_arr%real) '''' Current cases of derived type mapping left for future work are: > explicit member mapping of nested members (e.g. two layers of record types where we explicitly map a member from the internal record type) > Fortran's automagical mapping of all elements and nested elements of a derived type > explicit member mapping of a derived type and then constituient members (redundant in Fortran due to former case but still legal as far as I am aware) > explicit member mapping of a record type (may be handled reasonably, just not fully tested in this iteration) > explicit member mapping for Fortran allocatable types (a variation of nested record types) This patch seeks to support this by extending the Flang-new OpenMP lowering to support generation of this newly required information, creating the neccessary parent <-to-> member map_info links, calculating the member indices and setting if it's a partial map. The OMPDescriptorMapInfoGen pass has also been generalized into a map finalization phase, now named OMPMapInfoFinalization. This pass was extended to support the insertion of member maps into the BlockArg and MapOperands of relevant map carrying operations. Similar to the method in which descriptor types are expanded and constituient members inserted. Pull Request: #82853
1 parent 462435f commit 435e850

22 files changed

+1223
-283
lines changed

flang/docs/OpenMP-descriptor-management.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Currently, Flang will lower these descriptor types in the OpenMP lowering (lower
4444
to all other map types, generating an omp.MapInfoOp containing relevant information required for lowering
4545
the OpenMP dialect to LLVM-IR during the final stages of the MLIR lowering. However, after
4646
the lowering to FIR/HLFIR has been performed an OpenMP dialect specific pass for Fortran,
47-
`OMPDescriptorMapInfoGenPass` (Optimizer/OMPDescriptorMapInfoGen.cpp) will expand the
47+
`OMPMapInfoFinalizationPass` (Optimizer/OMPMapInfoFinalization.cpp) will expand the
4848
`omp.MapInfoOp`'s containing descriptors (which currently will be a `BoxType` or `BoxAddrOp`) into multiple
4949
mappings, with one extra per pointer member in the descriptor that is supported on top of the original
5050
descriptor map operation. These pointers members are linked to the parent descriptor by adding them to
@@ -53,7 +53,7 @@ owning operation's (`omp.TargetOp`, `omp.TargetDataOp` etc.) map operand list an
5353
operation is `IsolatedFromAbove`, it also inserts them as `BlockArgs` to canonicalize the mappings and
5454
simplify lowering.
5555
56-
An example transformation by the `OMPDescriptorMapInfoGenPass`:
56+
An example transformation by the `OMPMapInfoFinalizationPass`:
5757
5858
```
5959

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ std::unique_ptr<mlir::Pass> createAlgebraicSimplificationPass();
6868
std::unique_ptr<mlir::Pass>
6969
createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config);
7070

71-
std::unique_ptr<mlir::Pass> createOMPDescriptorMapInfoGenPass();
71+
std::unique_ptr<mlir::Pass> createOMPMapInfoFinalizationPass();
7272
std::unique_ptr<mlir::Pass> createOMPFunctionFilteringPass();
7373
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
7474
createOMPMarkDeclareTargetPass();

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,15 +321,15 @@ def LoopVersioning : Pass<"loop-versioning", "mlir::func::FuncOp"> {
321321
let dependentDialects = [ "fir::FIROpsDialect" ];
322322
}
323323

324-
def OMPDescriptorMapInfoGenPass
325-
: Pass<"omp-descriptor-map-info-gen", "mlir::func::FuncOp"> {
324+
def OMPMapInfoFinalizationPass
325+
: Pass<"omp-map-info-finalization", "mlir::func::FuncOp"> {
326326
let summary = "expands OpenMP MapInfo operations containing descriptors";
327327
let description = [{
328328
Expands MapInfo operations containing descriptor types into multiple
329329
MapInfo's for each pointer element in the descriptor that requires
330330
explicit individual mapping by the OpenMP runtime.
331331
}];
332-
let constructor = "::fir::createOMPDescriptorMapInfoGenPass()";
332+
let constructor = "::fir::createOMPMapInfoFinalizationPass()";
333333
let dependentDialects = ["mlir::omp::OpenMPDialect"];
334334
}
335335

flang/include/flang/Tools/CLOptions.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ inline void createHLFIRToFIRPassPipeline(
335335
/// rather than the host device.
336336
inline void createOpenMPFIRPassPipeline(
337337
mlir::PassManager &pm, bool isTargetDevice) {
338-
pm.addPass(fir::createOMPDescriptorMapInfoGenPass());
338+
pm.addPass(fir::createOMPMapInfoFinalizationPass());
339339
pm.addPass(fir::createOMPMarkDeclareTargetPass());
340340
if (isTargetDevice)
341341
pm.addPass(fir::createOMPFunctionFilteringPass());

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -814,38 +814,24 @@ bool ClauseProcessor::processLink(
814814
});
815815
}
816816

817-
mlir::omp::MapInfoOp
818-
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
819-
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
820-
llvm::ArrayRef<mlir::Value> bounds,
821-
llvm::ArrayRef<mlir::Value> members, uint64_t mapType,
822-
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
823-
bool isVal) {
824-
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
825-
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
826-
retTy = baseAddr.getType();
827-
}
828-
829-
mlir::TypeAttr varType = mlir::TypeAttr::get(
830-
llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
831-
832-
mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
833-
loc, retTy, baseAddr, varType, varPtrPtr, members, bounds,
834-
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
835-
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
836-
builder.getStringAttr(name));
837-
838-
return op;
839-
}
840-
841817
bool ClauseProcessor::processMap(
842818
mlir::Location currentLocation, Fortran::lower::StatementContext &stmtCtx,
843819
mlir::omp::MapClauseOps &result,
844820
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms,
845821
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
846822
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
847823
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
848-
return findRepeatableClause<omp::clause::Map>(
824+
// We always require tracking of symbols, even if the caller does not,
825+
// so we create an optionally used local set of symbols when the mapSyms
826+
// argument is not present.
827+
llvm::SmallVector<const Fortran::semantics::Symbol *> localMapSyms;
828+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *ptrMapSyms =
829+
mapSyms ? mapSyms : &localMapSyms;
830+
std::map<const Fortran::semantics::Symbol *,
831+
llvm::SmallVector<OmpMapMemberIndicesData>>
832+
parentMemberIndices;
833+
834+
bool clauseFound = findRepeatableClause<omp::clause::Map>(
849835
[&](const omp::clause::Map &clause,
850836
const Fortran::parser::CharBlock &source) {
851837
using Map = omp::clause::Map;
@@ -910,24 +896,33 @@ bool ClauseProcessor::processMap(
910896
// Explicit map captures are captured ByRef by default,
911897
// optimisation passes may alter this to ByCopy or other capture
912898
// types to optimise
913-
mlir::Value mapOp = createMapInfoOp(
914-
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
915-
asFortran.str(), bounds, {},
899+
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
900+
firOpBuilder, clauseLocation, symAddr,
901+
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
902+
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
916903
static_cast<
917904
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
918905
mapTypeBits),
919906
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
920907

921-
result.mapVars.push_back(mapOp);
922-
923-
if (mapSyms)
924-
mapSyms->push_back(object.id());
925-
if (mapSymLocs)
926-
mapSymLocs->push_back(symAddr.getLoc());
927-
if (mapSymTypes)
928-
mapSymTypes->push_back(symAddr.getType());
908+
if (object.id()->owner().IsDerivedType()) {
909+
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
910+
semaCtx);
911+
} else {
912+
result.mapVars.push_back(mapOp);
913+
ptrMapSyms->push_back(object.id());
914+
if (mapSymTypes)
915+
mapSymTypes->push_back(symAddr.getType());
916+
if (mapSymLocs)
917+
mapSymLocs->push_back(symAddr.getLoc());
918+
}
929919
}
930920
});
921+
922+
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
923+
*ptrMapSyms, mapSymTypes, mapSymLocs);
924+
925+
return clauseFound;
931926
}
932927

933928
bool ClauseProcessor::processReduction(

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,12 @@ template <typename T>
185185
bool ClauseProcessor::processMotionClauses(
186186
Fortran::lower::StatementContext &stmtCtx,
187187
mlir::omp::MapClauseOps &result) {
188-
return findRepeatableClause<T>(
188+
std::map<const Fortran::semantics::Symbol *,
189+
llvm::SmallVector<OmpMapMemberIndicesData>>
190+
parentMemberIndices;
191+
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
192+
193+
bool clauseFound = findRepeatableClause<T>(
189194
[&](const T &clause, const Fortran::parser::CharBlock &source) {
190195
mlir::Location clauseLocation = converter.genLocation(source);
191196
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -203,6 +208,7 @@ bool ClauseProcessor::processMotionClauses(
203208
for (const omp::Object &object : objects) {
204209
llvm::SmallVector<mlir::Value> bounds;
205210
std::stringstream asFortran;
211+
206212
Fortran::lower::AddrAndBoundsInfo info =
207213
Fortran::lower::gatherDataOperandAddrAndBounds<
208214
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
@@ -218,17 +224,29 @@ bool ClauseProcessor::processMotionClauses(
218224
// Explicit map captures are captured ByRef by default,
219225
// optimisation passes may alter this to ByCopy or other capture
220226
// types to optimise
221-
mlir::Value mapOp = createMapInfoOp(
222-
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
223-
asFortran.str(), bounds, {},
227+
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
228+
firOpBuilder, clauseLocation, symAddr,
229+
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
230+
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
224231
static_cast<
225232
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
226233
mapTypeBits),
227234
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
228235

229-
result.mapVars.push_back(mapOp);
236+
if (object.id()->owner().IsDerivedType()) {
237+
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
238+
semaCtx);
239+
} else {
240+
result.mapVars.push_back(mapOp);
241+
mapSymbols.push_back(object.id());
242+
}
230243
}
231244
});
245+
246+
insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
247+
mapSymbols,
248+
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr);
249+
return clauseFound;
232250
}
233251

234252
template <typename... Ts>

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -939,8 +939,10 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
939939
std::stringstream name;
940940
firOpBuilder.setInsertionPoint(targetOp);
941941
mlir::Value mapOp = createMapInfoOp(
942-
firOpBuilder, copyVal.getLoc(), copyVal, mlir::Value{}, name.str(),
943-
bounds, llvm::SmallVector<mlir::Value>{},
942+
firOpBuilder, copyVal.getLoc(), copyVal,
943+
/*varPtrPtr=*/mlir::Value{}, name.str(), bounds,
944+
/*members=*/llvm::SmallVector<mlir::Value>{},
945+
/*membersIndex=*/mlir::DenseIntElementsAttr{},
944946
static_cast<
945947
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
946948
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
@@ -1637,8 +1639,9 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
16371639
}
16381640

16391641
mlir::Value mapOp = createMapInfoOp(
1640-
firOpBuilder, baseOp.getLoc(), baseOp, mlir::Value{}, name.str(),
1641-
bounds, {},
1642+
firOpBuilder, baseOp.getLoc(), baseOp, /*varPtrPtr=*/mlir::Value{},
1643+
name.str(), bounds, /*members=*/{},
1644+
/*membersIndex=*/mlir::DenseIntElementsAttr{},
16421645
static_cast<
16431646
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
16441647
mapFlag),

0 commit comments

Comments
 (0)