Skip to content

Commit e71369f

Browse files
authored
[mlir][OpenMP] Add copyprivate support to omp.single (#80477)
This adds a new custom CopyPrivateVarList to the single operation. Each list item is formed by a reference to the variable to be updated, its type and the function to be used to perform the copy. It will be translated to LLVM IR using OpenMP builder, that will use the information in the copyprivate list to call __kmpc_copyprivate. This is patch 2 of 4, to add support for COPYPRIVATE in Flang. Original PR: #73128
1 parent 75f0d40 commit e71369f

File tree

5 files changed

+193
-3
lines changed

5 files changed

+193
-3
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2665,6 +2665,7 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
26652665
const Fortran::parser::OmpClauseList &beginClauseList,
26662666
const Fortran::parser::OmpClauseList &endClauseList) {
26672667
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
2668+
llvm::SmallVector<mlir::Value> copyPrivateVars;
26682669
mlir::UnitAttr nowaitAttr;
26692670

26702671
ClauseProcessor cp(converter, semaCtx, beginClauseList);
@@ -2678,7 +2679,8 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
26782679
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
26792680
.setGenNested(genNested)
26802681
.setClauses(&beginClauseList),
2681-
allocateOperands, allocatorOperands, nowaitAttr);
2682+
allocateOperands, allocatorOperands, copyPrivateVars,
2683+
/*copyPrivateFuncs=*/nullptr, nowaitAttr);
26822684
}
26832685

26842686
static mlir::omp::TaskOp

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,10 +466,16 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
466466
master thread), in the context of its implicit task. The other threads
467467
in the team, which do not execute the block, wait at an implicit barrier
468468
at the end of the single construct unless a nowait clause is specified.
469+
470+
If copyprivate variables and functions are specified, then each thread
471+
variable is updated with the variable value of the thread that executed
472+
the single region, using the specified copy functions.
469473
}];
470474

471475
let arguments = (ins Variadic<AnyType>:$allocate_vars,
472476
Variadic<AnyType>:$allocators_vars,
477+
Variadic<OpenMP_PointerLikeType>:$copyprivate_vars,
478+
OptionalAttr<SymbolRefArrayAttr>:$copyprivate_funcs,
473479
UnitAttr:$nowait);
474480

475481
let regions = (region AnyRegion:$region);
@@ -481,6 +487,10 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
481487
$allocators_vars, type($allocators_vars)
482488
) `)`
483489
|`nowait` $nowait
490+
|`copyprivate` `(`
491+
custom<CopyPrivateVarList>(
492+
$copyprivate_vars, type($copyprivate_vars), $copyprivate_funcs
493+
) `)`
484494
) $region attr-dict
485495
}];
486496
let hasVerifier = 1;

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,110 @@ static LogicalResult verifyReductionVarList(Operation *op,
573573
return success();
574574
}
575575

