-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[flang] Add reduction semantics to fir.do_loop #93934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: None (khaki3) ChangesDerived from #92480. This PR introduces Full diff: https://github.com/llvm/llvm-project/pull/93934.diff 4 Files Affected:
diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.td b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
index 0c34b640a5c9c..aedb6769186e9 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRAttr.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
@@ -67,6 +67,36 @@ def fir_BoxFieldAttr : I32EnumAttr<
let cppNamespace = "fir";
}
+def fir_ReduceOperationEnum : I32BitEnumAttr<"ReduceOperationEnum",
+ "intrinsic operations and functions supported by DO CONCURRENT REDUCE",
+ [
+ I32BitEnumAttrCaseBit<"Add", 0, "add">,
+ I32BitEnumAttrCaseBit<"Multiply", 1, "multiply">,
+ I32BitEnumAttrCaseBit<"AND", 2, "and">,
+ I32BitEnumAttrCaseBit<"OR", 3, "or">,
+ I32BitEnumAttrCaseBit<"EQV", 4, "eqv">,
+ I32BitEnumAttrCaseBit<"NEQV", 5, "neqv">,
+ I32BitEnumAttrCaseBit<"MAX", 6, "max">,
+ I32BitEnumAttrCaseBit<"MIN", 7, "min">,
+ I32BitEnumAttrCaseBit<"IAND", 8, "iand">,
+ I32BitEnumAttrCaseBit<"IOR", 9, "ior">,
+ I32BitEnumAttrCaseBit<"EIOR", 10, "eior">
+ ]> {
+ let separator = ", ";
+ let cppNamespace = "::fir";
+ let printBitEnumPrimaryGroups = 1;
+}
+
+def fir_ReduceAttr : fir_Attr<"Reduce"> {
+ let mnemonic = "reduce_attr";
+
+ let parameters = (ins
+ "ReduceOperationEnum":$reduce_operation
+ );
+
+ let assemblyFormat = "`<` $reduce_operation `>`";
+}
+
// mlir::SideEffects::Resource for modelling operations which add debugging information
def DebuggingResource : Resource<"::fir::DebuggingResource">;
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 3afc97475db11..d79f2da916d05 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2107,8 +2107,37 @@ class region_Op<string mnemonic, list<Trait> traits = []> :
let hasVerifier = 1;
}
-def fir_DoLoopOp : region_Op<"do_loop",
- [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+def fir_ReduceOp : fir_SimpleOp<"reduce", [NoMemoryEffect]> {
+ let summary = "Represent reduction semantics for the reduce clause";
+
+ let description = [{
+ Given the address of a variable, creates reduction information for the
+ reduce clause.
+
+ ```
+ %17 = fir.reduce %8 {name = "sum"} : (!fir.ref<f32>) -> !fir.ref<f32>
+ fir.do_loop ... unordered reduce(#fir.reduce_attr<add> -> %17 : !fir.ref<f32>) ...
+ ```
+
+ This operation is typically used for DO CONCURRENT REDUCE clause. The memref
+ operand may have a unique name while the `name` attribute preserves the
+ original name of a reduction variable.
+ }];
+
+ let arguments = (ins
+ AnyRefOrBoxLike:$memref,
+ Builtin_StringAttr:$name
+ );
+
+ let results = (outs AnyRefOrBox);
+
+ let assemblyFormat = [{
+ operands attr-dict `:` functional-type(operands, results)
+ }];
+}
+
+def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getYieldedValuesMutable"]>]> {
let summary = "generalized loop operation";
let description = [{
@@ -2138,9 +2167,11 @@ def fir_DoLoopOp : region_Op<"do_loop",
Index:$lowerBound,
Index:$upperBound,
Index:$step,
+ Variadic<AnyType>:$reduceOperands,
Variadic<AnyType>:$initArgs,
OptionalAttr<UnitAttr>:$unordered,
- OptionalAttr<UnitAttr>:$finalValue
+ OptionalAttr<UnitAttr>:$finalValue,
+ OptionalAttr<ArrayAttr>:$reduceAttrs
);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
@@ -2151,6 +2182,8 @@ def fir_DoLoopOp : region_Op<"do_loop",
"mlir::Value":$step, CArg<"bool", "false">:$unordered,
CArg<"bool", "false">:$finalCountValue,
CArg<"mlir::ValueRange", "std::nullopt">:$iterArgs,
+ CArg<"mlir::ValueRange", "std::nullopt">:$reduceOperands,
+ CArg<"llvm::ArrayRef<mlir::Attribute>", "{}">:$reduceAttrs,
CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>
];
@@ -2163,11 +2196,12 @@ def fir_DoLoopOp : region_Op<"do_loop",
return getBody()->getArguments().drop_front();
}
mlir::Operation::operand_range getIterOperands() {
- return getOperands().drop_front(getNumControlOperands());
+ return getOperands()
+ .drop_front(getNumControlOperands() + getNumReduceOperands());
}
llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
- return
- getOperation()->getOpOperands().drop_front(getNumControlOperands());
+ return getOperation()->getOpOperands()
+ .drop_front(getNumControlOperands() + getNumReduceOperands());
}
void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
@@ -2182,11 +2216,25 @@ def fir_DoLoopOp : region_Op<"do_loop",
unsigned getNumControlOperands() { return 3; }
/// Does the operation hold operands for loop-carried values
bool hasIterOperands() {
- return (*this)->getNumOperands() > getNumControlOperands();
+ return getNumIterOperands() > 0;
+ }
+ /// Does the operation hold operands for reduction variables
+ bool hasReduceOperands() {
+ return getNumReduceOperands() > 0;
+ }
+ /// Get Number of variadic operands
+ unsigned getNumOperands(unsigned idx) {
+ auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(
+ getOperandSegmentSizeAttr());
+ return static_cast<unsigned>(segments[idx]);
+ }
+ // Get Number of reduction operands
+ unsigned getNumReduceOperands() {
+ return getNumOperands(3);
}
/// Get Number of loop-carried values
unsigned getNumIterOperands() {
- return (*this)->getNumOperands() - getNumControlOperands();
+ return getNumOperands(4);
}
/// Get the body of the loop
diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
index 2faba63dfba07..a0202a0159228 100644
--- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
@@ -297,6 +297,6 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
void FIROpsDialect::registerAttributes() {
addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
- LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
- UpperBoundAttr>();
+ LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr,
+ SubclassAttr, UpperBoundAttr>();
}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index b541b7cdc7a5b..807459c8ec3c7 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2079,9 +2079,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, mlir::Value lb,
mlir::Value ub, mlir::Value step, bool unordered,
bool finalCountValue, mlir::ValueRange iterArgs,
+ mlir::ValueRange reduceOperands,
+ llvm::ArrayRef<mlir::Attribute> reduceAttrs,
llvm::ArrayRef<mlir::NamedAttribute> attributes) {
result.addOperands({lb, ub, step});
+ result.addOperands(reduceOperands);
result.addOperands(iterArgs);
+ result.addAttribute(getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr(
+ {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+ static_cast<int32_t>(iterArgs.size())}));
if (finalCountValue) {
result.addTypes(builder.getIndexType());
result.addAttribute(getFinalValueAttrName(result.name),
@@ -2100,6 +2107,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
if (unordered)
result.addAttribute(getUnorderedAttrName(result.name),
builder.getUnitAttr());
+ if (!reduceAttrs.empty())
+ result.addAttribute(getReduceAttrsAttrName(result.name),
+ builder.getArrayAttr(reduceAttrs));
result.addAttributes(attributes);
}
@@ -2125,24 +2135,51 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
if (mlir::succeeded(parser.parseOptionalKeyword("unordered")))
result.addAttribute("unordered", builder.getUnitAttr());
+ // Parse the reduction arguments.
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
+ llvm::SmallVector<mlir::Type> reduceArgTypes;
+ if (succeeded(parser.parseOptionalKeyword("reduce"))) {
+ // Parse reduction attributes and variables.
+ llvm::SmallVector<ReduceAttr> attributes;
+ if (failed(parser.parseCommaSeparatedList(
+ mlir::AsmParser::Delimiter::Paren, [&]() {
+ if (parser.parseAttribute(attributes.emplace_back()) ||
+ parser.parseArrow() ||
+ parser.parseOperand(reduceOperands.emplace_back()) ||
+ parser.parseColonType(reduceArgTypes.emplace_back()))
+ return mlir::failure();
+ return mlir::success();
+ })))
+ return mlir::failure();
+ // Resolve input operands.
+ for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes))
+ if (parser.resolveOperand(std::get<0>(operand_type),
+ std::get<1>(operand_type), result.operands))
+ return mlir::failure();
+ llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+ attributes.end());
+ result.addAttribute(getReduceAttrsAttrName(result.name),
+ builder.getArrayAttr(arrayAttr));
+ }
+
// Parse the optional initial iteration arguments.
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
llvm::SmallVector<mlir::Type> argTypes;
bool prependCount = false;
regionArgs.push_back(inductionVariable);
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
// Parse assignment list and results type list.
- if (parser.parseAssignmentList(regionArgs, operands) ||
+ if (parser.parseAssignmentList(regionArgs, iterOperands) ||
parser.parseArrowTypeList(result.types))
return mlir::failure();
- if (result.types.size() == operands.size() + 1)
+ if (result.types.size() == iterOperands.size() + 1)
prependCount = true;
// Resolve input operands.
llvm::ArrayRef<mlir::Type> resTypes = result.types;
- for (auto operand_type :
- llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes))
+ for (auto operand_type : llvm::zip(
+ iterOperands, prependCount ? resTypes.drop_front() : resTypes))
if (parser.resolveOperand(std::get<0>(operand_type),
std::get<1>(operand_type), result.operands))
return mlir::failure();
@@ -2153,6 +2190,12 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
prependCount = true;
}
+ // Set the operandSegmentSizes attribute
+ result.addAttribute(getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr(
+ {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+ static_cast<int32_t>(iterOperands.size())}));
+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return mlir::failure();
@@ -2229,6 +2272,10 @@ mlir::LogicalResult fir::DoLoopOp::verify() {
i++;
}
+ auto reduceAttrs = getReduceAttrsAttr();
+ if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0))
+ return emitOpError(
+ "mismatch in number of reduction variables and reduction attributes");
return mlir::success();
}
@@ -2238,6 +2285,17 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
<< getUpperBound() << " step " << getStep();
if (getUnordered())
p << " unordered";
+ if (hasReduceOperands()) {
+ p << " reduce(";
+ auto attrs = getReduceAttrsAttr();
+ auto operands = getReduceOperands();
+ llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) {
+ p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
+ << std::get<1>(it).getType();
+ });
+ p << ')';
+ printBlockTerminators = true;
+ }
if (hasIterOperands()) {
p << " iter_args(";
auto regionArgs = getRegionIterArgs();
@@ -2251,8 +2309,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
p << " -> " << getResultTypes();
printBlockTerminators = true;
}
- p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
- {"unordered", "finalValue"});
+ p.printOptionalAttrDictWithKeyword(
+ (*this)->getAttrs(),
+ {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"});
p << ' ';
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
printBlockTerminators);
|
I cannot find any case in which we need |
Yeah I think it makes sense. In OpenACC dialect with have this extra op because the dialect is meant to be language agnostic. Here we know that we deal with Fortran so it should be fine. |
Thanks. I removed |
Variadic<AnyType>:$reduceOperands, | ||
Variadic<AnyType>:$initArgs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variadic<AnyType>:$reduceOperands, | |
Variadic<AnyType>:$initArgs, | |
Variadic<AnyType>:$initArgs, | |
Variadic<AnyType>:$reduceOperands, |
initArgs is more related to the control operands so maybe it's better to keep them together.
if (hasReduceOperands()) { | ||
p << " reduce("; | ||
auto attrs = getReduceAttrsAttr(); | ||
auto operands = getReduceOperands(); | ||
llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) { | ||
p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " | ||
<< std::get<1>(it).getType(); | ||
}); | ||
p << ')'; | ||
printBlockTerminators = true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch position with iter_args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iter_args
affects the result types (e.g. iter_args(%arg1 = %7) -> (index, i32)
). While we do not expect reduce
to be coupled with iter_args
, that is still acceptable. Which one should we choose?
iter_args(%arg1 = %7) reduce(...) -> (index, i32)
reduce(...) iter_args(%arg1 = %7) -> (index, i32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's a good point. Let's keep it like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks mostly ok to me. Just a comment about the position of the reduce list in the op.
@klausler Could you check this PR? I appreciate it! |
No, I have no knowledge of fir. |
Derived from #92480. This PR introduces reduction semantics into loops for DO CONCURRENT REDUCE. The
fir.do_loop
operation now invisibly has theoperandSegmentsizes
attribute and takes variable-length reduction operands with their operations given asfir.reduce_attr
. For the sake of compatibility,fir.do_loop
's builder has additional arguments at the end. Theiter_args
operand should be placed in front of the declaration of result types, so the new operand for reduction variables (reduce
) is put in the middle of arguments.