Skip to content

[MLIR][OpenMP] Add OMP Mapper field to MapInfoOp #120994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
llvm::ArrayRef<mlir::Value> members,
mlir::ArrayAttr membersIndex, uint64_t mapType,
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
bool partialMap) {
bool partialMap, mlir::FlatSymbolRefAttr mapperId) {
if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
retTy = baseAddr.getType();
Expand All @@ -144,6 +144,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
mapperId,
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
builder.getStringAttr(name), builder.getBoolAttr(partialMap));
return op;
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
llvm::ArrayRef<mlir::Value> members,
mlir::ArrayAttr membersIndex, uint64_t mapType,
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
bool partialMap = false);
bool partialMap = false,
mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr());

void insertChildMapInfoIntoParent(
Fortran::lower::AbstractConverter &converter,
Expand Down
5 changes: 4 additions & 1 deletion flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class MapInfoFinalizationPass
/*members=*/mlir::SmallVector<mlir::Value>{},
/*membersIndex=*/mlir::ArrayAttr{}, bounds,
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
/*mapperId*/ mlir::FlatSymbolRefAttr(),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
mlir::omp::VariableCaptureKind::ByRef),
/*name=*/builder.getStringAttr(""),
Expand Down Expand Up @@ -329,7 +330,8 @@ class MapInfoFinalizationPass
builder.getIntegerAttr(
builder.getIntegerType(64, false),
getDescriptorMapType(op.getMapType().value_or(0), target)),
op.getMapCaptureTypeAttr(), op.getNameAttr(),
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getMapCaptureTypeAttr(),
op.getNameAttr(),
/*partial_map=*/builder.getBoolAttr(false));
op.replaceAllUsesWith(newDescParentMapOp.getResult());
op->erase();
Expand Down Expand Up @@ -623,6 +625,7 @@ class MapInfoFinalizationPass
/*members=*/mlir::ValueRange{},
/*members_index=*/mlir::ArrayAttr{},
/*bounds=*/bounds, op.getMapTypeAttr(),
/*mapperId*/ mlir::FlatSymbolRefAttr(),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
mlir::omp::VariableCaptureKind::ByRef),
builder.getStringAttr(op.getNameAttr().strref() + "." +
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class MapsForPrivatizedSymbolsPass
/*bounds=*/ValueRange{},
builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
mapTypeTo),
/*mapperId*/ mlir::FlatSymbolRefAttr(),
builder.getAttr<omp::VariableCaptureKindAttr>(
omp::VariableCaptureKind::ByRef),
StringAttr(), builder.getBoolAttr(false));
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
OptionalAttr<IndexListArrayAttr>:$members_index,
Variadic<OpenMP_MapBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
OptionalAttr<UI64Attr>:$map_type,
OptionalAttr<FlatSymbolRefAttr>:$mapper_id,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a description for this argument.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it can only be the name of a declare mapper then it might be good to add that to the verifier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, I was on vacation.

I've updated the description and also added a verifier check.

OptionalAttr<VariableCaptureKindAttr>:$map_capture_type,
OptionalAttr<StrAttr>:$name,
DefaultValuedAttr<BoolAttr, "false">:$partial_map);
Expand Down Expand Up @@ -1076,6 +1077,8 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
- 'map_type': OpenMP map type for this map capture, for example: from, to and
always. It's a bitfield composed of the OpenMP runtime flags stored in
OpenMPOffloadMappingFlags.
- 'mapper_id': OpenMP mapper map type modifier for this map capture. It's used to
specify a user defined mapper to be used for mapping.
- 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla
this can affect how the variable is lowered.
- `name`: Holds the name of variable as specified in user clause (including bounds).
Expand All @@ -1087,6 +1090,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
`var_ptr` `(` $var_ptr `:` type($var_ptr) `,` $var_type `)`
oilist(
`var_ptr_ptr` `(` $var_ptr_ptr `:` type($var_ptr_ptr) `)`
| `mapper` `(` $mapper_id `)`
| `map_clauses` `(` custom<MapClause>($map_type) `)`
| `capture` `(` custom<CaptureType>($map_capture_type) `)`
| `members` `(` $members `:` custom<MembersIndex>($members_index) `:` type($members) `)`
Expand Down
8 changes: 7 additions & 1 deletion mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,13 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {

to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
}
} else {

if (mapInfoOp.getMapperId() &&
!SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
mapInfoOp, mapInfoOp.getMapperIdAttr())) {
return emitError(op->getLoc(), "invalid mapper id");
}
} else if (!isa<DeclareMapperInfoOp>(op)) {
emitError(op->getLoc(), "map argument is not a map entry operation");
}
}
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2849,3 +2849,13 @@ func.func @missing_workshare(%idx : index) {
^bb0(%arg0: !llvm.ptr):
omp.terminator
}

// -----
llvm.func @invalid_mapper(%0 : !llvm.ptr) {
%1 = omp.map.info var_ptr(%0 : !llvm.ptr, !llvm.struct<"my_type", (i32)>) mapper(@my_mapper) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
// expected-error @below {{invalid mapper id}}
omp.target_data map_entries(%1 : !llvm.ptr) {
omp.terminator
}
llvm.return
}
4 changes: 2 additions & 2 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2546,13 +2546,13 @@ func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
// CHECK: %[[C_12:.*]] = llvm.mlir.constant(2 : index) : i64
// CHECK: %[[C_13:.*]] = llvm.mlir.constant(2 : index) : i64
// CHECK: %[[BOUNDS1:.*]] = omp.map.bounds lower_bound(%[[C_11]] : i64) upper_bound(%[[C_10]] : i64) stride(%[[C_12]] : i64) start_idx(%[[C_13]] : i64)
// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr {name = ""}
// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.array<10 x i32>) mapper(@my_mapper) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr {name = ""}
%6 = llvm.mlir.constant(9 : index) : i64
%7 = llvm.mlir.constant(1 : index) : i64
%8 = llvm.mlir.constant(2 : index) : i64
%9 = llvm.mlir.constant(2 : index) : i64
%10 = omp.map.bounds lower_bound(%7 : i64) upper_bound(%6 : i64) stride(%8 : i64) start_idx(%9 : i64)
%mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr {name = ""}
%mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.array<10 x i32>) mapper(@my_mapper) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr {name = ""}

// CHECK: omp.target map_entries(%[[MAP0]] -> {{.*}}, %[[MAP1]] -> {{.*}} : !llvm.ptr, !llvm.ptr)
omp.target map_entries(%mapv1 -> %arg2, %mapv2 -> %arg3 : !llvm.ptr, !llvm.ptr) {
Expand Down