Skip to content

[MLIR][OpenMP] Add Lowering support for OpenMP Declare Mapper directive #117046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 18, 2025
Merged
2 changes: 2 additions & 0 deletions flang/include/flang/Lower/AbstractConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ class AbstractConverter {
mangleName(const Fortran::semantics::DerivedTypeSpec &) = 0;
/// Unique a compiler generated name (add a containing scope specific prefix)
virtual std::string mangleName(std::string &) = 0;
/// Unique a compiler generated name (add a provided scope specific prefix)
virtual std::string mangleName(std::string &, const semantics::Scope &) = 0;
/// Return the field name for a derived type component inside a fir.record
/// type.
virtual std::string
Expand Down
5 changes: 5 additions & 0 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
return Fortran::lower::mangle::mangleName(name, getCurrentScope(),
scopeBlockIdMap);
}
std::string
mangleName(std::string &name,
const Fortran::semantics::Scope &myScope) override final {
return Fortran::lower::mangle::mangleName(name, myScope, scopeBlockIdMap);
}
std::string getRecordTypeFieldName(
const Fortran::semantics::Symbol &component) override final {
return Fortran::lower::mangle::getRecordTypeFieldName(component,
Expand Down
46 changes: 45 additions & 1 deletion flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3119,7 +3119,51 @@ static void
genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OpenMPDeclareMapperConstruct &declareMapperConstruct) {
TODO(converter.getCurrentLocation(), "OpenMPDeclareMapperConstruct");
mlir::Location loc = converter.genLocation(declareMapperConstruct.source);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
lower::StatementContext stmtCtx;
const auto &spec =
std::get<parser::OmpMapperSpecifier>(declareMapperConstruct.t);
const auto &mapperName{std::get<std::optional<parser::Name>>(spec.t)};
const auto &varType{std::get<parser::TypeSpec>(spec.t)};
const auto &varName{std::get<parser::Name>(spec.t)};
assert(varType.declTypeSpec->category() ==
semantics::DeclTypeSpec::Category::TypeDerived &&
"Expected derived type");

std::string mapperNameStr;
if (mapperName.has_value()) {
mapperNameStr = mapperName->ToString();
mapperNameStr =
converter.mangleName(mapperNameStr, mapperName->symbol->owner());
} else {
mapperNameStr =
varType.declTypeSpec->derivedTypeSpec().name().ToString() + ".default";
mapperNameStr = converter.mangleName(
mapperNameStr, *varType.declTypeSpec->derivedTypeSpec().GetScope());
}

// Save current insertion point before moving to the module scope to create
// the DeclareMapperOp
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);

firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody());
auto mlirType = converter.genType(varType.declTypeSpec->derivedTypeSpec());
auto declMapperOp = firOpBuilder.create<mlir::omp::DeclareMapperOp>(
loc, mapperNameStr, mlirType);
auto &region = declMapperOp.getRegion();
firOpBuilder.createBlock(&region);
auto varVal = region.addArgument(firOpBuilder.getRefType(mlirType), loc);
converter.bindSymbol(*varName.symbol, varVal);

// Populate the declareMapper region with the map information.
mlir::omp::DeclareMapperInfoOperands clauseOps;
const auto *clauseList{
parser::Unwrap<parser::OmpClauseList>(declareMapperConstruct.t)};
List<Clause> clauses = makeClauses(*clauseList, semaCtx);
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processMap(loc, stmtCtx, clauseOps);
firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars);
}

static void
Expand Down
7 changes: 5 additions & 2 deletions flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,8 @@ class MapInfoFinalizationPass
for (auto *user : mapOp->getUsers()) {
if (llvm::isa<mlir::omp::TargetOp, mlir::omp::TargetDataOp,
mlir::omp::TargetUpdateOp, mlir::omp::TargetExitDataOp,
mlir::omp::TargetEnterDataOp>(user))
mlir::omp::TargetEnterDataOp,
mlir::omp::DeclareMapperInfoOp>(user))
return user;

