Skip to content

[MLIR][OpenMP] Add Lowering support for OpenMP custom mappers in map clause #121001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -969,8 +969,11 @@ void ClauseProcessor::processMapObjects(
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const {
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
llvm::StringRef mapperIdNameRef) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::FlatSymbolRefAttr mapperId;
std::string mapperIdName = mapperIdNameRef.str();

for (const omp::Object &object : objects) {
llvm::SmallVector<mlir::Value> bounds;
Expand Down Expand Up @@ -1003,6 +1006,20 @@ void ClauseProcessor::processMapObjects(
}
}

if (!mapperIdName.empty()) {
if (mapperIdName == "default") {
auto &typeSpec = object.sym()->owner().IsDerivedType()
? *object.sym()->owner().derivedTypeSpec()
: object.sym()->GetType()->derivedTypeSpec();
mapperIdName = typeSpec.name().ToString() + ".default";
mapperIdName = converter.mangleName(mapperIdName, *typeSpec.GetScope());
}
assert(converter.getModuleOp().lookupSymbol(mapperIdName) &&
"mapper not found");
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
mapperIdName);
mapperIdName.clear();
}
// Explicit map captures are captured ByRef by default,
// optimisation passes may alter this to ByCopy or other capture
// types to optimise
Expand All @@ -1016,7 +1033,8 @@ void ClauseProcessor::processMapObjects(
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType(),
/*partialMap=*/false, mapperId);

if (parentObj.has_value()) {
parentMemberIndices[parentObj.value()].addChildIndexAndMapToParent(
Expand Down Expand Up @@ -1047,6 +1065,7 @@ bool ClauseProcessor::processMap(
const auto &[mapType, typeMods, mappers, iterator, objects] = clause.t;
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
std::string mapperIdName;
// If the map type is specified, then process it else Tofrom is the
// default.
Map::MapType type = mapType.value_or(Map::MapType::Tofrom);
Expand Down Expand Up @@ -1090,13 +1109,17 @@ bool ClauseProcessor::processMap(
"Support for iterator modifiers is not implemented yet");
}
if (mappers) {
TODO(currentLocation,
"Support for mapper modifiers is not implemented yet");
assert(mappers->size() == 1 && "more than one mapper");
mapperIdName = mappers->front().v.id().symbol->name().ToString();
if (mapperIdName != "default")
mapperIdName = converter.mangleName(
mapperIdName, mappers->front().v.id().symbol->owner());
}

processMapObjects(stmtCtx, clauseLocation,
std::get<omp::ObjectList>(clause.t), mapTypeBits,
parentMemberIndices, result.mapVars, *ptrMapSyms);
parentMemberIndices, result.mapVars, *ptrMapSyms,
mapperIdName);
};

bool clauseFound = findRepeatableClause<omp::clause::Map>(process);
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ class ClauseProcessor {
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const;
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
llvm::StringRef mapperIdNameRef = "") const;

lower::AbstractConverter &converter;
semantics::SemanticsContext &semaCtx;
Expand Down
16 changes: 0 additions & 16 deletions flang/test/Lower/OpenMP/Todo/map-mapper.f90

This file was deleted.

60 changes: 60 additions & 0 deletions flang/test/Lower/OpenMP/declare-mapper.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
! RUN: split-file %s %t
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-1.f90 -o - | FileCheck %t/omp-declare-mapper-1.f90
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-2.f90 -o - | FileCheck %t/omp-declare-mapper-2.f90
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90

!--- omp-declare-mapper-1.f90
subroutine declare_mapper_1
Expand Down Expand Up @@ -83,3 +84,62 @@ subroutine declare_mapper_2
!CHECK: }
!$omp declare mapper (my_mapper : my_type2 :: v) map (v%arr) map (alloc : v%temp)
end subroutine declare_mapper_2

!--- omp-declare-mapper-3.f90
subroutine declare_mapper_3
type my_type
integer :: num_vals
integer, allocatable :: values(:)
end type

