Skip to content

Commit beb8430

Browse files
committed
[Flang][MLIR][OpenMP] Improve use_device_* handling
This patch updates MLIR op verifiers for operations taking arguments that must always be defined by an `omp.map.info` operation to check this requirement. It also modifies Flang lowering for `use_device_{addr, ptr}`, as well as the custom MLIR printer and parser for these clauses, to support initializing it to `OMP_MAP_RETURN_PARAM` and represent this in the MLIR representation as `return_param`. This internal mapping flag is what eventually is used for variables passed via these clauses into the target region when translating to LLVM IR, so making it explicit in Flang and MLIR removes an inconsistency in the current representation.
1 parent 78d95cc commit beb8430

File tree

6 files changed

+55
-19
lines changed

6 files changed

+55
-19
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,8 +1298,7 @@ bool ClauseProcessor::processUseDeviceAddr(
12981298
const parser::CharBlock &source) {
12991299
mlir::Location location = converter.genLocation(source);
13001300
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1301-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1302-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1301+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
13031302
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
13041303
parentMemberIndices, result.useDeviceAddrVars,
13051304
useDeviceSyms);
@@ -1320,8 +1319,7 @@ bool ClauseProcessor::processUseDevicePtr(
13201319
const parser::CharBlock &source) {
13211320
mlir::Location location = converter.genLocation(source);
13221321
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1323-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1324-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1322+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
13251323
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
13261324
parentMemberIndices, result.useDevicePtrVars,
13271325
useDeviceSyms);

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,14 +398,16 @@ mlir::Value createParentSymAndGenIntermediateMaps(
398398
interimBounds, treatIndexAsSection);
399399
}
400400

401-
// Remove all map TO, FROM and TOFROM bits, from the intermediate
401+
// Remove all map TO, FROM and RETURN_PARAM bits, from the intermediate
402402
// allocatable maps, we simply wish to alloc or release them. It may be
403403
// safer to just pass OMP_MAP_NONE as the map type, but we may still
404404
// need some of the other map types the mapped member utilises, so for
405405
// now it's good to keep an eye on this.
406406
llvm::omp::OpenMPOffloadMappingFlags interimMapType = mapTypeBits;
407407
interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
408408
interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
409+
interimMapType &=
410+
~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
409411

410412
// Create a map for the intermediate member and insert it and it's
411413
// indices into the parentMemberIndices list to track it.

flang/test/Fir/convert-to-llvm-openmp-and-fir.fir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,14 +423,15 @@ func.func @_QPopenmp_target_data_region() {
423423

424424
func.func @_QPomp_target_data_empty() {
425425
%0 = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_data_emptyEa"}
426-
omp.target_data use_device_addr(%0 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
426+
%1 = omp.map.info var_ptr(%0 : !fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>) map_clauses(return_param) capture(ByRef) -> !fir.ref<!fir.array<1024xi32>> {name = ""}
427+
omp.target_data use_device_addr(%1 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
427428
omp.terminator
428429
}
429430
return
430431
}
431432

432433
// CHECK-LABEL: llvm.func @_QPomp_target_data_empty
433-
// CHECK: omp.target_data use_device_addr(%1 -> %{{.*}} : !llvm.ptr) {
434+
// CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}} : !llvm.ptr) {
434435
// CHECK: }
435436

436437
// -----

flang/test/Lower/OpenMP/target.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ subroutine omp_target_device_addr
544544
!CHECK: %[[VAL_0_DECL:.*]]:2 = hlfir.declare %[[VAL_0]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
545545
!CHECK: %[[MAP_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
546546
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(to) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
547-
!CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
547+
!CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) map_clauses(return_param) capture(ByRef) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
548548
!CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(to) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
549549
!CHECK: omp.target_data map_entries(%[[MAP]], %[[MAP_MEMBERS]] : {{.*}}) use_device_addr(%[[DEV_ADDR]] -> %[[ARG_0:.*]], %[[DEV_ADDR_MEMBERS]] -> %[[ARG_1:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
550550
!$omp target data map(tofrom: a) use_device_addr(a)

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,6 +1520,9 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
15201520
if (mapTypeMod == "delete")
15211521
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
15221522

1523+
if (mapTypeMod == "return_param")
1524+
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1525+
15231526
return success();
15241527
};
15251528

@@ -1582,6 +1585,12 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
15821585
emitAllocRelease = false;
15831586
mapTypeStrs.push_back("delete");
15841587
}
1588+
if (mapTypeToBitFlag(
1589+
mapTypeBits,
1590+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
1591+
emitAllocRelease = false;
1592+
mapTypeStrs.push_back("return_param");
1593+
}
15851594
if (emitAllocRelease)
15861595
mapTypeStrs.push_back("exit_release_or_enter_alloc");
15871596

@@ -1776,13 +1785,27 @@ static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
17761785
// MapInfoOp
17771786
//===----------------------------------------------------------------------===//
17781787

1788+
static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
1789+
StringRef clauseName,
1790+
OperandRange vars) {
1791+
for (Value var : vars)
1792+
if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
1793+
return op->emitOpError()
1794+
<< "'" << clauseName
1795+
<< "' arguments must be defined by 'omp.map.info' ops";
1796+
return success();
1797+
}
1798+
17791799
LogicalResult MapInfoOp::verify() {
17801800
if (getMapperId() &&
17811801
!SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
17821802
*this, getMapperIdAttr())) {
17831803
return emitError("invalid mapper id");
17841804
}
17851805

1806+
if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
1807+
return failure();
1808+
17861809
return success();
17871810
}
17881811

@@ -1804,6 +1827,15 @@ LogicalResult TargetDataOp::verify() {
18041827
"At least one of map, use_device_ptr_vars, or "
18051828
"use_device_addr_vars operand must be present");
18061829
}
1830+
1831+
if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
1832+
getUseDevicePtrVars())))
1833+
return failure();
1834+
1835+
if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
1836+
getUseDeviceAddrVars())))
1837+
return failure();
1838+
18071839
return verifyMapClause(*this, getMapVars());
18081840
}
18091841

