Skip to content

Commit 886b2ed

Browse files
authored
[MLIR][OpenMP] Add Lowering support for OpenMP custom mappers in map clause (#121001)
Add Lowering support for OpenMP mapper field in mapInfoOp. NOTE: This patch only supports explicit mapper lowering. I'll add a separate PR soon which handles implicit default mapper recognition. Depends on #120994.
1 parent ee17955 commit 886b2ed

File tree

5 files changed

+120
-22
lines changed

5 files changed

+120
-22
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -969,8 +969,11 @@ void ClauseProcessor::processMapObjects(
969969
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
970970
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
971971
llvm::SmallVectorImpl<mlir::Value> &mapVars,
972-
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const {
972+
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
973+
llvm::StringRef mapperIdNameRef) const {
973974
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
975+
mlir::FlatSymbolRefAttr mapperId;
976+
std::string mapperIdName = mapperIdNameRef.str();
974977

975978
for (const omp::Object &object : objects) {
976979
llvm::SmallVector<mlir::Value> bounds;
@@ -1003,6 +1006,20 @@ void ClauseProcessor::processMapObjects(
10031006
}
10041007
}
10051008

1009+
if (!mapperIdName.empty()) {
1010+
if (mapperIdName == "default") {
1011+
auto &typeSpec = object.sym()->owner().IsDerivedType()
1012+
? *object.sym()->owner().derivedTypeSpec()
1013+
: object.sym()->GetType()->derivedTypeSpec();
1014+
mapperIdName = typeSpec.name().ToString() + ".default";
1015+
mapperIdName = converter.mangleName(mapperIdName, *typeSpec.GetScope());
1016+
}
1017+
assert(converter.getModuleOp().lookupSymbol(mapperIdName) &&
1018+
"mapper not found");
1019+
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
1020+
mapperIdName);
1021+
mapperIdName.clear();
1022+
}
10061023
// Explicit map captures are captured ByRef by default,
10071024
// optimisation passes may alter this to ByCopy or other capture
10081025
// types to optimise
@@ -1016,7 +1033,8 @@ void ClauseProcessor::processMapObjects(
10161033
static_cast<
10171034
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
10181035
mapTypeBits),
1019-
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());
1036+
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType(),
1037+
/*partialMap=*/false, mapperId);
10201038

10211039
if (parentObj.has_value()) {
10221040
parentMemberIndices[parentObj.value()].addChildIndexAndMapToParent(
@@ -1047,6 +1065,7 @@ bool ClauseProcessor::processMap(
10471065
const auto &[mapType, typeMods, mappers, iterator, objects] = clause.t;
10481066
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
10491067
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1068+
std::string mapperIdName;
10501069
// If the map type is specified, then process it else Tofrom is the
10511070
// default.
10521071
Map::MapType type = mapType.value_or(Map::MapType::Tofrom);
@@ -1090,13 +1109,17 @@ bool ClauseProcessor::processMap(
10901109
"Support for iterator modifiers is not implemented yet");
10911110
}
10921111
if (mappers) {
1093-
TODO(currentLocation,
1094-
"Support for mapper modifiers is not implemented yet");
1112+
assert(mappers->size() == 1 && "more than one mapper");
1113+
mapperIdName = mappers->front().v.id().symbol->name().ToString();
1114+
if (mapperIdName != "default")
1115+
mapperIdName = converter.mangleName(
1116+
mapperIdName, mappers->front().v.id().symbol->owner());
10951117
}
10961118

10971119
processMapObjects(stmtCtx, clauseLocation,
10981120
std::get<omp::ObjectList>(clause.t), mapTypeBits,
1099-
parentMemberIndices, result.mapVars, *ptrMapSyms);
1121+
parentMemberIndices, result.mapVars, *ptrMapSyms,
1122+
mapperIdName);
11001123
};
11011124

11021125
bool clauseFound = findRepeatableClause<omp::clause::Map>(process);

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ class ClauseProcessor {
175175
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
176176
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
177177
llvm::SmallVectorImpl<mlir::Value> &mapVars,
178-
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const;
178+
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
179+
llvm::StringRef mapperIdNameRef = "") const;
179180

180181
lower::AbstractConverter &converter;
181182
semantics::SemanticsContext &semaCtx;

flang/test/Lower/OpenMP/Todo/map-mapper.f90

Lines changed: 0 additions & 16 deletions
This file was deleted.