if (auto mapUser = llvm::dyn_cast<mlir::omp::MapInfoOp>(user))
Expand Down Expand Up @@ -497,7 +498,9 @@ class MapInfoFinalizationPass
// ourselves to the possibility of race conditions while this pass
// undergoes frequent re-iteration for the near future. So we loop
// over function in the module and then map.info inside of those.
getOperation()->walk([&](mlir::func::FuncOp func) {
getOperation()->walk([&](mlir::Operation *func) {
if (!mlir::isa<mlir::func::FuncOp, mlir::omp::DeclareMapperOp>(func))
return;
// clear all local allocations we made for any boxes in any prior
// iterations from previous function scopes.
localBoxAllocas.clear();
Expand Down
47 changes: 0 additions & 47 deletions flang/test/Lower/OpenMP/Todo/omp-declare-mapper.f90

This file was deleted.

85 changes: 85 additions & 0 deletions flang/test/Lower/OpenMP/declare-mapper.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
! This test checks lowering of OpenMP declare mapper Directive.

! RUN: split-file %s %t
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-1.f90 -o - | FileCheck %t/omp-declare-mapper-1.f90
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-2.f90 -o - | FileCheck %t/omp-declare-mapper-2.f90

!--- omp-declare-mapper-1.f90
subroutine declare_mapper_1
integer, parameter :: nvals = 250
type my_type
integer :: num_vals
integer, allocatable :: values(:)
end type

type my_type2
type(my_type) :: my_type_var
type(my_type) :: temp
real, dimension(nvals) :: unmapped
real, dimension(nvals) :: arr
end type
type(my_type2) :: t
real :: x, y(nvals)
!CHECK:omp.declare_mapper @[[MY_TYPE_MAPPER:_QQFdeclare_mapper_1my_type\.default]] : [[MY_TYPE:!fir\.type<_QFdeclare_mapper_1Tmy_type\{num_vals:i32,values:!fir\.box<!fir\.heap<!fir\.array<\?xi32>>>\}>]] {
!CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<[[MY_TYPE]]>):
!CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdeclare_mapper_1Evar"} : (!fir.ref<[[MY_TYPE]]>) -> (!fir.ref<[[MY_TYPE]]>, !fir.ref<[[MY_TYPE]]>)
!CHECK: %[[VAL_2:.*]] = hlfir.designate %[[VAL_1]]#0{"values"} {fortran_attrs = #fir.var_attrs<allocatable>} : (!fir.ref<[[MY_TYPE]]>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
!CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
!CHECK: %[[VAL_4:.*]] = fir.box_addr %[[VAL_3]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
!CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
!CHECK: %[[VAL_6:.*]]:3 = fir.box_dims %[[VAL_3]], %[[VAL_5]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
!CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
!CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
!CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
!CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_9]], %[[VAL_6]]#0 : index
!CHECK: %[[VAL_11:.*]] = hlfir.designate %[[VAL_1]]#0{"num_vals"} : (!fir.ref<[[MY_TYPE]]>) -> !fir.ref<i32>
!CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_11]] : !fir.ref<i32>
!CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_12]] : (i32) -> i64
!CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (i64) -> index
!CHECK: %[[VAL_15:.*]] = arith.subi %[[VAL_14]], %[[VAL_6]]#0 : index
!CHECK: %[[VAL_16:.*]] = omp.map.bounds lower_bound(%[[VAL_10]] : index) upper_bound(%[[VAL_15]] : index) extent(%[[VAL_6]]#1 : index) stride(%[[VAL_8]] : index) start_idx(%[[VAL_6]]#0 : index)
!CHECK: %[[VAL_17:.*]] = arith.constant 1 : index
!CHECK: %[[VAL_18:.*]] = fir.coordinate_of %[[VAL_1]]#0, %[[VAL_17]] : (!fir.ref<[[MY_TYPE]]>, index) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
!CHECK: %[[VAL_19:.*]] = fir.box_offset %[[VAL_18]] base_addr : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
!CHECK: %[[VAL_20:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, i32) var_ptr_ptr(%[[VAL_19]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_16]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
!CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.box<!fir.heap<!fir.array<?xi32>>>) map_clauses(to) capture(ByRef) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {name = "var%[[VAL_22:.*]](1:var%[[VAL_23:.*]])"}
!CHECK: %[[VAL_24:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_21]], %[[VAL_20]] : [1], [1, 0] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<[[MY_TYPE]]> {name = "var"}
!CHECK: omp.declare_mapper.info map_entries(%[[VAL_24]], %[[VAL_21]], %[[VAL_20]] : !fir.ref<[[MY_TYPE]]>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>)
!CHECK: }
!$omp declare mapper (my_type :: var) map (var, var%values (1:var%num_vals))
end subroutine declare_mapper_1

!--- omp-declare-mapper-2.f90
subroutine declare_mapper_2
integer, parameter :: nvals = 250
type my_type
integer :: num_vals
integer, allocatable :: values(:)
end type

type my_type2
type(my_type) :: my_type_var
type(my_type) :: temp
real, dimension(nvals) :: unmapped
real, dimension(nvals) :: arr
end type
type(my_type2) :: t
real :: x, y(nvals)
!CHECK:omp.declare_mapper @[[MY_TYPE_MAPPER:_QQFdeclare_mapper_2my_mapper]] : [[MY_TYPE:!fir\.type<_QFdeclare_mapper_2Tmy_type2\{my_type_var:!fir\.type<_QFdeclare_mapper_2Tmy_type\{num_vals:i32,values:!fir\.box<!fir\.heap<!fir\.array<\?xi32>>>\}>,temp:!fir\.type<_QFdeclare_mapper_2Tmy_type\{num_vals:i32,values:!fir\.box<!fir\.heap<!fir\.array<\?xi32>>>\}>,unmapped:!fir\.array<250xf32>,arr:!fir\.array<250xf32>\}>]] {
!CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<[[MY_TYPE]]>):
!CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdeclare_mapper_2Ev"} : (!fir.ref<[[MY_TYPE]]>) -> (!fir.ref<[[MY_TYPE]]>, !fir.ref<[[MY_TYPE]]>)
!CHECK: %[[VAL_2:.*]] = arith.constant 250 : index
!CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
!CHECK: %[[VAL_4:.*]] = hlfir.designate %[[VAL_1]]#0{"arr"} shape %[[VAL_3]] : (!fir.ref<[[MY_TYPE]]>, !fir.shape<1>) -> !fir.ref<!fir.array<250xf32>>
!CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
!CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
!CHECK: %[[VAL_7:.*]] = arith.subi %[[VAL_2]], %[[VAL_5]] : index
!CHECK: %[[VAL_8:.*]] = omp.map.bounds lower_bound(%[[VAL_6]] : index) upper_bound(%[[VAL_7]] : index) extent(%[[VAL_2]] : index) stride(%[[VAL_5]] : index) start_idx(%[[VAL_5]] : index)
!CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_4]] : !fir.ref<!fir.array<250xf32>>, !fir.array<250xf32>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_8]]) -> !fir.ref<!fir.array<250xf32>> {name = "v%[[VAL_10:.*]]"}
!CHECK: %[[VAL_11:.*]] = hlfir.designate %[[VAL_1]]#0{"temp"} : (!fir.ref<[[MY_TYPE]]>) -> !fir.ref<!fir.type<_QFdeclare_mapper_2Tmy_type{num_vals:i32,values:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
!CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref<!fir.type<_QFdeclare_mapper_2Tmy_type{num_vals:i32,values:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>, !fir.type<_QFdeclare_mapper_2Tmy_type{num_vals:i32,values:!fir.box<!fir.heap<!fir.array<?xi32>>>}>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<!fir.type<_QFdeclare_mapper_2Tmy_type{num_vals:i32,values:!fir.box<!fir.heap<!fir.array<?xi32>>>}>> {name = "v%[[VAL_13:.*]]"}
!CHECK: %[[VAL_14:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_9]], %[[VAL_12]] : [3], [1] : !fir.ref<!fir.array<250xf32>>, !fir.ref<!fir.type<_QFdeclare_mapper_2Tmy_type{num_vals:i32,values:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<[[MY_TYPE]]> {name = "v", partial_map = true}
!CHECK: omp.declare_mapper.info map_entries(%[[VAL_14]], %[[VAL_9]], %[[VAL_12]] : !fir.ref<[[MY_TYPE]]>, !fir.ref<!fir.array<250xf32>>, !fir.ref<!fir.type<_QFdeclare_mapper_2Tmy_type{num_vals:i32,values:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>)
!CHECK: }
!$omp declare mapper (my_mapper : my_type2 :: v) map (v%arr) map (alloc : v%temp)
end subroutine declare_mapper_2
1 change: 1 addition & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
#include "llvm/Support/Casting.h"
#include <cstddef>
#include <iterator>
#include <optional>
Expand Down
Loading