576+
//===----------------------------------------------------------------------===//
577+
// Parser, printer and verifier for CopyPrivateVarList
578+
//===----------------------------------------------------------------------===//
579+
580+
/// copyprivate-entry-list ::= copyprivate-entry
581+
/// | copyprivate-entry-list `,` copyprivate-entry
582+
/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
583+
static ParseResult parseCopyPrivateVarList(
584+
OpAsmParser &parser,
585+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
586+
SmallVectorImpl<Type> &types, ArrayAttr &copyPrivateSymbols) {
587+
SmallVector<SymbolRefAttr> copyPrivateFuncsVec;
588+
if (failed(parser.parseCommaSeparatedList([&]() {
589+
if (parser.parseOperand(operands.emplace_back()) ||
590+
parser.parseArrow() ||
591+
parser.parseAttribute(copyPrivateFuncsVec.emplace_back()) ||
592+
parser.parseColonType(types.emplace_back()))
593+
return failure();
594+
return success();
595+
})))
596+
return failure();
597+
SmallVector<Attribute> copyPrivateFuncs(copyPrivateFuncsVec.begin(),
598+
copyPrivateFuncsVec.end());
599+
copyPrivateSymbols = ArrayAttr::get(parser.getContext(), copyPrivateFuncs);
600+
return success();
601+
}
602+
603+
/// Print CopyPrivate clause
604+
static void printCopyPrivateVarList(OpAsmPrinter &p, Operation *op,
605+
OperandRange copyPrivateVars,
606+
TypeRange copyPrivateTypes,
607+
std::optional<ArrayAttr> copyPrivateFuncs) {
608+
if (!copyPrivateFuncs.has_value())
609+
return;
610+
llvm::interleaveComma(
611+
llvm::zip(copyPrivateVars, *copyPrivateFuncs, copyPrivateTypes), p,
612+
[&](const auto &args) {
613+
p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
614+
<< std::get<2>(args);
615+
});
616+
}
617+
618+
/// Verifies CopyPrivate Clause
619+
static LogicalResult
620+
verifyCopyPrivateVarList(Operation *op, OperandRange copyPrivateVars,
621+
std::optional<ArrayAttr> copyPrivateFuncs) {
622+
size_t copyPrivateFuncsSize =
623+
copyPrivateFuncs.has_value() ? copyPrivateFuncs->size() : 0;
624+
if (copyPrivateFuncsSize != copyPrivateVars.size())
625+
return op->emitOpError() << "inconsistent number of copyPrivate vars (= "
626+
<< copyPrivateVars.size()
627+
<< ") and functions (= " << copyPrivateFuncsSize
628+
<< "), both must be equal";
629+
if (!copyPrivateFuncs.has_value())
630+
return success();
631+
632+
for (auto copyPrivateVarAndFunc :
633+
llvm::zip(copyPrivateVars, *copyPrivateFuncs)) {
634+
auto symbolRef =
635+
llvm::cast<SymbolRefAttr>(std::get<1>(copyPrivateVarAndFunc));
636+
std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
637+
funcOp;
638+
if (mlir::func::FuncOp mlirFuncOp =
639+
SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
640+
symbolRef))
641+
funcOp = mlirFuncOp;
642+
else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
643+
SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
644+
op, symbolRef))
645+
funcOp = llvmFuncOp;
646+
647+
auto getNumArguments = [&] {
648+
return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
649+
};
650+
651+
auto getArgumentType = [&](unsigned i) {
652+
return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
653+
*funcOp);
654+
};
655+
656+
if (!funcOp)
657+
return op->emitOpError() << "expected symbol reference " << symbolRef
658+
<< " to point to a copy function";
659+
660+
if (getNumArguments() != 2)
661+
return op->emitOpError()
662+
<< "expected copy function " << symbolRef << " to have 2 operands";
663+
664+
Type argTy = getArgumentType(0);
665+
if (argTy != getArgumentType(1))
666+
return op->emitOpError() << "expected copy function " << symbolRef
667+
<< " arguments to have the same type";
668+
669+
Type varType = std::get<0>(copyPrivateVarAndFunc).getType();
670+
if (argTy != varType)
671+
return op->emitOpError()
672+
<< "expected copy function arguments' type (" << argTy
673+
<< ") to be the same as copyprivate variable's type (" << varType
674+
<< ")";
675+
}
676+
677+
return success();
678+
}
679+
576680
//===----------------------------------------------------------------------===//
577681
// Parser, printer and verifier for DependVarList
578682
//===----------------------------------------------------------------------===//
@@ -1152,7 +1256,8 @@ LogicalResult SingleOp::verify() {
11521256
return emitError(
11531257
"expected equal sizes for allocate and allocator variables");
11541258

1155-
return success();
1259+
return verifyCopyPrivateVarList(*this, getCopyprivateVars(),
1260+
getCopyprivateFuncs());
11561261
}
11571262