@@ -1888,16 +1920,15 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
18881920
}
18891921

18901922
LogicalResult TargetOp::verify() {
1891-
LogicalResult verifyDependVars =
1892-
verifyDependVarList(*this, getDependKinds(), getDependVars());
1893-
1894-
if (failed(verifyDependVars))
1895-
return verifyDependVars;
1923+
if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
1924+
return failure();
18961925

1897-
LogicalResult verifyMapVars = verifyMapClause(*this, getMapVars());
1926+
if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
1927+
getHasDeviceAddrVars())))
1928+
return failure();
18981929

1899-
if (failed(verifyMapVars))
1900-
return verifyMapVars;
1930+
if (failed(verifyMapClause(*this, getMapVars())))
1931+
return failure();
19011932

19021933
return verifyPrivateVarsMapping(*this);
19031934
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -802,10 +802,14 @@ func.func @omp_target_data (%if_cond : i1, %device : si32, %device_ptr: memref<i
802802
%mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(always, from) capture(ByRef) -> memref<?xi32> {name = ""}
803803
omp.target_data if(%if_cond) device(%device : si32) map_entries(%mapv1 : memref<?xi32>){}
804804

805-
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
806-
// CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[VAL_3:.*]] -> %{{.*}} : memref<?xi32>) use_device_ptr(%[[VAL_4:.*]] -> %{{.*}} : memref<i32>)
805+
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%{{.*}} : memref<?xi32>, tensor<?xi32>) map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
806+
// CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr(%{{.*}} : memref<?xi32>, tensor<?xi32>) map_clauses(return_param) capture(ByRef) -> memref<?xi32> {name = ""}
807+
// CHECK: %[[DEV_PTR:.*]] = omp.map.info var_ptr(%{{.*}} : memref<i32>, tensor<i32>) map_clauses(return_param) capture(ByRef) -> memref<i32> {name = ""}
808+
// CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[DEV_ADDR]] -> %{{.*}} : memref<?xi32>) use_device_ptr(%[[DEV_PTR]] -> %{{.*}} : memref<i32>)
807809
%mapv2 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
808-
omp.target_data map_entries(%mapv2 : memref<?xi32>) use_device_addr(%device_addr -> %arg0 : memref<?xi32>) use_device_ptr(%device_ptr -> %arg1 : memref<i32>) {
810+
%device_addrv1 = omp.map.info var_ptr(%device_addr : memref<?xi32>, tensor<?xi32>) map_clauses(return_param) capture(ByRef) -> memref<?xi32> {name = ""}
811+
%device_ptrv1 = omp.map.info var_ptr(%device_ptr : memref<i32>, tensor<i32>) map_clauses(return_param) capture(ByRef) -> memref<i32> {name = ""}
812+
omp.target_data map_entries(%mapv2 : memref<?xi32>) use_device_addr(%device_addrv1 -> %arg0 : memref<?xi32>) use_device_ptr(%device_ptrv1 -> %arg1 : memref<i32>) {
809813
omp.terminator
810814
}
811815

0 commit comments

Comments
 (0)