type my_type2
type(my_type) :: my_type_var
real, dimension(250) :: arr
end type

!CHECK: omp.declare_mapper @[[MY_TYPE_MAPPER2:_QQFdeclare_mapper_3my_mapper2]] : [[MY_TYPE2:!fir\.type<_QFdeclare_mapper_3Tmy_type2\{my_type_var:!fir\.type<_QFdeclare_mapper_3Tmy_type\{num_vals:i32,values:!fir\.box<!fir\.heap<!fir\.array<\?xi32>>>}>,arr:!fir\.array<250xf32>}>]] {
!CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<[[MY_TYPE2]]>):
!CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdeclare_mapper_3Ev"} : (!fir.ref<[[MY_TYPE2]]>) -> (!fir.ref<[[MY_TYPE2]]>, !fir.ref<[[MY_TYPE2]]>)
!CHECK: %[[VAL_2:.*]] = hlfir.designate %[[VAL_1]]#0{"my_type_var"} : (!fir.ref<[[MY_TYPE2]]>) -> !fir.ref<[[MY_TYPE:!fir\.type<_QFdeclare_mapper_3Tmy_type\{num_vals:i32,values:!fir\.box<!fir\.heap<!fir\.array<\?xi32>>>}>]]>
!CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) mapper(@[[MY_TYPE_MAPPER:_QQFdeclare_mapper_3my_mapper]]) map_clauses(tofrom) capture(ByRef) -> !fir.ref<[[MY_TYPE]]> {name = "v%[[VAL_4:.*]]"}
!CHECK: %[[VAL_5:.*]] = arith.constant 250 : index
!CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
!CHECK: %[[VAL_7:.*]] = hlfir.designate %[[VAL_1]]#0{"arr"} shape %[[VAL_6]] : (!fir.ref<[[MY_TYPE2]]>, !fir.shape<1>) -> !fir.ref<!fir.array<250xf32>>
!CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
!CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
!CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_5]], %[[VAL_8]] : index
!CHECK: %[[VAL_11:.*]] = omp.map.bounds lower_bound(%[[VAL_9]] : index) upper_bound(%[[VAL_10]] : index) extent(%[[VAL_5]] : index) stride(%[[VAL_8]] : index) start_idx(%[[VAL_8]] : index)
!CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref<!fir.array<250xf32>>, !fir.array<250xf32>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_11]]) -> !fir.ref<!fir.array<250xf32>> {name = "v%[[VAL_13:.*]]"}
!CHECK: %[[VAL_14:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE2]]>, [[MY_TYPE2]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_12]] : [0], [1] : !fir.ref<[[MY_TYPE]]>, !fir.ref<!fir.array<250xf32>>) -> !fir.ref<[[MY_TYPE2]]> {name = "v", partial_map = true}
!CHECK: omp.declare_mapper.info map_entries(%[[VAL_14]], %[[VAL_3]], %[[VAL_12]] : !fir.ref<[[MY_TYPE2]]>, !fir.ref<[[MY_TYPE]]>, !fir.ref<!fir.array<250xf32>>)
!CHECK: }

