Skip to content

Commit aec735c

Browse files
authored
[Flang][OpenMP][MLIR] Fix common block mapping for regular and declare target link (#91829)
This PR attempts to fix common block mapping for regular mapping of these types as well as when they have been marked as "declare target link". This PR should allow correct mapping of both the members of a common block and the full common block via its block symbol. The main changes were some adjustments to the Fortran OpenMP lowering to HLFIR/FIR, the lowering of the LLVM+OpenMP dialect to LLVM-IR and adjustments to the way the we handle target kernel map argument rebinding inside of the OMPIRBuilder. For the Fortran OpenMP lowering were two changes, one to prevent the implicit capture of common block members when the common block symbol itself has been marked and the other creates intermediate member access inside of the target region to be used in-place of those external to the target region, this prevents external usages breaking the IsolatedFromAbove pact. In the latter case, there was an adjustment to the size calculation for types to better handle cases where we pass an array as the type of a map (as opposed to the bounds and the type of the element), which occurs in the case of common blocks. There is also some adjustment to how handleDeclareTargetMapVar handles renaming of declare target symbols in the module to the reference pointer, now it will only apply to those within the kernel that is currently being generated and we also perform a modification to replace constants with instructions as necessary as we cannot replace these with our reference pointer (non-constant and constants do not mix nicely). In the case of the OpenMPIRBuilder some changes were made to defer global symbol rebinding to kernel arguments until all other arguments have been rebound. This makes sure we do not replace uses that may refer to the global (e.g. a GEP) but are themselves actually a separate argument that needs bound. Currently "declare target to" still needs some work, but this may be the case for all types in conjunction with "declare target to" at the moment.
1 parent fef144c commit aec735c

13 files changed

+703
-25
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,33 @@ static void genBodyOfTargetDataOp(
776776
}
777777
}
778778

779+
// This generates intermediate common block member accesses within a region
780+
// and then rebinds the members symbol to the intermediate accessors we have
781+
// generated so that subsequent code generation will utilise these instead.
782+
//
783+
// When the scope changes, the bindings to the intermediate accessors should
784+
// be dropped in place of the original symbol bindings.
785+
//
786+
// This is for utilisation with TargetOp.
787+
static void genIntermediateCommonBlockAccessors(
788+
Fortran::lower::AbstractConverter &converter,
789+
const mlir::Location &currentLocation, mlir::Region &region,
790+
llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSyms) {
791+
for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) {
792+
if (auto *details =
793+
argSymbol->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
794+
for (auto obj : details->objects()) {
795+
auto targetCBMemberBind = Fortran::lower::genCommonBlockMember(
796+
converter, currentLocation, *obj, region.getArgument(argIndex));
797+
fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*obj);
798+
fir::ExtendedValue targetCBExv =
799+
getExtendedValue(sexv, targetCBMemberBind);
800+
converter.bindSymbol(*obj, targetCBExv);
801+
}
802+
}
803+
}
804+
}
805+
779806
// This functions creates a block for the body of the targetOp's region. It adds
780807
// all the symbols present in mapSymbols as block arguments to this block.
781808
static void
@@ -955,6 +982,16 @@ genBodyOfTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
955982
// Create the insertion point after the marker.
956983
firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
957984

985+
// If we map a common block using it's symbol e.g. map(tofrom: /common_block/)
986+
// and accessing it's members within the target region, there is a large
987+
// chance we will end up with uses external to the region accessing the common
988+
// resolve these, we do so by generating new common block member accesses
989+
// within the region, binding them to the member symbol for the scope of the
990+
// region so that subsequent code generation within the region will utilise
991+
// our new member accesses we have created.
992+
genIntermediateCommonBlockAccessors(converter, currentLocation, region,
993+
mapSyms);
994+
958995
if (ConstructQueue::iterator next = std::next(item); next != queue.end()) {
959996
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
960997
next);
@@ -1670,6 +1707,13 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
16701707
if (dsp.getAllSymbolsToPrivatize().contains(&sym))
16711708
return;
16721709

