Skip to content

Commit d4e9ba5

Browse files
authored
[mlir][OpenMP] Standardise representation of reduction clause (#96215)
Now all operations with a reduction clause have an array of bools controlling whether each reduction variable should be passed by reference or value. This was already supported for Wsloop and Parallel. The new operations modified here currently have no flang lowering or translation to LLVMIR and so further changes are not needed. It isn't possible to check the verifier in mlir/test/Dialect/OpenMP/invalid.mlir because there is no way of parsing an operation to have an incorrect number of byref attributes. The verifier exists to pick up buggy operation builders or in-place operation modification.
1 parent 2a948d1 commit d4e9ba5

File tree

5 files changed

+137
-36
lines changed

5 files changed

+137
-36
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,8 @@ bool ClauseProcessor::processReduction(
10271027

10281028
// Copy local lists into the output.
10291029
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
1030-
llvm::copy(reduceVarByRef, std::back_inserter(result.reduceVarByRef));
1030+
llvm::copy(reduceVarByRef,
1031+
std::back_inserter(result.reductionVarsByRef));
10311032
llvm::copy(reductionDeclSymbols,
10321033
std::back_inserter(result.reductionDeclSymbols));
10331034

mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ struct IfClauseOps {
9494

9595
struct InReductionClauseOps {
9696
llvm::SmallVector<Value> inReductionVars;
97+
llvm::SmallVector<bool> inReductionVarsByRef;
9798
llvm::SmallVector<Attribute> inReductionDeclSymbols;
9899
};
99100

@@ -178,7 +179,7 @@ struct ProcBindClauseOps {
178179

179180
struct ReductionClauseOps {
180181
llvm::SmallVector<Value> reductionVars;
181-
llvm::SmallVector<bool> reduceVarByRef;
182+
llvm::SmallVector<bool> reductionVarsByRef;
182183
llvm::SmallVector<Attribute> reductionDeclSymbols;
183184
};
184185

@@ -199,6 +200,7 @@ struct SimdlenClauseOps {
199200

200201
struct TaskReductionClauseOps {
201202
llvm::SmallVector<Value> taskReductionVars;
203+
llvm::SmallVector<bool> taskReductionVarsByRef;
202204
llvm::SmallVector<Attribute> taskReductionDeclSymbols;
203205
};
204206

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def TeamsOp : OpenMP_Op<"teams", [
250250
Variadic<AnyType>:$allocate_vars,
251251
Variadic<AnyType>:$allocators_vars,
252252
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
253+
OptionalAttr<DenseBoolArrayAttr>:$reduction_vars_byref,
253254
OptionalAttr<SymbolRefArrayAttr>:$reductions);
254255

255256
let regions = (region AnyRegion:$region);
@@ -266,8 +267,8 @@ def TeamsOp : OpenMP_Op<"teams", [
266267
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
267268
| `reduction` `(`
268269
custom<ReductionVarList>(
269-
$reduction_vars, type($reduction_vars), $reductions
270-
) `)`
270+
$reduction_vars, type($reduction_vars), $reduction_vars_byref,
271+
$reductions ) `)`
271272
| `allocate` `(`
272273
custom<AllocateAndAllocator>(
273274
$allocate_vars, type($allocate_vars),
@@ -310,7 +311,9 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
310311
by the accumulator it uses and accumulators must not be repeated in the same
311312
reduction. The reduction declaration specifies how to combine the values
312313
from each section into the final value, which is available in the
313-
accumulator after all the sections complete.
314+
accumulator after all the sections complete. True values in
315+
reduction_vars_byref indicate that the reduction variable should be passed
316+
by reference.
314317

315318
The $allocators_vars and $allocate_vars parameters are a variadic list of values
316319
that specify the memory allocator to be used to obtain storage for private values.
@@ -319,6 +322,7 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
319322
implicit barrier at the end of the construct.
320323
}];
321324
let arguments = (ins Variadic<OpenMP_PointerLikeType>:$reduction_vars,
325+
OptionalAttr<DenseBoolArrayAttr>:$reduction_vars_byref,
322326
OptionalAttr<SymbolRefArrayAttr>:$reductions,
323327
Variadic<AnyType>:$allocate_vars,
324328
Variadic<AnyType>:$allocators_vars,
@@ -333,7 +337,8 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
333337
let assemblyFormat = [{
334338
oilist( `reduction` `(`
335339
custom<ReductionVarList>(
336-
$reduction_vars, type($reduction_vars), $reductions
340+
$reduction_vars, type($reduction_vars), $reduction_vars_byref,
341+
$reductions
337342
) `)`
338343
| `allocate` `(`
339344
custom<AllocateAndAllocator>(
@@ -793,6 +798,8 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
793798

794799
The `in_reduction` clause specifies that this particular task (among all the
795800
tasks in current taskgroup, if any) participates in a reduction.
801+
`in_reduction_vars_byref` indicates whether each reduction variable should
802+
be passed by value or by reference.
796803

797804
The `priority` clause is a hint for the priority of the generated task.
798805
The `priority` is a non-negative integer expression that provides a hint for
@@ -818,6 +825,7 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
818825
UnitAttr:$untied,
819826
UnitAttr:$mergeable,
820827
Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
828+
OptionalAttr<DenseBoolArrayAttr>:$in_reduction_vars_byref,
821829
OptionalAttr<SymbolRefArrayAttr>:$in_reductions,
822830
Optional<I32>:$priority,
823831
OptionalAttr<TaskDependArrayAttr>:$depends,
@@ -835,7 +843,8 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
835843
|`mergeable` $mergeable
836844
|`in_reduction` `(`
837845
custom<ReductionVarList>(
838-
$in_reduction_vars, type($in_reduction_vars), $in_reductions
846+
$in_reduction_vars, type($in_reduction_vars),
847+
$in_reduction_vars_byref, $in_reductions
839848
) `)`
840849
|`priority` `(` $priority `)`
841850
|`allocate` `(`
@@ -962,8 +971,10 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
962971
UnitAttr:$untied,
963972
UnitAttr:$mergeable,
964973
Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
974+
OptionalAttr<DenseBoolArrayAttr>:$in_reduction_vars_byref,
965975
OptionalAttr<SymbolRefArrayAttr>:$in_reductions,
966976
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
977+
OptionalAttr<DenseBoolArrayAttr>:$reduction_vars_byref,
967978
OptionalAttr<SymbolRefArrayAttr>:$reductions,
968979
Optional<IntLikeType>:$priority,
969980
Variadic<AnyType>:$allocate_vars,
@@ -985,11 +996,13 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
985996
|`mergeable` $mergeable
986997
|`in_reduction` `(`
987998
custom<ReductionVarList>(
988-
$in_reduction_vars, type($in_reduction_vars), $in_reductions
999+
$in_reduction_vars, type($in_reduction_vars),
1000+
$in_reduction_vars_byref, $in_reductions
9891001
) `)`
9901002
|`reduction` `(`
9911003
custom<ReductionVarList>(
992-
$reduction_vars, type($reduction_vars), $reductions
1004+
$reduction_vars, type($reduction_vars), $reduction_vars_byref,
1005+
$reductions
9931006
) `)`
9941007
|`priority` `(` $priority `:` type($priority) `)`
9951008
|`allocate` `(`
@@ -1040,6 +1053,7 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments,
10401053
}];
10411054

10421055
let arguments = (ins Variadic<OpenMP_PointerLikeType>:$task_reduction_vars,
1056+
OptionalAttr<DenseBoolArrayAttr>:$task_reduction_vars_byref,
10431057
OptionalAttr<SymbolRefArrayAttr>:$task_reductions,
10441058
Variadic<AnyType>:$allocate_vars,
10451059
Variadic<AnyType>:$allocators_vars);
@@ -1053,7 +1067,8 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments,
10531067
let assemblyFormat = [{
10541068
oilist(`task_reduction` `(`
10551069
custom<ReductionVarList>(
1056-
$task_reduction_vars, type($task_reduction_vars), $task_reductions
1070+
$task_reduction_vars, type($task_reduction_vars),
1071+
$task_reduction_vars_byref, $task_reductions
10571072
) `)`
10581073
|`allocate` `(`
10591074
custom<AllocateAndAllocator>(

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ static ArrayAttr makeArrayAttr(MLIRContext *context,
4848
return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
4949
}
5050

51+
static DenseBoolArrayAttr
52+
makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) {
53+
return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
54+
}
55+
5156
namespace {
5257
struct MemRefPointerLikeModel
5358
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
@@ -499,7 +504,7 @@ static ParseResult parseClauseWithRegionArgs(
499504
return success();
500505
})))
501506
return failure();
502-
isByRef = DenseBoolArrayAttr::get(parser.getContext(), isByRefVec);
507+
isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
503508

504509
auto *argsBegin = regionPrivateArgs.begin();
505510
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
@@ -591,7 +596,7 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
591596
mlir::SmallVector<bool> isByRefVec;
592597
isByRefVec.resize(privateVarTypes.size(), false);
593598
DenseBoolArrayAttr isByRef =
594-
DenseBoolArrayAttr::get(op->getContext(), isByRefVec);
599+
makeDenseBoolArrayAttr(op->getContext(), isByRefVec);
595600

596601
printClauseWithRegionArgs(p, op, argsSubrange, "private",
597602
privateVarOperands, privateVarTypes, isByRef,
@@ -607,18 +612,22 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
607612
static ParseResult
608613
parseReductionVarList(OpAsmParser &parser,
609614
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
610-
SmallVectorImpl<Type> &types,
615+
SmallVectorImpl<Type> &types, DenseBoolArrayAttr &isByRef,
611616
ArrayAttr &redcuctionSymbols) {
612617
SmallVector<SymbolRefAttr> reductionVec;
618+
SmallVector<bool> isByRefVec;
613619
if (failed(parser.parseCommaSeparatedList([&]() {
620+
ParseResult optionalByref = parser.parseOptionalKeyword("byref");
614621
if (parser.parseAttribute(reductionVec.emplace_back()) ||
615622
parser.parseArrow() ||
616623
parser.parseOperand(operands.emplace_back()) ||
617624
parser.parseColonType(types.emplace_back()))
618625
return failure();
626+
isByRefVec.push_back(optionalByref.succeeded());
619627
return success();
620628
})))
621629
return failure();
630+
isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
622631
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
623632
redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
624633
return success();
@@ -628,11 +637,21 @@ parseReductionVarList(OpAsmParser &parser,
628637
static void printReductionVarList(OpAsmPrinter &p, Operation *op,
629638
OperandRange reductionVars,
630639
TypeRange reductionTypes,
640+
std::optional<DenseBoolArrayAttr> isByRef,
631641
std::optional<ArrayAttr> reductions) {
632-
for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
642+
auto getByRef = [&](unsigned i) -> const char * {
643+
if (!isByRef || !*isByRef)
644+
return "";
645+
assert(isByRef->empty() || i < isByRef->size());
646+
if (!isByRef->empty() && (*isByRef)[i])
647+
return "byref ";
648+
return "";
649+
};
650+
651+
for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) {
633652
if (i != 0)
634653
p << ", ";
635-
p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
654+
p << getByRef(i) << (*reductions)[i] << " -> " << reductionVars[i] << " : "
636655
<< reductionVars[i].getType();
637656
}
638657
}
@@ -641,16 +660,12 @@ static void printReductionVarList(OpAsmPrinter &p, Operation *op,
641660
static LogicalResult
642661
verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductions,
643662
OperandRange reductionVars,
644-
std::optional<ArrayRef<bool>> byRef = std::nullopt) {
663+
std::optional<ArrayRef<bool>> byRef) {
645664
if (!reductionVars.empty()) {
646665
if (!reductions || reductions->size() != reductionVars.size())
647666
return op->emitOpError()
648667
<< "expected as many reduction symbol references "
649668
"as reduction variables";
650-
if (mlir::isa<omp::WsloopOp, omp::ParallelOp>(op))
651-
assert(byRef);
652-
else
653-
assert(!byRef); // TODO: support byref reductions on other operations
654669
if (byRef && byRef->size() != reductionVars.size())
655670
return op->emitError() << "expected as many reduction variable by "
656671
"reference attributes as reduction variables";
@@ -1492,7 +1507,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
14921507
ParallelOp::build(builder, state, clauses.ifVar, clauses.numThreadsVar,
14931508
clauses.allocateVars, clauses.allocatorVars,
14941509
clauses.reductionVars,
1495-
DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
1510+
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
14961511
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
14971512
clauses.procBindKindAttr, clauses.privateVars,
14981513
makeArrayAttr(ctx, clauses.privatizers));
@@ -1590,6 +1605,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
15901605
clauses.numTeamsUpperVar, clauses.ifVar,
15911606
clauses.threadLimitVar, clauses.allocateVars,
15921607
clauses.allocatorVars, clauses.reductionVars,
1608+
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
15931609
makeArrayAttr(ctx, clauses.reductionDeclSymbols));
15941610
}
15951611

@@ -1621,7 +1637,8 @@ LogicalResult TeamsOp::verify() {
16211637
return emitError(
16221638
"expected equal sizes for allocate and allocator variables");
16231639

1624-
return verifyReductionVarList(*this, getReductions(), getReductionVars());
1640+
return verifyReductionVarList(*this, getReductions(), getReductionVars(),
1641+
getReductionVarsByref());
16251642
}
16261643

16271644
//===----------------------------------------------------------------------===//
@@ -1633,6 +1650,7 @@ void SectionsOp::build(OpBuilder &builder, OperationState &state,
16331650
MLIRContext *ctx = builder.getContext();
16341651
// TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
16351652
SectionsOp::build(builder, state, clauses.reductionVars,
1653+
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
16361654
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
16371655
clauses.allocateVars, clauses.allocatorVars,
16381656
clauses.nowaitAttr);
@@ -1643,7 +1661,8 @@ LogicalResult SectionsOp::verify() {
16431661
return emitError(
16441662
"expected equal sizes for allocate and allocator variables");
16451663

1646-
return verifyReductionVarList(*this, getReductions(), getReductionVars());
1664+
return verifyReductionVarList(*this, getReductions(), getReductionVars(),
1665+
getReductionVarsByref());
16471666
}
16481667

16491668
LogicalResult SectionsOp::verifyRegions() {
@@ -1733,7 +1752,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
17331752
// privatizers.
17341753
WsloopOp::build(builder, state, clauses.linearVars, clauses.linearStepVars,
17351754
clauses.reductionVars,
1736-
DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
1755+
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
17371756
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
17381757
clauses.scheduleValAttr, clauses.scheduleChunkVar,
17391758
clauses.scheduleModAttr, clauses.scheduleSimdAttr,
@@ -1934,6 +1953,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
19341953
TaskOp::build(
19351954
builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
19361955
clauses.mergeableAttr, clauses.inReductionVars,
1956+
makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef),
19371957
makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar,
19381958
makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
19391959
clauses.allocateVars, clauses.allocatorVars);
@@ -1945,7 +1965,8 @@ LogicalResult TaskOp::verify() {
19451965
return failed(verifyDependVars)
19461966
? verifyDependVars
19471967
: verifyReductionVarList(*this, getInReductions(),
1948-
getInReductionVars());
1968+
getInReductionVars(),
1969+
getInReductionVarsByref());
19491970
}
19501971

19511972
//===----------------------------------------------------------------------===//
@@ -1955,14 +1976,17 @@ LogicalResult TaskOp::verify() {
19551976
void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
19561977
const TaskgroupClauseOps &clauses) {
19571978
MLIRContext *ctx = builder.getContext();
1958-
TaskgroupOp::build(builder, state, clauses.taskReductionVars,
1959-
makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
1960-
clauses.allocateVars, clauses.allocatorVars);
1979+
TaskgroupOp::build(
1980+
builder, state, clauses.taskReductionVars,
1981+
makeDenseBoolArrayAttr(ctx, clauses.taskReductionVarsByRef),
1982+
makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
1983+
clauses.allocateVars, clauses.allocatorVars);
19611984
}
19621985

19631986
LogicalResult TaskgroupOp::verify() {
19641987
return verifyReductionVarList(*this, getTaskReductions(),
1965-
getTaskReductionVars());
1988+
getTaskReductionVars(),
1989+
getTaskReductionVarsByref());
19661990
}
19671991

19681992
//===----------------------------------------------------------------------===//
@@ -1976,7 +2000,9 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
19762000
TaskloopOp::build(
19772001
builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
19782002
clauses.mergeableAttr, clauses.inReductionVars,
2003+
makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef),
19792004
makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
2005+
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
19802006
makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
19812007
clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
19822008
clauses.numTasksVar, clauses.nogroupAttr);
@@ -1994,10 +2020,11 @@ LogicalResult TaskloopOp::verify() {
19942020
if (getAllocateVars().size() != getAllocatorsVars().size())
19952021
return emitError(
19962022
"expected equal sizes for allocate and allocator variables");
1997-
if (failed(
1998-
verifyReductionVarList(*this, getReductions(), getReductionVars())) ||
2023+
if (failed(verifyReductionVarList(*this, getReductions(), getReductionVars(),
2024+
getReductionVarsByref())) ||
19992025
failed(verifyReductionVarList(*this, getInReductions(),
2000-
getInReductionVars())))
2026+
getInReductionVars(),
2027+
getInReductionVarsByref())))
20012028
return failure();
20022029

20032030
if (!getReductionVars().empty() && getNogroup())

0 commit comments

Comments
 (0)