Skip to content

Commit 82a0c88

Browse files
committed
[MLIR][OpenMP] Normalize handling of entry block arguments
This patch introduces a new MLIR interface for the OpenMP dialect aimed at providing a uniform way of verifying and handling entry block arguments defined by OpenMP clauses. The approach consists in defining a set of overrideable methods that return the number of block arguments the operation holds regarding each of the clauses that may define them. These by default return 0, but they are overriden by the corresponding clause through the `extraClassDeclaration` mechanism. Another set of interface methods to get the actual lists of block arguments is defined, which is implemented based on the previously described methods. These implicitly define a standardized ordering between the list of block arguments associated to each clause, based on the alphabetical ordering of their names. They should be the preferred way of matching operation arguments and entry block arguments to that operation's first region. Some updates are made to the printing/parsing of `omp.parallel` to follow the expected order between `private` and `reduction` clauses, as well as the MLIR to LLVM IR translation pass to access block arguments using the new interface. Unit tests of operations impacted by additional verification checks and sorting of entry block arguments.
1 parent 07e0b8a commit 82a0c88

File tree

11 files changed

+195
-76
lines changed

11 files changed

+195
-76
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -472,17 +472,26 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
472472
/// \param [in] infoAccessor - for a private variable, this returns the
473473
/// data we want to merge: type or location.
474474
/// \param [out] allRegionArgsInfo - the merged list of region info.
475+
/// \param [in] addBeforePrivate - `true` if the passed information goes before
476+
/// private information.
475477
template <typename OMPOp, typename InfoTy>
476478
static void
477479
mergePrivateVarsInfo(OMPOp op, llvm::ArrayRef<InfoTy> currentList,
478480
llvm::function_ref<InfoTy(mlir::Value)> infoAccessor,
479-
llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo) {
481+
llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo,
482+
bool addBeforePrivate) {
480483
mlir::OperandRange privateVars = op.getPrivateVars();
481484

482-
llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
483-
[](InfoTy i) { return i; });
485+
if (addBeforePrivate)
486+
llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
487+
[](InfoTy i) { return i; });
488+
484489
llvm::transform(privateVars, std::back_inserter(allRegionArgsInfo),
485490
infoAccessor);
491+
492+
if (!addBeforePrivate)
493+
llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
494+
[](InfoTy i) { return i; });
486495
}
487496

488497
//===----------------------------------------------------------------------===//
@@ -868,12 +877,12 @@ static void genBodyOfTargetOp(
868877
mergePrivateVarsInfo(targetOp, mapSymTypes,
869878
llvm::function_ref<mlir::Type(mlir::Value)>{
870879
[](mlir::Value v) { return v.getType(); }},
871-
allRegionArgTypes);
880+
allRegionArgTypes, /*addBeforePrivate=*/true);
872881

873882
mergePrivateVarsInfo(targetOp, mapSymLocs,
874883
llvm::function_ref<mlir::Location(mlir::Value)>{
875884
[](mlir::Value v) { return v.getLoc(); }},
876-
allRegionArgLocs);
885+
allRegionArgLocs, /*addBeforePrivate=*/true);
877886

878887
mlir::Block *regionBlock = firOpBuilder.createBlock(
879888
&region, {}, allRegionArgTypes, allRegionArgLocs);
@@ -1478,21 +1487,21 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
14781487
mergePrivateVarsInfo(parallelOp, reductionTypes,
14791488
llvm::function_ref<mlir::Type(mlir::Value)>{
14801489
[](mlir::Value v) { return v.getType(); }},
1481-
allRegionArgTypes);
1490+
allRegionArgTypes, /*addBeforePrivate=*/false);
14821491

14831492
llvm::SmallVector<mlir::Location> allRegionArgLocs;
14841493
mergePrivateVarsInfo(parallelOp, llvm::ArrayRef(reductionLocs),
14851494
llvm::function_ref<mlir::Location(mlir::Value)>{
14861495
[](mlir::Value v) { return v.getLoc(); }},
1487-
allRegionArgLocs);
1496+
allRegionArgLocs, /*addBeforePrivate=*/false);
14881497