!CHECK: omp.declare_mapper @[[MY_TYPE_MAPPER]] : [[MY_TYPE]] {
!CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<[[MY_TYPE]]>):
!CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdeclare_mapper_3Evar"} : (!fir.ref<[[MY_TYPE]]>) -> (!fir.ref<[[MY_TYPE]]>, !fir.ref<[[MY_TYPE]]>)
!CHECK: %[[VAL_2:.*]] = hlfir.designate %[[VAL_1]]#0{"values"} {fortran_attrs = #fir.var_attrs<allocatable>} : (!fir.ref<[[MY_TYPE]]>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
!CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
!CHECK: %[[VAL_4:.*]] = fir.box_addr %[[VAL_3]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
!CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
!CHECK: %[[VAL_6:.*]]:3 = fir.box_dims %[[VAL_3]], %[[VAL_5]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
!CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
!CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
!CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
!CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_9]], %[[VAL_6]]#0 : index
!CHECK: %[[VAL_11:.*]] = hlfir.designate %[[VAL_1]]#0{"num_vals"} : (!fir.ref<[[MY_TYPE]]>) -> !fir.ref<i32>
!CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_11]] : !fir.ref<i32>
!CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_12]] : (i32) -> i64
!CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (i64) -> index
!CHECK: %[[VAL_15:.*]] = arith.subi %[[VAL_14]], %[[VAL_6]]#0 : index
!CHECK: %[[VAL_16:.*]] = omp.map.bounds lower_bound(%[[VAL_10]] : index) upper_bound(%[[VAL_15]] : index) extent(%[[VAL_6]]#1 : index) stride(%[[VAL_8]] : index) start_idx(%[[VAL_6]]#0 : index)
!CHECK: %[[VAL_17:.*]] = arith.constant 1 : index
!CHECK: %[[VAL_18:.*]] = fir.coordinate_of %[[VAL_1]]#0, %[[VAL_17]] : (!fir.ref<[[MY_TYPE]]>, index) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
!CHECK: %[[VAL_19:.*]] = fir.box_offset %[[VAL_18]] base_addr : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
!CHECK: %[[VAL_20:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, i32) var_ptr_ptr(%[[VAL_19]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_16]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
!CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.box<!fir.heap<!fir.array<?xi32>>>) map_clauses(to) capture(ByRef) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {name = "var%[[VAL_22:.*]](1:var%[[VAL_23:.*]])"}
!CHECK: %[[VAL_24:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_21]], %[[VAL_20]] : [1], [1, 0] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<[[MY_TYPE]]> {name = "var"}
!CHECK: omp.declare_mapper.info map_entries(%[[VAL_24]], %[[VAL_21]], %[[VAL_20]] : !fir.ref<[[MY_TYPE]]>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>)
!CHECK: }
!$omp declare mapper (my_mapper : my_type :: var) map (var, var%values (1:var%num_vals))
!$omp declare mapper (my_mapper2 : my_type2 :: v) map (mapper(my_mapper) : v%my_type_var) map (tofrom : v%arr)
end subroutine declare_mapper_3
30 changes: 30 additions & 0 deletions flang/test/Lower/OpenMP/map-mapper.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %s -o - | FileCheck %s
program p
integer, parameter :: n = 256
type t1
integer :: x(256)
end type t1

!$omp declare mapper(xx : t1 :: nn) map(to: nn, nn%x)
!$omp declare mapper(t1 :: nn) map(from: nn)

!CHECK-LABEL: omp.declare_mapper @_QQFt1.default : !fir.type<_QFTt1{x:!fir.array<256xi32>}>
!CHECK-LABEL: omp.declare_mapper @_QQFxx : !fir.type<_QFTt1{x:!fir.array<256xi32>}>

type(t1) :: a, b
!CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%{{.*}} : {{.*}}, {{.*}}) mapper(@_QQFxx) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "a"}
!CHECK: omp.target map_entries(%[[MAP_A]] -> %{{.*}}, %{{.*}} -> %{{.*}} : {{.*}}, {{.*}}) {
!$omp target map(mapper(xx) : a)
do i = 1, n
a%x(i) = i
end do
!$omp end target

!CHECK: %[[MAP_B:.*]] = omp.map.info var_ptr(%{{.*}} : {{.*}}, {{.*}}) mapper(@_QQFt1.default) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "b"}
!CHECK: omp.target map_entries(%[[MAP_B]] -> %{{.*}}, %{{.*}} -> %{{.*}} : {{.*}}, {{.*}}) {
!$omp target map(mapper(default) : b)
do i = 1, n
b%x(i) = i
end do
!$omp end target
end program p