Skip to content

Commit 8d8470f

Browse files
committed
[MLIR][OpenMP] Add OMP Mapper field to MapInfoOp (llvm#120994)
This patch adds the mapper field to the omp.map.info op. Depends on llvm#117046.
1 parent 6820bf7 commit 8d8470f

File tree

9 files changed

+35
-7
lines changed

9 files changed

+35
-7
lines changed

flang/include/flang/Lower/OpenMP/Utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
116116
llvm::ArrayRef<mlir::Value> members,
117117
mlir::ArrayAttr membersIndex, uint64_t mapType,
118118
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
119-
bool partialMap = false);
119+
bool partialMap = false,
120+
mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr());
120121

121122
void insertChildMapInfoIntoParent(
122123
Fortran::lower::AbstractConverter &converter,

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
130130
llvm::ArrayRef<mlir::Value> members,
131131
mlir::ArrayAttr membersIndex, uint64_t mapType,
132132
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
133-
bool partialMap) {
133+
bool partialMap, mlir::FlatSymbolRefAttr mapperId) {
134134
if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
135135
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
136136
retTy = baseAddr.getType();
@@ -149,6 +149,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
149149
mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
150150
loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
151151
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
152+
mapperId,
152153
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
153154
builder.getStringAttr(name), builder.getBoolAttr(partialMap));
154155
return op;

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ mlir::omp::MapInfoOp createMapInfoOp(
5151
mlir::Value varPtrPtr, std::string name, llvm::ArrayRef<mlir::Value> bounds,
5252
llvm::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex,
5353
uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType,
54-
mlir::Type retTy, bool partialMap = false) {
54+
mlir::Type retTy, bool partialMap = false,
55+
mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr()) {
5556
if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
5657
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
5758
retTy = baseAddr.getType();
@@ -70,6 +71,7 @@ mlir::omp::MapInfoOp createMapInfoOp(
7071
mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
7172
loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
7273
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
74+
mapperId,
7375
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
7476
builder.getStringAttr(name), builder.getBoolAttr(partialMap));
7577

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ class MapInfoFinalizationPass
184184
/*members=*/mlir::SmallVector<mlir::Value>{},
185185
/*membersIndex=*/mlir::ArrayAttr{}, bounds,
186186
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
187+
/*mapperId*/ mlir::FlatSymbolRefAttr(),
187188
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
188189
mlir::omp::VariableCaptureKind::ByRef),
189190
/*name=*/builder.getStringAttr(""),
@@ -331,7 +332,8 @@ class MapInfoFinalizationPass
331332
builder.getIntegerAttr(
332333
builder.getIntegerType(64, false),
333334
getDescriptorMapType(op.getMapType().value_or(0), target)),
334-
op.getMapCaptureTypeAttr(), op.getNameAttr(),
335+
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getMapCaptureTypeAttr(),
336+
op.getNameAttr(),
335337
/*partial_map=*/builder.getBoolAttr(false));
336338
op.replaceAllUsesWith(newDescParentMapOp.getResult());
337339
op->erase();
@@ -629,6 +631,7 @@ class MapInfoFinalizationPass
629631
// /*members=*/mlir::ValueRange{},
630632
// /*members_index=*/mlir::ArrayAttr{},
631633
// /*bounds=*/bounds, op.getMapTypeAttr(),
634+
// /*mapperId*/ mlir::FlatSymbolRefAttr(),
632635
// builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
633636
// mlir::omp::VariableCaptureKind::ByRef),
634637
// builder.getStringAttr(op.getNameAttr().strref() + "." +

flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class MapsForPrivatizedSymbolsPass
9191
/*bounds=*/ValueRange{},
9292
builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
9393
mapTypeTo),
94+
/*mapperId*/ mlir::FlatSymbolRefAttr(),
9495
builder.getAttr<omp::VariableCaptureKindAttr>(
9596
omp::VariableCaptureKind::ByRef),
9697
StringAttr(), builder.getBoolAttr(false));

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
10231023
OptionalAttr<IndexListArrayAttr>:$members_index,
10241024
Variadic<OpenMP_MapBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
10251025
OptionalAttr<UI64Attr>:$map_type,
1026+
OptionalAttr<FlatSymbolRefAttr>:$mapper_id,
10261027
OptionalAttr<VariableCaptureKindAttr>:$map_capture_type,
10271028
OptionalAttr<StrAttr>:$name,
10281029
DefaultValuedAttr<BoolAttr, "false">:$partial_map);
@@ -1076,6 +1077,8 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
10761077
- 'map_type': OpenMP map type for this map capture, for example: from, to and
10771078
always. It's a bitfield composed of the OpenMP runtime flags stored in
10781079
OpenMPOffloadMappingFlags.
1080+
- 'mapper_id': OpenMP mapper map type modifier for this map capture. It's used to
1081+
specify a user defined mapper to be used for mapping.
10791082
- 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla
10801083
this can affect how the variable is lowered.
10811084
- `name`: Holds the name of variable as specified in user clause (including bounds).
@@ -1087,6 +1090,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
10871090
`var_ptr` `(` $var_ptr `:` type($var_ptr) `,` $var_type `)`
10881091
oilist(
10891092
`var_ptr_ptr` `(` $var_ptr_ptr `:` type($var_ptr_ptr) `)`
1093+
| `mapper` `(` $mapper_id `)`
10901094
| `map_clauses` `(` custom<MapClause>($map_type) `)`
10911095
| `capture` `(` custom<CaptureType>($map_capture_type) `)`
10921096
| `members` `(` $members `:` custom<MembersIndex>($members_index) `:` type($members) `)`

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1639,7 +1639,13 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
16391639

16401640
to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
16411641
}
1642-
} else {
1642+
1643+
if (mapInfoOp.getMapperId() &&
1644+
!SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
1645+
mapInfoOp, mapInfoOp.getMapperIdAttr())) {
1646+
return emitError(op->getLoc(), "invalid mapper id");
1647+
}
1648+
} else if (!isa<DeclareMapperInfoOp>(op)) {
16431649
emitError(op->getLoc(), "map argument is not a map entry operation");
16441650
}
16451651
}

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2843,3 +2843,13 @@ func.func @missing_workshare(%idx : index) {
28432843
^bb0(%arg0: !llvm.ptr):
28442844
omp.terminator
28452845
}
2846+
2847+
// -----
2848+
llvm.func @invalid_mapper(%0 : !llvm.ptr) {
2849+
%1 = omp.map.info var_ptr(%0 : !llvm.ptr, !llvm.struct<"my_type", (i32)>) mapper(@my_mapper) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
2850+
// expected-error @below {{invalid mapper id}}
2851+
omp.target_data map_entries(%1 : !llvm.ptr) {
2852+
omp.terminator
2853+
}
2854+
llvm.return
2855+
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,13 +2546,13 @@ func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
25462546
// CHECK: %[[C_12:.*]] = llvm.mlir.constant(2 : index) : i64
25472547
// CHECK: %[[C_13:.*]] = llvm.mlir.constant(2 : index) : i64
25482548
// CHECK: %[[BOUNDS1:.*]] = omp.map.bounds lower_bound(%[[C_11]] : i64) upper_bound(%[[C_10]] : i64) stride(%[[C_12]] : i64) start_idx(%[[C_13]] : i64)
2549-
// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr {name = ""}
2549+
// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.array<10 x i32>) mapper(@my_mapper) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr {name = ""}
25502550
%6 = llvm.mlir.constant(9 : index) : i64
25512551
%7 = llvm.mlir.constant(1 : index) : i64
25522552
%8 = llvm.mlir.constant(2 : index) : i64
25532553
%9 = llvm.mlir.constant(2 : index) : i64
25542554
%10 = omp.map.bounds lower_bound(%7 : i64) upper_bound(%6 : i64) stride(%8 : i64) start_idx(%9 : i64)
2555-
%mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr {name = ""}
2555+
%mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.array<10 x i32>) mapper(@my_mapper) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr {name = ""}
25562556

25572557
// CHECK: omp.target map_entries(%[[MAP0]] -> {{.*}}, %[[MAP1]] -> {{.*}} : !llvm.ptr, !llvm.ptr)
25582558
omp.target map_entries(%mapv1 -> %arg2, %mapv2 -> %arg3 : !llvm.ptr, !llvm.ptr) {

0 commit comments

Comments
 (0)