14891498
mlir::Region &region = parallelOp.getRegion();
14901499
firOpBuilder.createBlock(&region, /*insertPt=*/{}, allRegionArgTypes,
14911500
allRegionArgLocs);
14921501

1493-
llvm::SmallVector<const semantics::Symbol *> allSymbols(reductionSyms);
1494-
allSymbols.append(dsp->getDelayedPrivSymbols().begin(),
1495-
dsp->getDelayedPrivSymbols().end());
1502+
llvm::SmallVector<const semantics::Symbol *> allSymbols(
1503+
dsp->getDelayedPrivSymbols());
1504+
allSymbols.append(reductionSyms.begin(), reductionSyms.end());
14961505

14971506
unsigned argIdx = 0;
14981507
for (const semantics::Symbol *arg : allSymbols) {

flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,5 @@ subroutine red_and_delayed_private
2626

2727
! CHECK-LABEL: _QPred_and_delayed_private
2828
! CHECK: omp.parallel
29-
! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
30-
! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
29+
! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
30+
! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {

flang/test/Lower/OpenMP/delayed-privatization-reduction.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ subroutine red_and_delayed_private
2929

3030
! CHECK-LABEL: _QPred_and_delayed_private
3131
! CHECK: omp.parallel
32-
! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
33-
! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
32+
! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
33+
! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ class OpenMP_InReductionClauseSkip<
451451
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
452452
extraClassDeclaration> {
453453
let traits = [
454-
ReductionClauseInterface
454+
BlockArgOpenMPOpInterface, ReductionClauseInterface
455455
];
456456

457457
let arguments = (ins
@@ -472,6 +472,8 @@ class OpenMP_InReductionClauseSkip<
472472
return SmallVector<Value>(getInReductionVars().begin(),
473473
getInReductionVars().end());
474474
}
475+
476+
unsigned numInReductionBlockArgs() { return getInReductionVars().size(); }
475477
}];
476478

477479
// Description varies depending on the operation.
@@ -575,6 +577,8 @@ class OpenMP_MapClauseSkip<
575577
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
576578
extraClassDeclaration> {
577579
let traits = [
580+
// Not adding the BlockArgOpenMPOpInterface here because omp.target is the
581+
// only operation defining block arguments for `map` clauses.
578582
MapClauseOwningOpInterface
579583
];
580584

@@ -923,6 +927,10 @@ class OpenMP_PrivateClauseSkip<
923927
bit description = false, bit extraClassDeclaration = false
924928
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
925929
extraClassDeclaration> {
930+
let traits = [
931+
BlockArgOpenMPOpInterface
932+
];
933+
926934
let arguments = (ins
927935
Variadic<AnyType>:$private_vars,
928936
OptionalAttr<SymbolRefArrayAttr>:$private_syms
@@ -933,6 +941,10 @@ class OpenMP_PrivateClauseSkip<
933941
custom<PrivateList>($private_vars, type($private_vars), $private_syms) `)`
934942
}];
935943

944+
let extraClassDeclaration = [{
945+
unsigned numPrivateBlockArgs() { return getPrivateVars().size(); }
946+
}];
947+
936948
// TODO: Add description.
937949
}
938950

@@ -973,7 +985,7 @@ class OpenMP_ReductionClauseSkip<
973985
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
974986
extraClassDeclaration> {
975987
let traits = [
976-
ReductionClauseInterface
988+
BlockArgOpenMPOpInterface, ReductionClauseInterface
977989
];
978990

979991
let arguments = (ins
@@ -991,6 +1003,7 @@ class OpenMP_ReductionClauseSkip<
9911003
let extraClassDeclaration = [{
9921004
/// Returns the number of reduction variables.
9931005
unsigned getNumReductionVars() { return getReductionVars().size(); }
1006+
unsigned numReductionBlockArgs() { return getReductionVars().size(); }
9941007
}];
9951008

9961009
// Description varies depending on the operation.
@@ -1104,7 +1117,7 @@ class OpenMP_TaskReductionClauseSkip<
11041117
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
11051118
extraClassDeclaration> {
11061119
let traits = [
1107-
ReductionClauseInterface
1120+
BlockArgOpenMPOpInterface, ReductionClauseInterface
11081121
];
11091122

11101123
let arguments = (ins
@@ -1119,6 +1132,18 @@ class OpenMP_TaskReductionClauseSkip<
11191132
$task_reduction_byref, $task_reduction_syms) `)`
11201133
}];
11211134

1135+
let extraClassDeclaration = [{
1136+
/// Returns the reduction variables.
1137+
SmallVector<Value> getReductionVars() {
1138+
return SmallVector<Value>(getTaskReductionVars().begin(),
1139+
getTaskReductionVars().end());
1140+
}
1141+
1142+
unsigned numTaskReductionBlockArgs() {
1143+
return getTaskReductionVars().size();
1144+
}
1145+
}];
1146+
11221147
let description = [{
11231148
The `task_reduction` clause specifies a reduction among tasks. For each list
11241149
item, the number of copies is unspecified. Any copies associated with the
@@ -1130,14 +1155,6 @@ class OpenMP_TaskReductionClauseSkip<
11301155
attribute, and whether the reduction variable should be passed into the
11311156
reduction region by value or by reference in `task_reduction_byref`.
11321157
}];
1133-
1134-
let extraClassDeclaration = [{
1135-
/// Returns the reduction variables.
1136-
SmallVector<Value> getReductionVars() {
1137-
return SmallVector<Value>(getTaskReductionVars().begin(),
1138-
getTaskReductionVars().end());
1139-
}
1140-
}];
11411158
}
11421159

11431160
def OpenMP_TaskReductionClause : OpenMP_TaskReductionClauseSkip<>;

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,8 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [
10431043
//===----------------------------------------------------------------------===//
10441044

10451045
def TargetOp : OpenMP_Op<"target", traits = [
1046-
AttrSizedOperandSegments, IsolatedFromAbove, OutlineableOpenMPOpInterface
1046+
AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove,
1047+
OutlineableOpenMPOpInterface
10471048
], clauses = [
10481049
// TODO: Complete clause list (defaultmap, uses_allocators).
10491050
OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
@@ -1065,6 +1066,10 @@ def TargetOp : OpenMP_Op<"target", traits = [
10651066
OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
10661067
];
10671068

1069+
let extraClassDeclaration = [{
1070+
unsigned numMapBlockArgs() { return getMapVars().size(); }
1071+
}] # clausesExtraClassDeclaration;
1072+
10681073
let hasVerifier = 1;
10691074
}
10701075

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,86 @@
1515

1616
include "mlir/IR/OpBase.td"
1717

18+
def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
19+
let description = [{
20+
OpenMP operations that define entry block arguments as part of the
21+
representation of its clauses.
22+
}];
23+
24+
let cppNamespace = "::mlir::omp";
25+
26+
let methods = [
27+
// Default-implemented methods to be overriden by the corresponding clauses.
28+
InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
29+
"unsigned", "numInReductionBlockArgs", (ins), [{}], [{
30+
return 0;
31+
}]>,
32+
InterfaceMethod<"Get number of block arguments defined by `map`.",
33+
"unsigned", "numMapBlockArgs", (ins), [{}], [{
34+
return 0;
35+
}]>,
36+
InterfaceMethod<"Get number of block arguments defined by `private`.",
37+
"unsigned", "numPrivateBlockArgs", (ins), [{}], [{
38+
return 0;
39+
}]>,
40+
InterfaceMethod<"Get number of block arguments defined by `reduction`.",
41+
"unsigned", "numReductionBlockArgs", (ins), [{}], [{
42+
return 0;
43+
}]>,
44+
InterfaceMethod<"Get number of block arguments defined by `task_reduction`.",
45+
"unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
46+
return 0;
47+
}]>,
48+
49+
// Unified access methods for clause-associated entry block arguments.
50+
InterfaceMethod<"Get block arguments defined by `in_reduction`.",
51+
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
52+
"getInReductionBlockArgs", (ins), [{
53+
return $_op->getRegion(0).getArguments().take_front(
54+
$_op.numInReductionBlockArgs());
55+
}]>,
56+
InterfaceMethod<"Get block arguments defined by `map`.",
57+
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
58+
"getMapBlockArgs", (ins), [{
59+
return $_op->getRegion(0).getArguments().slice(
60+
$_op.numInReductionBlockArgs(), $_op.numMapBlockArgs());
61+
}]>,
62+
InterfaceMethod<"Get block arguments defined by `private`.",
63+
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
64+
"getPrivateBlockArgs", (ins), [{
65+
return $_op->getRegion(0).getArguments().slice(
66+
$_op.numInReductionBlockArgs() + $_op.numMapBlockArgs(),
67+
$_op.numPrivateBlockArgs());
68+
}]>,
69+
InterfaceMethod<"Get block arguments defined by `reduction`.",
70+
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
71+
"getReductionBlockArgs", (ins), [{
72+
return $_op->getRegion(0).getArguments().slice(
73+
$_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
74+
$_op.numPrivateBlockArgs(), $_op.numReductionBlockArgs());
75+
}]>,
76+
InterfaceMethod<"Get block arguments defined by `task_reduction`.",
77+
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
78+
"getTaskReductionBlockArgs", (ins), [{
79+
return $_op->getRegion(0).getArguments().slice(
80+
$_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
81+
$_op.numPrivateBlockArgs() + $_op.numReductionBlockArgs(),
82+
$_op.numTaskReductionBlockArgs());
83+
}]>,
84+
];
85+
86+
let verify = [{
87+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
88+
unsigned expectedArgs = iface.numInReductionBlockArgs() +
89+
iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
90+
iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs();
91+
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
92+
return $_op->emitOpError() << "expected at least " << expectedArgs
93+
<< " entry block argument(s)";
94+
return ::mlir::success();
95+
}];
96+
}
97+
1898
def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> {
1999
let description = [{
20100
OpenMP operations whose region will be outlined will implement this

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -536,13 +536,6 @@ static ParseResult parseParallelRegion(
536536
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
537537
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
538538

539-
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
540-
if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
541-
reductionTypes, reductionByref,
542-
reductionSyms, regionPrivateArgs)))
543-
return failure();
544-
}
545-
546539
if (succeeded(parser.parseOptionalKeyword("private"))) {
547540
auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {});
548541
if (failed(parseClauseWithRegionArgs(parser, region, privateVars,
@@ -557,6 +550,13 @@ static ParseResult parseParallelRegion(
557550
}
558551
}
559552

553+
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
554+
if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
555+
reductionTypes, reductionByref,
556+
reductionSyms, regionPrivateArgs)))
557+
return failure();
558+
}
559+
560560
return parser.parseRegion(region, regionPrivateArgs);
561561
}
562562

