Skip to content

[Flang][OpenMP][MLIR] Fix common block mapping for regular and declare target link #91829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,33 @@ static void genBodyOfTargetDataOp(
}
}

// This generates intermediate common block member accesses within a region
// and then rebinds the members symbol to the intermediate accessors we have
// generated so that subsequent code generation will utilise these instead.
//
// When the scope changes, the bindings to the intermediate accessors should
// be dropped in place of the original symbol bindings.
//
// This is for utilisation with TargetOp.
static void genIntermediateCommonBlockAccessors(
Fortran::lower::AbstractConverter &converter,
const mlir::Location &currentLocation, mlir::Region &region,
llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSyms) {
for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) {
if (auto *details =
argSymbol->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
for (auto obj : details->objects()) {
auto targetCBMemberBind = Fortran::lower::genCommonBlockMember(
converter, currentLocation, *obj, region.getArgument(argIndex));
fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*obj);
fir::ExtendedValue targetCBExv =
getExtendedValue(sexv, targetCBMemberBind);
converter.bindSymbol(*obj, targetCBExv);
}
}
}
}

// This functions creates a block for the body of the targetOp's region. It adds
// all the symbols present in mapSymbols as block arguments to this block.
static void
Expand Down Expand Up @@ -955,6 +982,16 @@ genBodyOfTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
// Create the insertion point after the marker.
firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());

// If we map a common block using it's symbol e.g. map(tofrom: /common_block/)
// and accessing it's members within the target region, there is a large
// chance we will end up with uses external to the region accessing the common
// resolve these, we do so by generating new common block member accesses
// within the region, binding them to the member symbol for the scope of the
// region so that subsequent code generation within the region will utilise
// our new member accesses we have created.
genIntermediateCommonBlockAccessors(converter, currentLocation, region,
mapSyms);

if (ConstructQueue::iterator next = std::next(item); next != queue.end()) {
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
next);
Expand Down Expand Up @@ -1670,6 +1707,13 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
if (dsp.getAllSymbolsToPrivatize().contains(&sym))
return;

// if the symbol is part of an already mapped common block, do not make a
// map for it.
if (const Fortran::semantics::Symbol *common =
Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate()))
if (llvm::find(mapSyms, common) != mapSyms.end())
return;