flang/test/Lower/OpenMP/declare-mapper.f90

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
! RUN: split-file %s %t
44
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-1.f90 -o - | FileCheck %t/omp-declare-mapper-1.f90
55
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-2.f90 -o - | FileCheck %t/omp-declare-mapper-2.f90
6+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90
67

78
!--- omp-declare-mapper-1.f90
89
subroutine declare_mapper_1
@@ -83,3 +84,62 @@ subroutine declare_mapper_2
8384
!CHECK: }
8485
!$omp declare mapper (my_mapper : my_type2 :: v) map (v%arr) map (alloc : v%temp)
8586
end subroutine declare_mapper_2
87+
88+
!--- omp-declare-mapper-3.f90
89+
subroutine declare_mapper_3
90+
type my_type
91+
integer :: num_vals
92+
integer, allocatable :: values(:)
93+
end type
94+
95+
type my_type2
96+
type(my_type) :: my_type_var
97+
real, dimension(250) :: arr
98+
end type
99+
100+
!CHECK: omp.declare_mapper @[[MY_TYPE_MAPPER2:_QQFdeclare_mapper_3my_mapper2]] : [[MY_TYPE2:!fir\.type<_QFdeclare_mapper_3Tmy_type2\{my_type_var:!fir\.type<_QFdeclare_mapper_3Tmy_type\{num_vals:i32,values:!fir\.box<!fir\.heap<!fir\.array<\?xi32>>>}>,arr:!fir\.array<250xf32>}>]] {
101+
!CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<[[MY_TYPE2]]>):
102+
!CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdeclare_mapper_3Ev"} : (!fir.ref<[[MY_TYPE2]]>) -> (!fir.ref<[[MY_TYPE2]]>, !fir.ref<[[MY_TYPE2]]>)
103+
!CHECK: %[[VAL_2:.*]] = hlfir.designate %[[VAL_1]]#0{"my_type_var"} : (!fir.ref<[[MY_TYPE2]]>) -> !fir.ref<[[MY_TYPE:!fir\.type<_QFdeclare_mapper_3Tmy_type\{num_vals:i32,values:!fir\.box<!fir\.heap<!fir\.array<\?xi32>>>}>]]>
104+
!CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) mapper(@[[MY_TYPE_MAPPER:_QQFdeclare_mapper_3my_mapper]]) map_clauses(tofrom) capture(ByRef) -> !fir.ref<[[MY_TYPE]]> {name = "v%[[VAL_4:.*]]"}
105+
!CHECK: %[[VAL_5:.*]] = arith.constant 250 : index
106+
!CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
107+
!CHECK: %[[VAL_7:.*]] = hlfir.designate %[[VAL_1]]#0{"arr"} shape %[[VAL_6]] : (!fir.ref<[[MY_TYPE2]]>, !fir.shape<1>) -> !fir.ref<!fir.array<250xf32>>
108+
!CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
109+
!CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
110+
!CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_5]], %[[VAL_8]] : index
111+
!CHECK: %[[VAL_11:.*]] = omp.map.bounds lower_bound(%[[VAL_9]] : index) upper_bound(%[[VAL_10]] : index) extent(%[[VAL_5]] : index) stride(%[[VAL_8]] : index) start_idx(%[[VAL_8]] : index)
112+
!CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref<!fir.array<250xf32>>, !fir.array<250xf32>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_11]]) -> !fir.ref<!fir.array<250xf32>> {name = "v%[[VAL_13:.*]]"}
113+
!CHECK: %[[VAL_14:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE2]]>, [[MY_TYPE2]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_12]] : [0], [1] : !fir.ref<[[MY_TYPE]]>, !fir.ref<!fir.array<250xf32>>) -> !fir.ref<[[MY_TYPE2]]> {name = "v", partial_map = true}
114+
!CHECK: omp.declare_mapper.info map_entries(%[[VAL_14]], %[[VAL_3]], %[[VAL_12]] : !fir.ref<[[MY_TYPE2]]>, !fir.ref<[[MY_TYPE]]>, !fir.ref<!fir.array<250xf32>>)
115+
!CHECK: }
116+
117+
!CHECK: omp.declare_mapper @[[MY_TYPE_MAPPER]] : [[MY_TYPE]] {
118+
!CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<[[MY_TYPE]]>):
119+
!CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdeclare_mapper_3Evar"} : (!fir.ref<[[MY_TYPE]]>) -> (!fir.ref<[[MY_TYPE]]>, !fir.ref<[[MY_TYPE]]>)
120+
!CHECK: %[[VAL_2:.*]] = hlfir.designate %[[VAL_1]]#0{"values"} {fortran_attrs = #fir.var_attrs<allocatable>} : (!fir.ref<[[MY_TYPE]]>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
121+
!CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
122+
!CHECK: %[[VAL_4:.*]] = fir.box_addr %[[VAL_3]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
123+
!CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
124+
!CHECK: %[[VAL_6:.*]]:3 = fir.box_dims %[[VAL_3]], %[[VAL_5]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
125+
!CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
126+
!CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
127+
!CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
128+
!CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_9]], %[[VAL_6]]#0 : index
129+
!CHECK: %[[VAL_11:.*]] = hlfir.designate %[[VAL_1]]#0{"num_vals"} : (!fir.ref<[[MY_TYPE]]>) -> !fir.ref<i32>
130+
!CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_11]] : !fir.ref<i32>
131+
!CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_12]] : (i32) -> i64
132+
!CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (i64) -> index
133+
!CHECK: %[[VAL_15:.*]] = arith.subi %[[VAL_14]], %[[VAL_6]]#0 : index
134+
!CHECK: %[[VAL_16:.*]] = omp.map.bounds lower_bound(%[[VAL_10]] : index) upper_bound(%[[VAL_15]] : index) extent(%[[VAL_6]]#1 : index) stride(%[[VAL_8]] : index) start_idx(%[[VAL_6]]#0 : index)
135+
!CHECK: %[[VAL_17:.*]] = arith.constant 1 : index
136+
!CHECK: %[[VAL_18:.*]] = fir.coordinate_of %[[VAL_1]]#0, %[[VAL_17]] : (!fir.ref<[[MY_TYPE]]>, index) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
137+
!CHECK: %[[VAL_19:.*]] = fir.box_offset %[[VAL_18]] base_addr : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
138+
!CHECK: %[[VAL_20:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, i32) var_ptr_ptr(%[[VAL_19]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_16]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
139+
!CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.box<!fir.heap<!fir.array<?xi32>>>) map_clauses(to) capture(ByRef) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {name = "var%[[VAL_22:.*]](1:var%[[VAL_23:.*]])"}
140+
!CHECK: %[[VAL_24:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_21]], %[[VAL_20]] : [1], [1, 0] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<[[MY_TYPE]]> {name = "var"}
141+
!CHECK: omp.declare_mapper.info map_entries(%[[VAL_24]], %[[VAL_21]], %[[VAL_20]] : !fir.ref<[[MY_TYPE]]>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>)
142+
!CHECK: }
143+
!$omp declare mapper (my_mapper : my_type :: var) map (var, var%values (1:var%num_vals))
144+
!$omp declare mapper (my_mapper2 : my_type2 :: v) map (mapper(my_mapper) : v%my_type_var) map (tofrom : v%arr)
145+
end subroutine declare_mapper_3
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %s -o - | FileCheck %s
2+
program p
3+
integer, parameter :: n = 256
4+
type t1
5+
integer :: x(256)
6+
end type t1
7+
8+
!$omp declare mapper(xx : t1 :: nn) map(to: nn, nn%x)
9+
!$omp declare mapper(t1 :: nn) map(from: nn)
10+
11+
!CHECK-LABEL: omp.declare_mapper @_QQFt1.default : !fir.type<_QFTt1{x:!fir.array<256xi32>}>
12+
!CHECK-LABEL: omp.declare_mapper @_QQFxx : !fir.type<_QFTt1{x:!fir.array<256xi32>}>
13+
14+
type(t1) :: a, b
15+
!CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%{{.*}} : {{.*}}, {{.*}}) mapper(@_QQFxx) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "a"}
16+
!CHECK: omp.target map_entries(%[[MAP_A]] -> %{{.*}}, %{{.*}} -> %{{.*}} : {{.*}}, {{.*}}) {
17+
!$omp target map(mapper(xx) : a)
18+
do i = 1, n
19+
a%x(i) = i
20+
end do
21+
!$omp end target
22+
23+
!CHECK: %[[MAP_B:.*]] = omp.map.info var_ptr(%{{.*}} : {{.*}}, {{.*}}) mapper(@_QQFt1.default) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "b"}
24+
!CHECK: omp.target map_entries(%[[MAP_B]] -> %{{.*}}, %{{.*}} -> %{{.*}} : {{.*}}, {{.*}}) {
25+
!$omp target map(mapper(default) : b)
26+
do i = 1, n
27+
b%x(i) = i
28+
end do
29+
!$omp end target
30+
end program p

0 commit comments

Comments
 (0)