@@ -566,18 +566,9 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
566566
DenseBoolArrayAttr reductionByref,
567567
ArrayAttr reductionSyms, ValueRange privateVars,
568568
TypeRange privateTypes, ArrayAttr privateSyms) {
569-
if (reductionSyms) {
570-
auto *argsBegin = region.front().getArguments().begin();
571-
MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size());
572-
printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
573-
reductionTypes, reductionByref, reductionSyms);
574-
}
575-
576569
if (privateSyms) {
577570
auto *argsBegin = region.front().getArguments().begin();
578-
MutableArrayRef argsSubrange(argsBegin + reductionVars.size(),
579-
argsBegin + reductionVars.size() +
580-
privateTypes.size());
571+
MutableArrayRef argsSubrange(argsBegin, argsBegin + privateTypes.size());
581572
mlir::SmallVector<bool> isByRefVec;
582573
isByRefVec.resize(privateTypes.size(), false);
583574
DenseBoolArrayAttr isByRef =
@@ -587,6 +578,15 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
587578
privateTypes, isByRef, privateSyms);
588579
}
589580

581+
if (reductionSyms) {
582+
auto *argsBegin = region.front().getArguments().begin();
583+
MutableArrayRef argsSubrange(argsBegin + privateVars.size(),
584+
argsBegin + privateVars.size() +
585+
reductionTypes.size());
586+
printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
587+
reductionTypes, reductionByref, reductionSyms);
588+
}
589+
590590
p.printRegion(region, /*printEntryBlockArgs=*/false);
591591
}
592592

0 commit comments

Comments
 (0)