if (llvm::find(mapSyms, &sym) == mapSyms.end()) {
mlir::Value baseOp = converter.getSymbolAddress(sym);
if (!baseOp)
Expand Down
74 changes: 74 additions & 0 deletions flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
Original file line number Diff line number Diff line change
Expand Up @@ -1006,3 +1006,77 @@ func.func @omp_map_info_nested_derived_type_explicit_member_conversion(%arg0 : !
}

// -----

// CHECK-LABEL: llvm.func @omp_map_common_block_using_common_block_symbol

// CHECK: %[[ADDR_OF:.*]] = llvm.mlir.addressof @var_common_ : !llvm.ptr
// CHECK: %[[CB_MAP:.*]] = omp.map.info var_ptr(%[[ADDR_OF]] : !llvm.ptr, !llvm.array<8 x i8>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var_common"}
// CHECK: omp.target map_entries(%[[CB_MAP]] -> %[[ARG0:.*]] : !llvm.ptr) {
// CHECK: ^bb0(%[[ARG0]]: !llvm.ptr):
// CHECK: %[[VAR_2_OFFSET:.*]] = llvm.mlir.constant(4 : index) : i64
// CHECK: %[[VAR_1_OFFSET:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %{{.*}} = llvm.getelementptr %[[ARG0]][%[[VAR_1_OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK: %{{.*}} = llvm.getelementptr %[[ARG0]][%[[VAR_2_OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8

func.func @omp_map_common_block_using_common_block_symbol() {
%0 = fir.address_of(@var_common_) : !fir.ref<!fir.array<8xi8>>
%1 = omp.map.info var_ptr(%0 : !fir.ref<!fir.array<8xi8>>, !fir.array<8xi8>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.array<8xi8>> {name = "var_common"}
omp.target map_entries(%1 -> %arg0 : !fir.ref<!fir.array<8xi8>>) {
^bb0(%arg0: !fir.ref<!fir.array<8xi8>>):
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%c20_i32 = arith.constant 20 : i32
%2 = fir.convert %arg0 : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
%3 = fir.coordinate_of %2, %c0 : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
%4 = fir.convert %3 : (!fir.ref<i8>) -> !fir.ref<i32>
%5 = fir.convert %arg0 : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
%6 = fir.coordinate_of %5, %c4 : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
%7 = fir.convert %6 : (!fir.ref<i8>) -> !fir.ref<i32>
%8 = fir.load %4 : !fir.ref<i32>
%9 = arith.addi %8, %c20_i32 : i32
fir.store %9 to %7 : !fir.ref<i32>
omp.terminator
}
return
}

fir.global common @var_common_(dense<0> : vector<8xi8>) {alignment = 4 : i64} : !fir.array<8xi8>

// -----

// CHECK-LABEL: llvm.func @omp_map_common_block_using_common_block_members

// CHECK: %[[VAR_2_OFFSET:.*]] = llvm.mlir.constant(4 : index) : i64
// CHECK: %[[VAR_1_OFFSET:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[ADDR_OF:.*]] = llvm.mlir.addressof @var_common_ : !llvm.ptr
// CHECK: %[[VAR_1_CB_GEP:.*]] = llvm.getelementptr %[[ADDR_OF]][%[[VAR_1_OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK: %[[VAR_2_CB_GEP:.*]] = llvm.getelementptr %[[ADDR_OF]][%[[VAR_2_OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK: %[[MAP_CB_VAR_1:.*]] = omp.map.info var_ptr(%[[VAR_1_CB_GEP]] : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var1"}
// CHECK: %[[MAP_CB_VAR_2:.*]] = omp.map.info var_ptr(%[[VAR_2_CB_GEP]] : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var2"}
// CHECK: omp.target map_entries(%[[MAP_CB_VAR_1]] -> %[[ARG0:.*]], %[[MAP_CB_VAR_2]] -> %[[ARG1:.*]] : !llvm.ptr, !llvm.ptr) {
// CHECK: ^bb0(%[[ARG0]]: !llvm.ptr, %[[ARG1]]: !llvm.ptr):

func.func @omp_map_common_block_using_common_block_members() {
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%0 = fir.address_of(@var_common_) : !fir.ref<!fir.array<8xi8>>
%1 = fir.convert %0 : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
%2 = fir.coordinate_of %1, %c0 : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
%3 = fir.convert %2 : (!fir.ref<i8>) -> !fir.ref<i32>
%4 = fir.convert %0 : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
%5 = fir.coordinate_of %4, %c4 : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
%6 = fir.convert %5 : (!fir.ref<i8>) -> !fir.ref<i32>
%7 = omp.map.info var_ptr(%3 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "var1"}
%8 = omp.map.info var_ptr(%6 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "var2"}
omp.target map_entries(%7 -> %arg0, %8 -> %arg1 : !fir.ref<i32>, !fir.ref<i32>) {
^bb0(%arg0: !fir.ref<i32>, %arg1: !fir.ref<i32>):
%c10_i32 = arith.constant 10 : i32
%9 = fir.load %arg0 : !fir.ref<i32>
%10 = arith.muli %9, %c10_i32 : i32
fir.store %10 to %arg1 : !fir.ref<i32>
omp.terminator
}
return
}

fir.global common @var_common_(dense<0> : vector<8xi8>) {alignment = 4 : i64} : !fir.array<8xi8>
41 changes: 41 additions & 0 deletions flang/test/Integration/OpenMP/map-types-and-sizes.f90
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,31 @@ subroutine mapType_char
!$omp end target
end subroutine mapType_char

!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [1 x i64] [i64 8]
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [1 x i64] [i64 35]
subroutine mapType_common_block
implicit none
common /var_common/ var1, var2
integer :: var1, var2
!$omp target map(tofrom: /var_common/)
var1 = var1 + 20
var2 = var2 + 30
!$omp end target
end subroutine mapType_common_block

!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [2 x i64] [i64 4, i64 4]
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [2 x i64] [i64 35, i64 35]
subroutine mapType_common_block_members
implicit none
common /var_common/ var1, var2
integer :: var1, var2

!$omp target map(tofrom: var1, var2)
var2 = var1
!$omp end target
end subroutine mapType_common_block_members


!CHECK-LABEL: define {{.*}} @{{.*}}maptype_ptr_explicit_{{.*}}
!CHECK: %[[ALLOCA:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, i64 1, align 8
!CHECK: %[[ALLOCA_GEP:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[ALLOCA]], i32 1
Expand Down Expand Up @@ -346,3 +371,19 @@ end subroutine mapType_char
!CHECK: store ptr %[[ALLOCA]], ptr %[[BASE_PTR_ARR]], align 8
!CHECK: %[[OFFLOAD_PTR_ARR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0
!CHECK: store ptr %[[ARR_OFF]], ptr %[[OFFLOAD_PTR_ARR]], align 8

!CHECK-LABEL: define {{.*}} @{{.*}}maptype_common_block_{{.*}}
!CHECK: %[[BASE_PTR_ARR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
!CHECK: store ptr @var_common_, ptr %[[BASE_PTR_ARR]], align 8
!CHECK: %[[OFFLOAD_PTR_ARR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0
!CHECK: store ptr @var_common_, ptr %[[OFFLOAD_PTR_ARR]], align 8

!CHECK-LABEL: define {{.*}} @{{.*}}maptype_common_block_members_{{.*}}
!CHECK: %[[BASE_PTR_ARR:.*]] = getelementptr inbounds [2 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
!CHECK: store ptr @var_common_, ptr %[[BASE_PTR_ARR]], align 8
!CHECK: %[[OFFLOAD_PTR_ARR:.*]] = getelementptr inbounds [2 x ptr], ptr %.offload_ptrs, i32 0, i32 0
!CHECK: store ptr @var_common_, ptr %[[OFFLOAD_PTR_ARR]], align 8
!CHECK: %[[BASE_PTR_ARR_1:.*]] = getelementptr inbounds [2 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
!CHECK: store ptr getelementptr (i8, ptr @var_common_, i64 4), ptr %[[BASE_PTR_ARR_1]], align 8
!CHECK: %[[OFFLOAD_PTR_ARR_1:.*]] = getelementptr inbounds [2 x ptr], ptr %.offload_ptrs, i32 0, i32 1
!CHECK: store ptr getelementptr (i8, ptr @var_common_, i64 4), ptr %[[OFFLOAD_PTR_ARR_1]], align 8
83 changes: 83 additions & 0 deletions flang/test/Lower/OpenMP/common-block-map.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s

!CHECK: fir.global common @var_common_(dense<0> : vector<8xi8>) {{.*}} : !fir.array<8xi8>
!CHECK: fir.global common @var_common_link_(dense<0> : vector<8xi8>) {{{.*}} omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (link)>} : !fir.array<8xi8>

!CHECK-LABEL: func.func @_QPmap_full_block
!CHECK: %[[CB_ADDR:.*]] = fir.address_of(@var_common_) : !fir.ref<!fir.array<8xi8>>
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[CB_ADDR]] : !fir.ref<!fir.array<8xi8>>, !fir.array<8xi8>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.array<8xi8>> {name = "var_common"}
!CHECK: omp.target map_entries(%[[MAP]] -> %[[MAP_ARG:.*]] : !fir.ref<!fir.array<8xi8>>) {
!CHECK: ^bb0(%[[MAP_ARG]]: !fir.ref<!fir.array<8xi8>>):
!CHECK: %[[CONV:.*]] = fir.convert %[[MAP_ARG]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[INDEX:.*]] = arith.constant 0 : index
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[CONV2:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[CB_MEMBER_1:.*]]:2 = hlfir.declare %[[CONV2]] {uniq_name = "_QFmap_full_blockEvar1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[CONV3:.*]] = fir.convert %[[MAP_ARG]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[INDEX2:.*]] = arith.constant 4 : index
!CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[CONV3]], %[[INDEX2]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[CONV4:.*]] = fir.convert %[[COORD2]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[CB_MEMBER_2:.*]]:2 = hlfir.declare %[[CONV4]] {uniq_name = "_QFmap_full_blockEvar2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
subroutine map_full_block
implicit none
common /var_common/ var1, var2
integer :: var1, var2
!$omp target map(tofrom: /var_common/)
var1 = var1 + 20
var2 = var2 + 30
!$omp end target
end

!CHECK-LABEL: @_QPmap_mix_of_members
!CHECK: %[[COMMON_BLOCK:.*]] = fir.address_of(@var_common_) : !fir.ref<!fir.array<8xi8>>
!CHECK: %[[CB_CONV:.*]] = fir.convert %[[COMMON_BLOCK]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[INDEX:.*]] = arith.constant 0 : index
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CB_CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[CONV:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[CB_MEMBER_1:.*]]:2 = hlfir.declare %[[CONV]] {uniq_name = "_QFmap_mix_of_membersEvar1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[CB_CONV:.*]] = fir.convert %[[COMMON_BLOCK]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[INDEX:.*]] = arith.constant 4 : index
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CB_CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[CONV:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[CB_MEMBER_2:.*]]:2 = hlfir.declare %[[CONV]] {uniq_name = "_QFmap_mix_of_membersEvar2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[MAP_EXP:.*]] = omp.map.info var_ptr(%[[CB_MEMBER_2]]#0 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "var2"}
!CHECK: %[[MAP_IMP:.*]] = omp.map.info var_ptr(%[[CB_MEMBER_1]]#1 : !fir.ref<i32>, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "var1"}
!CHECK: omp.target map_entries(%[[MAP_EXP]] -> %[[ARG_EXP:.*]], %[[MAP_IMP]] -> %[[ARG_IMP:.*]] : !fir.ref<i32>, !fir.ref<i32>) {
!CHECK: ^bb0(%[[ARG_EXP]]: !fir.ref<i32>, %[[ARG_IMP]]: !fir.ref<i32>):
!CHECK: %[[EXP_MEMBER:.*]]:2 = hlfir.declare %[[ARG_EXP]] {uniq_name = "_QFmap_mix_of_membersEvar2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[IMP_MEMBER:.*]]:2 = hlfir.declare %[[ARG_IMP]] {uniq_name = "_QFmap_mix_of_membersEvar1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
subroutine map_mix_of_members
implicit none
common /var_common/ var1, var2
integer :: var1, var2

!$omp target map(tofrom: var2)
var2 = var1
!$omp end target
end

!CHECK-LABEL: @_QQmain
!CHECK: %[[DECL_TAR_CB:.*]] = fir.address_of(@var_common_link_) : !fir.ref<!fir.array<8xi8>>
!CHECK: %[[MAP_DECL_TAR_CB:.*]] = omp.map.info var_ptr(%[[DECL_TAR_CB]] : !fir.ref<!fir.array<8xi8>>, !fir.array<8xi8>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.array<8xi8>> {name = "var_common_link"}
!CHECK: omp.target map_entries(%[[MAP_DECL_TAR_CB]] -> %[[MAP_DECL_TAR_ARG:.*]] : !fir.ref<!fir.array<8xi8>>) {
!CHECK: ^bb0(%[[MAP_DECL_TAR_ARG]]: !fir.ref<!fir.array<8xi8>>):
!CHECK: %[[CONV:.*]] = fir.convert %[[MAP_DECL_TAR_ARG]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[INDEX:.*]] = arith.constant 0 : index
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[CONV:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[MEMBER_ONE:.*]]:2 = hlfir.declare %[[CONV]] {uniq_name = "_QFElink1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[CONV:.*]] = fir.convert %[[MAP_DECL_TAR_ARG]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[INDEX:.*]] = arith.constant 4 : index
!CHECK: %[[COORD:.*]] = fir.coordinate_of %[[CONV]], %[[INDEX]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[CONV:.*]] = fir.convert %[[COORD]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[MEMBER_TWO:.*]]:2 = hlfir.declare %[[CONV]] {uniq_name = "_QFElink2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
program main
implicit none
common /var_common_link/ link1, link2
integer :: link1, link2
!$omp declare target link(/var_common_link/)

!$omp target map(tofrom: /var_common_link/)
link1 = link2 + 20
!$omp end target
end program
51 changes: 42 additions & 9 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5164,15 +5164,7 @@ static Function *createOutlinedFunction(
? make_range(Func->arg_begin() + 1, Func->arg_end())
: Func->args();

// Rewrite uses of input valus to parameters.
for (auto InArg : zip(Inputs, ArgRange)) {
Value *Input = std::get<0>(InArg);
Argument &Arg = std::get<1>(InArg);
Value *InputCopy = nullptr;

Builder.restoreIP(
ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP()));

auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
// Things like GEP's can come in the form of Constants. Constants and
// ConstantExpr's do not have access to the knowledge of what they're
// contained in, so we must dig a little to find an instruction so we
Expand All @@ -5198,8 +5190,49 @@ static Function *createOutlinedFunction(
if (auto *Instr = dyn_cast<Instruction>(User))
if (Instr->getFunction() == Func)
Instr->replaceUsesOfWith(Input, InputCopy);
};

SmallVector<std::pair<Value *, Value *>> DeferredReplacement;

// Rewrite uses of input valus to parameters.
for (auto InArg : zip(Inputs, ArgRange)) {
Value *Input = std::get<0>(InArg);
Argument &Arg = std::get<1>(InArg);
Value *InputCopy = nullptr;

Builder.restoreIP(
ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP()));

// In certain cases a Global may be set up for replacement, however, this
// Global may be used in multiple arguments to the kernel, just segmented
// apart, for example, if we have a global array, that is sectioned into
// multiple mappings (technically not legal in OpenMP, but there is a case
// in Fortran for Common Blocks where this is neccesary), we will end up
// with GEP's into this array inside the kernel, that refer to the Global
// but are technically seperate arguments to the kernel for all intents and
// purposes. If we have mapped a segment that requires a GEP into the 0-th
// index, it will fold into an referal to the Global, if we then encounter
// this folded GEP during replacement all of the references to the
// Global in the kernel will be replaced with the argument we have generated
// that corresponds to it, including any other GEP's that refer to the
// Global that may be other arguments. This will invalidate all of the other
// preceding mapped arguments that refer to the same global that may be
// seperate segments. To prevent this, we defer global processing until all
// other processing has been performed.
if (llvm::isa<llvm::GlobalValue>(std::get<0>(InArg)) ||
llvm::isa<llvm::GlobalObject>(std::get<0>(InArg)) ||
llvm::isa<llvm::GlobalVariable>(std::get<0>(InArg))) {
DeferredReplacement.push_back(std::make_pair(Input, InputCopy));
continue;
}

ReplaceValue(Input, InputCopy, Func);
}

// Replace all of our deferred Input values, currently just Globals.
for (auto Deferred : DeferredReplacement)
ReplaceValue(std::get<0>(Deferred), std::get<1>(Deferred), Func);

// Restore insert point.
Builder.restoreIP(OldInsertPoint);

Expand Down
Loading
Loading