Skip to content

[mlir][OpenMP] Standardise representation of reduction clause #96215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,8 @@ bool ClauseProcessor::processReduction(

// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reduceVarByRef, std::back_inserter(result.reduceVarByRef));
llvm::copy(reduceVarByRef,
std::back_inserter(result.reductionVarsByRef));
llvm::copy(reductionDeclSymbols,
std::back_inserter(result.reductionDeclSymbols));

Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ struct IfClauseOps {

struct InReductionClauseOps {
llvm::SmallVector<Value> inReductionVars;
llvm::SmallVector<bool> inReductionVarsByRef;
llvm::SmallVector<Attribute> inReductionDeclSymbols;
};

Expand Down Expand Up @@ -177,7 +178,7 @@ struct ProcBindClauseOps {

struct ReductionClauseOps {
llvm::SmallVector<Value> reductionVars;
llvm::SmallVector<bool> reduceVarByRef;
llvm::SmallVector<bool> reductionVarsByRef;
llvm::SmallVector<Attribute> reductionDeclSymbols;
};

Expand All @@ -198,6 +199,7 @@ struct SimdlenClauseOps {

struct TaskReductionClauseOps {
llvm::SmallVector<Value> taskReductionVars;
llvm::SmallVector<bool> taskReductionVarsByRef;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally prefer this name "template" to what's done for in-reduction and reduction. Would you agree about making the change and renaming InReductionClauseOps::inReductionVarsByRef and ReductionClauseOps::reductionVarsByRef?

llvm::SmallVector<Attribute> taskReductionDeclSymbols;
};

Expand Down
31 changes: 23 additions & 8 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def TeamsOp : OpenMP_Op<"teams", [
Variadic<AnyType>:$allocate_vars,
Variadic<AnyType>:$allocators_vars,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$reduction_vars_byref,
OptionalAttr<SymbolRefArrayAttr>:$reductions);

let regions = (region AnyRegion:$region);
Expand All @@ -262,8 +263,8 @@ def TeamsOp : OpenMP_Op<"teams", [
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
| `reduction` `(`
custom<ReductionVarList>(
$reduction_vars, type($reduction_vars), $reductions
) `)`
$reduction_vars, type($reduction_vars), $reduction_vars_byref,
$reductions ) `)`
| `allocate` `(`
custom<AllocateAndAllocator>(
$allocate_vars, type($allocate_vars),
Expand Down Expand Up @@ -306,7 +307,9 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
by the accumulator it uses and accumulators must not be repeated in the same
reduction. The reduction declaration specifies how to combine the values
from each section into the final value, which is available in the
accumulator after all the sections complete.
accumulator after all the sections complete. True values in
reduction_vars_byref indicate that the reduction variable should be passed
by reference.

The $allocators_vars and $allocate_vars parameters are a variadic list of values
that specify the memory allocator to be used to obtain storage for private values.
Expand All @@ -315,6 +318,7 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
implicit barrier at the end of the construct.
}];
let arguments = (ins Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$reduction_vars_byref,
OptionalAttr<SymbolRefArrayAttr>:$reductions,
Variadic<AnyType>:$allocate_vars,
Variadic<AnyType>:$allocators_vars,
Expand All @@ -329,7 +333,8 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
let assemblyFormat = [{
oilist( `reduction` `(`
custom<ReductionVarList>(
$reduction_vars, type($reduction_vars), $reductions
$reduction_vars, type($reduction_vars), $reduction_vars_byref,
$reductions
) `)`
| `allocate` `(`
custom<AllocateAndAllocator>(
Expand Down Expand Up @@ -786,6 +791,8 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,

The `in_reduction` clause specifies that this particular task (among all the
tasks in current taskgroup, if any) participates in a reduction.
`in_reduction_vars_byref` indicates whether each reduction variable should
be passed by value or by reference.

The `priority` clause is a hint for the priority of the generated task.
The `priority` is a non-negative integer expression that provides a hint for
Expand All @@ -811,6 +818,7 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
UnitAttr:$untied,
UnitAttr:$mergeable,
Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$in_reduction_vars_byref,
OptionalAttr<SymbolRefArrayAttr>:$in_reductions,
Optional<I32>:$priority,
OptionalAttr<TaskDependArrayAttr>:$depends,
Expand All @@ -828,7 +836,8 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
|`mergeable` $mergeable
|`in_reduction` `(`
custom<ReductionVarList>(
$in_reduction_vars, type($in_reduction_vars), $in_reductions
$in_reduction_vars, type($in_reduction_vars),
$in_reduction_vars_byref, $in_reductions
) `)`
|`priority` `(` $priority `)`
|`allocate` `(`
Expand Down Expand Up @@ -955,8 +964,10 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
UnitAttr:$untied,
UnitAttr:$mergeable,
Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$in_reduction_vars_byref,
OptionalAttr<SymbolRefArrayAttr>:$in_reductions,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$reduction_vars_byref,
OptionalAttr<SymbolRefArrayAttr>:$reductions,
Optional<IntLikeType>:$priority,
Variadic<AnyType>:$allocate_vars,
Expand All @@ -978,11 +989,13 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
|`mergeable` $mergeable
|`in_reduction` `(`
custom<ReductionVarList>(
$in_reduction_vars, type($in_reduction_vars), $in_reductions
$in_reduction_vars, type($in_reduction_vars),
$in_reduction_vars_byref, $in_reductions
) `)`
|`reduction` `(`
custom<ReductionVarList>(
$reduction_vars, type($reduction_vars), $reductions
$reduction_vars, type($reduction_vars), $reduction_vars_byref,
$reductions
) `)`
|`priority` `(` $priority `:` type($priority) `)`
|`allocate` `(`
Expand Down Expand Up @@ -1033,6 +1046,7 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments,
}];

let arguments = (ins Variadic<OpenMP_PointerLikeType>:$task_reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$task_reduction_vars_byref,
OptionalAttr<SymbolRefArrayAttr>:$task_reductions,
Variadic<AnyType>:$allocate_vars,
Variadic<AnyType>:$allocators_vars);
Expand All @@ -1046,7 +1060,8 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments,
let assemblyFormat = [{
oilist(`task_reduction` `(`
custom<ReductionVarList>(
$task_reduction_vars, type($task_reduction_vars), $task_reductions
$task_reduction_vars, type($task_reduction_vars),
$task_reduction_vars_byref, $task_reductions
) `)`
|`allocate` `(`
custom<AllocateAndAllocator>(
Expand Down
71 changes: 49 additions & 22 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ static ArrayAttr makeArrayAttr(MLIRContext *context,
return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
}

static DenseBoolArrayAttr
makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) {
return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
}

namespace {
struct MemRefPointerLikeModel
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
Expand Down Expand Up @@ -460,7 +465,7 @@ static ParseResult parseClauseWithRegionArgs(
return success();
})))
return failure();
isByRef = DenseBoolArrayAttr::get(parser.getContext(), isByRefVec);
isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);

auto *argsBegin = regionPrivateArgs.begin();
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
Expand Down Expand Up @@ -552,7 +557,7 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
mlir::SmallVector<bool> isByRefVec;
isByRefVec.resize(privateVarTypes.size(), false);
DenseBoolArrayAttr isByRef =
DenseBoolArrayAttr::get(op->getContext(), isByRefVec);
makeDenseBoolArrayAttr(op->getContext(), isByRefVec);

printClauseWithRegionArgs(p, op, argsSubrange, "private",
privateVarOperands, privateVarTypes, isByRef,
Expand All @@ -568,18 +573,22 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
static ParseResult
parseReductionVarList(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types,
SmallVectorImpl<Type> &types, DenseBoolArrayAttr &isByRef,
ArrayAttr &redcuctionSymbols) {
SmallVector<SymbolRefAttr> reductionVec;
SmallVector<bool> isByRefVec;
if (failed(parser.parseCommaSeparatedList([&]() {
ParseResult optionalByref = parser.parseOptionalKeyword("byref");
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseArrow() ||
parser.parseOperand(operands.emplace_back()) ||
parser.parseColonType(types.emplace_back()))
return failure();
isByRefVec.push_back(optionalByref.succeeded());
return success();
})))
return failure();
isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
return success();
Expand All @@ -589,11 +598,21 @@ parseReductionVarList(OpAsmParser &parser,
static void printReductionVarList(OpAsmPrinter &p, Operation *op,
OperandRange reductionVars,
TypeRange reductionTypes,
std::optional<DenseBoolArrayAttr> isByRef,
std::optional<ArrayAttr> reductions) {
for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
auto getByRef = [&](unsigned i) -> const char * {
if (!isByRef || !*isByRef)
return "";
assert(isByRef->empty() || i < isByRef->size());
if (!isByRef->empty() && (*isByRef)[i])
return "byref ";
return "";
};

for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) {
if (i != 0)
p << ", ";
p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
p << getByRef(i) << (*reductions)[i] << " -> " << reductionVars[i] << " : "
<< reductionVars[i].getType();
}
}
Expand All @@ -602,16 +621,12 @@ static void printReductionVarList(OpAsmPrinter &p, Operation *op,
static LogicalResult
verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductions,
OperandRange reductionVars,
std::optional<ArrayRef<bool>> byRef = std::nullopt) {
std::optional<ArrayRef<bool>> byRef) {
if (!reductionVars.empty()) {
if (!reductions || reductions->size() != reductionVars.size())
return op->emitOpError()
<< "expected as many reduction symbol references "
"as reduction variables";
if (mlir::isa<omp::WsloopOp, omp::ParallelOp>(op))
assert(byRef);
else
assert(!byRef); // TODO: support byref reductions on other operations
if (byRef && byRef->size() != reductionVars.size())
return op->emitError() << "expected as many reduction variable by "
"reference attributes as reduction variables";
Expand Down Expand Up @@ -1453,7 +1468,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ParallelOp::build(builder, state, clauses.ifVar, clauses.numThreadsVar,
clauses.allocateVars, clauses.allocatorVars,
clauses.reductionVars,
DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
clauses.procBindKindAttr, clauses.privateVars,
makeArrayAttr(ctx, clauses.privatizers));
Expand Down Expand Up @@ -1551,6 +1566,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
clauses.numTeamsUpperVar, clauses.ifVar,
clauses.threadLimitVar, clauses.allocateVars,
clauses.allocatorVars, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols));
}

Expand Down Expand Up @@ -1582,7 +1598,8 @@ LogicalResult TeamsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");

return verifyReductionVarList(*this, getReductions(), getReductionVars());
return verifyReductionVarList(*this, getReductions(), getReductionVars(),
getReductionVarsByref());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1594,6 +1611,7 @@ void SectionsOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
SectionsOp::build(builder, state, clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
clauses.allocateVars, clauses.allocatorVars,
clauses.nowaitAttr);
Expand All @@ -1604,7 +1622,8 @@ LogicalResult SectionsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");

return verifyReductionVarList(*this, getReductions(), getReductionVars());
return verifyReductionVarList(*this, getReductions(), getReductionVars(),
getReductionVarsByref());
}

LogicalResult SectionsOp::verifyRegions() {
Expand Down Expand Up @@ -1693,7 +1712,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
// privatizers.
WsloopOp::build(builder, state, clauses.linearVars, clauses.linearStepVars,
clauses.reductionVars,
DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
clauses.scheduleValAttr, clauses.scheduleChunkVar,
clauses.scheduleModAttr, clauses.scheduleSimdAttr,
Expand Down Expand Up @@ -1892,6 +1911,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
TaskOp::build(
builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
clauses.mergeableAttr, clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef),
makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar,
makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
clauses.allocateVars, clauses.allocatorVars);
Expand All @@ -1903,7 +1923,8 @@ LogicalResult TaskOp::verify() {
return failed(verifyDependVars)
? verifyDependVars
: verifyReductionVarList(*this, getInReductions(),
getInReductionVars());
getInReductionVars(),
getInReductionVarsByref());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1913,14 +1934,17 @@ LogicalResult TaskOp::verify() {
void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
const TaskgroupClauseOps &clauses) {
MLIRContext *ctx = builder.getContext();
TaskgroupOp::build(builder, state, clauses.taskReductionVars,
makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
clauses.allocateVars, clauses.allocatorVars);
TaskgroupOp::build(
builder, state, clauses.taskReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.taskReductionVarsByRef),
makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
clauses.allocateVars, clauses.allocatorVars);
}

LogicalResult TaskgroupOp::verify() {
return verifyReductionVarList(*this, getTaskReductions(),
getTaskReductionVars());
getTaskReductionVars(),
getTaskReductionVarsByref());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1934,7 +1958,9 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
TaskloopOp::build(
builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
clauses.mergeableAttr, clauses.inReductionVars,
makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef),
makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
clauses.numTasksVar, clauses.nogroupAttr);
Expand All @@ -1952,10 +1978,11 @@ LogicalResult TaskloopOp::verify() {
if (getAllocateVars().size() != getAllocatorsVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
if (failed(
verifyReductionVarList(*this, getReductions(), getReductionVars())) ||
if (failed(verifyReductionVarList(*this, getReductions(), getReductionVars(),
getReductionVarsByref())) ||
failed(verifyReductionVarList(*this, getInReductions(),
getInReductionVars())))
getInReductionVars(),
getInReductionVarsByref())))
return failure();

if (!getReductionVars().empty() && getNogroup())
Expand Down
Loading
Loading