11581263
//===----------------------------------------------------------------------===//

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,63 @@ func.func @omp_single(%data_var : memref<i32>) -> () {
12551255
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
12561256
"omp.single" (%data_var) ({
12571257
omp.barrier
1258-
}) {operandSegmentSizes = array<i32: 1,0>} : (memref<i32>) -> ()
1258+
}) {operandSegmentSizes = array<i32: 1,0,0>} : (memref<i32>) -> ()
1259+
return
1260+
}
1261+
1262+
// -----
1263+
1264+
func.func @omp_single_copyprivate(%data_var : memref<i32>) -> () {
1265+
// expected-error @below {{inconsistent number of copyPrivate vars (= 1) and functions (= 0), both must be equal}}
1266+
"omp.single" (%data_var) ({
1267+
omp.barrier
1268+
}) {operandSegmentSizes = array<i32: 0,0,1>} : (memref<i32>) -> ()
1269+
return
1270+
}
1271+
1272+
// -----
1273+
1274+
func.func @omp_single_copyprivate(%data_var : memref<i32>) -> () {
1275+
// expected-error @below {{expected symbol reference @copy_func to point to a copy function}}
1276+
omp.single copyprivate(%data_var -> @copy_func : memref<i32>) {
1277+
omp.barrier
1278+
}
1279+
return
1280+
}
1281+
1282+
// -----
1283+
1284+
func.func private @copy_func(memref<i32>)
1285+
1286+
func.func @omp_single_copyprivate(%data_var : memref<i32>) -> () {
1287+
// expected-error @below {{expected copy function @copy_func to have 2 operands}}
1288+
omp.single copyprivate(%data_var -> @copy_func : memref<i32>) {
1289+
omp.barrier
1290+
}
1291+
return
1292+
}
1293+
1294+
// -----
1295+
1296+
func.func private @copy_func(memref<i32>, memref<f32>)
1297+
1298+
func.func @omp_single_copyprivate(%data_var : memref<i32>) -> () {
1299+
// expected-error @below {{expected copy function @copy_func arguments to have the same type}}
1300+
omp.single copyprivate(%data_var -> @copy_func : memref<i32>) {
1301+
omp.barrier
1302+
}
1303+
return
1304+
}
1305+
1306+
// -----
1307+
1308+
func.func private @copy_func(memref<f32>, memref<f32>)
1309+
1310+
func.func @omp_single_copyprivate(%data_var : memref<i32>) -> () {
1311+
// expected-error @below {{expected copy function arguments' type ('memref<f32>') to be the same as copyprivate variable's type ('memref<i32>')}}
1312+
omp.single copyprivate(%data_var -> @copy_func : memref<i32>) {
1313+
omp.barrier
1314+
}
12591315
return
12601316
}
12611317

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,6 +1622,23 @@ func.func @omp_single_multiple_blocks() {
16221622
return
16231623
}
16241624

1625+
func.func private @copy_i32(memref<i32>, memref<i32>)
1626+
1627+
// CHECK-LABEL: func @omp_single_copyprivate
1628+
func.func @omp_single_copyprivate(%data_var: memref<i32>) {
1629+
omp.parallel {
1630+
// CHECK: omp.single copyprivate(%{{.*}} -> @copy_i32 : memref<i32>) {
1631+
omp.single copyprivate(%data_var -> @copy_i32 : memref<i32>) {
1632+
"test.payload"() : () -> ()
1633+
// CHECK: omp.terminator
1634+
omp.terminator
1635+
}
1636+
// CHECK: omp.terminator
1637+
omp.terminator
1638+
}
1639+
return
1640+
}
1641+
16251642
// CHECK-LABEL: @omp_task
16261643
// CHECK-SAME: (%[[bool_var:.*]]: i1, %[[i64_var:.*]]: i64, %[[i32_var:.*]]: i32, %[[data_var:.*]]: memref<i32>)
16271644
func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memref<i32>) {

0 commit comments

Comments
 (0)