1710+
// if the symbol is part of an already mapped common block, do not make a
1711+
// map for it.
1712+
if (const Fortran::semantics::Symbol *common =
1713+
Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate()))
1714+
if (llvm::find(mapSyms, common) != mapSyms.end())
1715+
return;
1716+
16731717
if (llvm::find(mapSyms, &sym) == mapSyms.end()) {
16741718
mlir::Value baseOp = converter.getSymbolAddress(sym);
16751719
if (!baseOp)

flang/test/Fir/convert-to-llvm-openmp-and-fir.fir

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,3 +1006,77 @@ func.func @omp_map_info_nested_derived_type_explicit_member_conversion(%arg0 : !
10061006
}
10071007

10081008
// -----
1009+
1010+
// CHECK-LABEL: llvm.func @omp_map_common_block_using_common_block_symbol
1011+
1012+
// CHECK: %[[ADDR_OF:.*]] = llvm.mlir.addressof @var_common_ : !llvm.ptr
1013+
// CHECK: %[[CB_MAP:.*]] = omp.map.info var_ptr(%[[ADDR_OF]] : !llvm.ptr, !llvm.array<8 x i8>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var_common"}
1014+
// CHECK: omp.target map_entries(%[[CB_MAP]] -> %[[ARG0:.*]] : !llvm.ptr) {
1015+
// CHECK: ^bb0(%[[ARG0]]: !llvm.ptr):
1016+
// CHECK: %[[VAR_2_OFFSET:.*]] = llvm.mlir.constant(4 : index) : i64
1017+
// CHECK: %[[VAR_1_OFFSET:.*]] = llvm.mlir.constant(0 : index) : i64
1018+
// CHECK: %{{.*}} = llvm.getelementptr %[[ARG0]][%[[VAR_1_OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
1019+
// CHECK: %{{.*}} = llvm.getelementptr %[[ARG0]][%[[VAR_2_OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
1020+
1021+
func.func @omp_map_common_block_using_common_block_symbol() {
1022+
%0 = fir.address_of(@var_common_) : !fir.ref<!fir.array<8xi8>>
1023+
%1 = omp.map.info var_ptr(%0 : !fir.ref<!fir.array<8xi8>>, !fir.array<8xi8>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.array<8xi8>> {name = "var_common"}
1024+
omp.target map_entries(%1 -> %arg0 : !fir.ref<!fir.array<8xi8>>) {
1025+
^bb0(%arg0: !fir.ref<!fir.array<8xi8>>):
1026+
%c4 = arith.constant 4 : index
1027+
%c0 = arith.constant 0 : index
1028+
%c20_i32 = arith.constant 20 : i32
1029+
%2 = fir.convert %arg0 : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
1030+
%3 = fir.coordinate_of %2, %c0 : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
1031+
%4 = fir.convert %3 : (!fir.ref<i8>) -> !fir.ref<i32>
1032+
%5 = fir.convert %arg0 : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
1033+
%6 = fir.coordinate_of %5, %c4 : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
1034+
%7 = fir.convert %6 : (!fir.ref<i8>) -> !fir.ref<i32>
1035+
%8 = fir.load %4 : !fir.ref<i32>
1036+
%9 = arith.addi %8, %c20_i32 : i32
1037+
fir.store %9 to %7 : !fir.ref<i32>
1038+
omp.terminator
1039+
}
1040+
return
1041+
}
1042+
1043+
fir.global common @var_common_(dense<0> : vector<8xi8>) {alignment = 4 : i64} : !fir.array<8xi8>
1044+
1045+
// -----
1046+
1047+
// CHECK-LABEL: llvm.func @omp_map_common_block_using_common_block_members
1048+
1049+
// CHECK: %[[VAR_2_OFFSET:.*]] = llvm.mlir.constant(4 : index) : i64
1050+
// CHECK: %[[VAR_1_OFFSET:.*]] = llvm.mlir.constant(0 : index) : i64
1051+
// CHECK: %[[ADDR_OF:.*]] = llvm.mlir.addressof @var_common_ : !llvm.ptr
1052+
// CHECK: %[[VAR_1_CB_GEP:.*]] = llvm.getelementptr %[[ADDR_OF]][%[[VAR_1_OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
1053+
// CHECK: %[[VAR_2_CB_GEP:.*]] = llvm.getelementptr %[[ADDR_OF]][%[[VAR_2_OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
1054+
// CHECK: %[[MAP_CB_VAR_1:.*]] = omp.map.info var_ptr(%[[VAR_1_CB_GEP]] : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var1"}
1055+
// CHECK: %[[MAP_CB_VAR_2:.*]] = omp.map.info var_ptr(%[[VAR_2_CB_GEP]] : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var2"}
1056+
// CHECK: omp.target map_entries(%[[MAP_CB_VAR_1]] -> %[[ARG0:.*]], %[[MAP_CB_VAR_2]] -> %[[ARG1:.*]] : !llvm.ptr, !llvm.ptr) {
1057+
// CHECK: ^bb0(%[[ARG0]]: !llvm.ptr, %[[ARG1]]: !llvm.ptr):
1058+
1059+
func.func @omp_map_common_block_using_common_block_members() {
1060+
%c4 = arith.constant 4 : index
1061+
%c0 = arith.constant 0 : index
1062+
%0 = fir.address_of(@var_common_) : !fir.ref<!fir.array<8xi8>>
1063+
%1 = fir.convert %0 : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
1064+
%2 = fir.coordinate_of %1, %c0 : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
1065+
%3 = fir.convert %2 : (!fir.ref<i8>) -> !fir.ref<i32>
1066+
%4 = fir.convert %0 : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
1067+
%5 = fir.coordinate_of %4, %c4 : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
1068+
%6 = fir.convert %5 : (!fir.ref<i8>) -> !fir.ref<i32>
1069+
%7 = omp.map.info var_ptr(%3 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "var1"}
1070+
%8 = omp.map.info var_ptr(%6 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "var2"}
1071+
omp.target map_entries(%7 -> %arg0, %8 -> %arg1 : !fir.ref<i32>, !fir.ref<i32>) {
1072+
^bb0(%arg0: !fir.ref<i32>, %arg1: !fir.ref<i32>):
1073+
%c10_i32 = arith.constant 10 : i32
1074+
%9 = fir.load %arg0 : !fir.ref<i32>
1075+
%10 = arith.muli %9, %c10_i32 : i32
1076+
fir.store %10 to %arg1 : !fir.ref<i32>
1077+
omp.terminator
1078+
}
1079+
return
1080+
}
1081+
1082+
fir.global common @var_common_(dense<0> : vector<8xi8>) {alignment = 4 : i64} : !fir.array<8xi8>

flang/test/Integration/OpenMP/map-types-and-sizes.f90

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,31 @@ subroutine mapType_char
231231
!$omp end target
232232
end subroutine mapType_char
233233

234+
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [1 x i64] [i64 8]
235+
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [1 x i64] [i64 35]
236+
subroutine mapType_common_block
237+
implicit none
238+
common /var_common/ var1, var2
239+
integer :: var1, var2
240+
!$omp target map(tofrom: /var_common/)
241+
var1 = var1 + 20
242+
var2 = var2 + 30
243+
!$omp end target
244+
end subroutine mapType_common_block
245+
246+
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [2 x i64] [i64 4, i64 4]
247+
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [2 x i64] [i64 35, i64 35]
248+
subroutine mapType_common_block_members
249+
implicit none
250+
common /var_common/ var1, var2
251+
integer :: var1, var2
252+
253+
!$omp target map(tofrom: var1, var2)
254+
var2 = var1
255+
!$omp end target
256+
end subroutine mapType_common_block_members
257+
258+
234259
!CHECK-LABEL: define {{.*}} @{{.*}}maptype_ptr_explicit_{{.*}}
235260
!CHECK: %[[ALLOCA:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, i64 1, align 8
236261
!CHECK: %[[ALLOCA_GEP:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[ALLOCA]], i32 1
@@ -346,3 +371,19 @@ end subroutine mapType_char
346371
!CHECK: store ptr %[[ALLOCA]], ptr %[[BASE_PTR_ARR]], align 8
347372
!CHECK: %[[OFFLOAD_PTR_ARR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0
348373
!CHECK: store ptr %[[ARR_OFF]], ptr %[[OFFLOAD_PTR_ARR]], align 8
374+
375+
!CHECK-LABEL: define {{.*}} @{{.*}}maptype_common_block_{{.*}}
376+
!CHECK: %[[BASE_PTR_ARR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
377+
!CHECK: store ptr @var_common_, ptr %[[BASE_PTR_ARR]], align 8
378+
!CHECK: %[[OFFLOAD_PTR_ARR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0
379+
!CHECK: store ptr @var_common_, ptr %[[OFFLOAD_PTR_ARR]], align 8
380+
381+
!CHECK-LABEL: define {{.*}} @{{.*}}maptype_common_block_members_{{.*}}
382+
!CHECK: %[[BASE_PTR_ARR:.*]] = getelementptr inbounds [2 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
383+
!CHECK: store ptr @var_common_, ptr %[[BASE_PTR_ARR]], align 8
384+
!CHECK: %[[OFFLOAD_PTR_ARR:.*]] = getelementptr inbounds [2 x ptr], ptr %.offload_ptrs, i32 0, i32 0
385+
!CHECK: store ptr @var_common_, ptr %[[OFFLOAD_PTR_ARR]], align 8
386+
!CHECK: %[[BASE_PTR_ARR_1:.*]] = getelementptr inbounds [2 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
387+
!CHECK: store ptr getelementptr (i8, ptr @var_common_, i64 4), ptr %[[BASE_PTR_ARR_1]], align 8
388+
!CHECK: %[[OFFLOAD_PTR_ARR_1:.*]] = getelementptr inbounds [2 x ptr], ptr %.offload_ptrs, i32 0, i32 1
389+
!CHECK: store ptr getelementptr (i8, ptr @var_common_, i64 4), ptr %[[OFFLOAD_PTR_ARR_1]], align 8
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
2+
3+
!CHECK: fir.global common @var_common_(dense<0> : vector<8xi8>) {{.*}} : !fir.array<8xi8>
4+
!CHECK: fir.global common @var_common_link_(dense<0> : vector<8xi8>) {{{.*}} omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (link)>} : !fir.array<8xi8>
5+
6+
!CHECK-LABEL: func.func @_QPmap_full_block
7+
!CHECK: %[[CB_ADDR:.*]] = fir.address_of(@var_common_) : !fir.ref<!fir.array<8xi8>>
8+
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[CB_ADDR]] : !fir.ref<!fir.array<8xi8>>, !fir.array<8xi8>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.array<8xi8>> {name = "var_common"}
9+
!CHECK: omp.target map_entries(%[[MAP]] -> %[[MAP_ARG:.*]] : !fir.ref<!fir.array<8xi8>>) {
10+
!CHECK: ^bb0(%[[MAP_ARG]]: !fir.ref<!fir.array<8xi8>>):
11+
!CHECK: %[[CONV:.*]] = fir.convert %[[MAP_ARG]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
12+
!CHECK: %[[INDEX:.*]] = arith.constant 0 : index
13+
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
14+
!CHECK: %[[CONV2:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
15+
!CHECK: %[[CB_MEMBER_1:.*]]:2 = hlfir.declare %[[CONV2]] {uniq_name = "_QFmap_full_blockEvar1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
16+
!CHECK: %[[CONV3:.*]] = fir.convert %[[MAP_ARG]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
17+
!CHECK: %[[INDEX2:.*]] = arith.constant 4 : index
18+
!CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[CONV3]], %[[INDEX2]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
19+
!CHECK: %[[CONV4:.*]] = fir.convert %[[COORD2]] : (!fir.ref<i8>) -> !fir.ref<i32>
20+
!CHECK: %[[CB_MEMBER_2:.*]]:2 = hlfir.declare %[[CONV4]] {uniq_name = "_QFmap_full_blockEvar2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
21+
subroutine map_full_block
22+
implicit none
23+
common /var_common/ var1, var2
24+
integer :: var1, var2
25+
!$omp target map(tofrom: /var_common/)
26+
var1 = var1 + 20
27+
var2 = var2 + 30
28+
!$omp end target
29+
end
30+
31+
!CHECK-LABEL: @_QPmap_mix_of_members
32+
!CHECK: %[[COMMON_BLOCK:.*]] = fir.address_of(@var_common_) : !fir.ref<!fir.array<8xi8>>
33+
!CHECK: %[[CB_CONV:.*]] = fir.convert %[[COMMON_BLOCK]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
34+
!CHECK: %[[INDEX:.*]] = arith.constant 0 : index
35+
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CB_CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
36+
!CHECK: %[[CONV:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
37+
!CHECK: %[[CB_MEMBER_1:.*]]:2 = hlfir.declare %[[CONV]] {uniq_name = "_QFmap_mix_of_membersEvar1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
38+
!CHECK: %[[CB_CONV:.*]] = fir.convert %[[COMMON_BLOCK]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
39+
!CHECK: %[[INDEX:.*]] = arith.constant 4 : index
40+
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CB_CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
41+
!CHECK: %[[CONV:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
42+
!CHECK: %[[CB_MEMBER_2:.*]]:2 = hlfir.declare %[[CONV]] {uniq_name = "_QFmap_mix_of_membersEvar2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
43+
!CHECK: %[[MAP_EXP:.*]] = omp.map.info var_ptr(%[[CB_MEMBER_2]]#0 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "var2"}
44+
!CHECK: %[[MAP_IMP:.*]] = omp.map.info var_ptr(%[[CB_MEMBER_1]]#1 : !fir.ref<i32>, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "var1"}
45+
!CHECK: omp.target map_entries(%[[MAP_EXP]] -> %[[ARG_EXP:.*]], %[[MAP_IMP]] -> %[[ARG_IMP:.*]] : !fir.ref<i32>, !fir.ref<i32>) {
46+
!CHECK: ^bb0(%[[ARG_EXP]]: !fir.ref<i32>, %[[ARG_IMP]]: !fir.ref<i32>):
47+
!CHECK: %[[EXP_MEMBER:.*]]:2 = hlfir.declare %[[ARG_EXP]] {uniq_name = "_QFmap_mix_of_membersEvar2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
48+
!CHECK: %[[IMP_MEMBER:.*]]:2 = hlfir.declare %[[ARG_IMP]] {uniq_name = "_QFmap_mix_of_membersEvar1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
49+
subroutine map_mix_of_members
50+
implicit none
51+
common /var_common/ var1, var2
52+
integer :: var1, var2
53+
54+
!$omp target map(tofrom: var2)
55+
var2 = var1
56+
!$omp end target
57+
end
58+
59+
!CHECK-LABEL: @_QQmain
60+
!CHECK: %[[DECL_TAR_CB:.*]] = fir.address_of(@var_common_link_) : !fir.ref<!fir.array<8xi8>>
61+
!CHECK: %[[MAP_DECL_TAR_CB:.*]] = omp.map.info var_ptr(%[[DECL_TAR_CB]] : !fir.ref<!fir.array<8xi8>>, !fir.array<8xi8>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.array<8xi8>> {name = "var_common_link"}
62+
!CHECK: omp.target map_entries(%[[MAP_DECL_TAR_CB]] -> %[[MAP_DECL_TAR_ARG:.*]] : !fir.ref<!fir.array<8xi8>>) {
63+
!CHECK: ^bb0(%[[MAP_DECL_TAR_ARG]]: !fir.ref<!fir.array<8xi8>>):
64+
!CHECK: %[[CONV:.*]] = fir.convert %[[MAP_DECL_TAR_ARG]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
65+
!CHECK: %[[INDEX:.*]] = arith.constant 0 : index
66+
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
67+
!CHECK: %[[CONV:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
68+
!CHECK: %[[MEMBER_ONE:.*]]:2 = hlfir.declare %[[CONV]] {uniq_name = "_QFElink1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
69+
!CHECK: %[[CONV:.*]] = fir.convert %[[MAP_DECL_TAR_ARG]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
70+
!CHECK: %[[INDEX:.*]] = arith.constant 4 : index
71+
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
72+
!CHECK: %[[CONV:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
73+
!CHECK: %[[MEMBER_TWO:.*]]:2 = hlfir.declare %[[CONV]] {uniq_name = "_QFElink2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
74+
program main
75+
implicit none
76+
common /var_common_link/ link1, link2
77+
integer :: link1, link2
78+
!$omp declare target link(/var_common_link/)
79+
80+
!$omp target map(tofrom: /var_common_link/)
81+
link1 = link2 + 20
82+
!$omp end target
83+
end program

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5164,15 +5164,7 @@ static Function *createOutlinedFunction(
51645164
? make_range(Func->arg_begin() + 1, Func->arg_end())
51655165
: Func->args();
51665166

5167-
// Rewrite uses of input valus to parameters.
5168-
for (auto InArg : zip(Inputs, ArgRange)) {
5169-
Value *Input = std::get<0>(InArg);
5170-
Argument &Arg = std::get<1>(InArg);
5171-
Value *InputCopy = nullptr;
5172-
5173-
Builder.restoreIP(
5174-
ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP()));
5175-
5167+
auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
51765168
// Things like GEP's can come in the form of Constants. Constants and
51775169
// ConstantExpr's do not have access to the knowledge of what they're
51785170
// contained in, so we must dig a little to find an instruction so we
@@ -5198,8 +5190,49 @@ static Function *createOutlinedFunction(
51985190
if (auto *Instr = dyn_cast<Instruction>(User))
51995191
if (Instr->getFunction() == Func)
52005192
Instr->replaceUsesOfWith(Input, InputCopy);
5193+
};
5194+
5195+
SmallVector<std::pair<Value *, Value *>> DeferredReplacement;
5196+
5197+
// Rewrite uses of input valus to parameters.
5198+
for (auto InArg : zip(Inputs, ArgRange)) {
5199+
Value *Input = std::get<0>(InArg);
5200+
Argument &Arg = std::get<1>(InArg);
5201+
Value *InputCopy = nullptr;
5202+
5203+
Builder.restoreIP(
5204+
ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP()));
5205+
5206+
// In certain cases a Global may be set up for replacement, however, this
5207+
// Global may be used in multiple arguments to the kernel, just segmented
5208+
// apart, for example, if we have a global array, that is sectioned into
5209+
// multiple mappings (technically not legal in OpenMP, but there is a case
5210+
// in Fortran for Common Blocks where this is neccesary), we will end up
5211+
// with GEP's into this array inside the kernel, that refer to the Global
5212+
// but are technically seperate arguments to the kernel for all intents and
5213+
// purposes. If we have mapped a segment that requires a GEP into the 0-th
5214+
// index, it will fold into an referal to the Global, if we then encounter
5215+
// this folded GEP during replacement all of the references to the
5216+
// Global in the kernel will be replaced with the argument we have generated
5217+
// that corresponds to it, including any other GEP's that refer to the
5218+
// Global that may be other arguments. This will invalidate all of the other
5219+
// preceding mapped arguments that refer to the same global that may be
5220+
// seperate segments. To prevent this, we defer global processing until all
5221+
// other processing has been performed.
5222+
if (llvm::isa<llvm::GlobalValue>(std::get<0>(InArg)) ||
5223+
llvm::isa<llvm::GlobalObject>(std::get<0>(InArg)) ||
5224+
llvm::isa<llvm::GlobalVariable>(std::get<0>(InArg))) {
5225+
DeferredReplacement.push_back(std::make_pair(Input, InputCopy));
5226+
continue;
5227+
}
5228+
5229+
ReplaceValue(Input, InputCopy, Func);
52015230
}
52025231

5232+
// Replace all of our deferred Input values, currently just Globals.
5233+
for (auto Deferred : DeferredReplacement)
5234+
ReplaceValue(std::get<0>(Deferred), std::get<1>(Deferred), Func);
5235+
52035236
// Restore insert point.
52045237
Builder.restoreIP(OldInsertPoint);
52055238

0 commit comments

Comments
 (0)