Skip to content

[flang] Add reduction semantics to fir.do_loop #93934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 6, 2024

Conversation

khaki3
Copy link
Contributor

@khaki3 khaki3 commented May 31, 2024

Derived from #92480. This PR introduces reduction semantics into loops for DO CONCURRENT REDUCE. The fir.do_loop operation now invisibly has the operandSegmentsizes attribute and takes variable-length reduction operands with their operations given as fir.reduce_attr. For the sake of compatibility, fir.do_loop's builder has additional arguments at the end. The iter_args operand should be placed in front of the declaration of result types, so the new operand for reduction variables (reduce) is put in the middle of arguments.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels May 31, 2024
@llvmbot
Copy link
Member

llvmbot commented May 31, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (khaki3)

Changes

Derived from #92480. This PR introduces fir.reduce into fir.do_loop. The operation fir.reduce conveys reduction semantics in a similar way to acc.reduction; it marks the reference to reduction variables while keeping their original names. The fir.do_loop operation now invisibly has the operandSegmentsizes attribute and takes variable-length reduction operands with their operations given as fir.reduce_attr. For the sake of compatibility, fir.do_loop's builder has additional arguments at the end. The iter_args operand should be next to a return-type declaration, so the new operand for fir.reduce is put in the middle of arguments.


Full diff: https://github.com/llvm/llvm-project/pull/93934.diff

4 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/FIRAttr.td (+30)
  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+56-8)
  • (modified) flang/lib/Optimizer/Dialect/FIRAttr.cpp (+2-2)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+66-7)
diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.td b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
index 0c34b640a5c9c..aedb6769186e9 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRAttr.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
@@ -67,6 +67,36 @@ def fir_BoxFieldAttr : I32EnumAttr<
   let cppNamespace = "fir";
 }
 
+def fir_ReduceOperationEnum : I32BitEnumAttr<"ReduceOperationEnum",
+    "intrinsic operations and functions supported by DO CONCURRENT REDUCE",
+    [
+      I32BitEnumAttrCaseBit<"Add", 0, "add">,
+      I32BitEnumAttrCaseBit<"Multiply", 1, "multiply">,
+      I32BitEnumAttrCaseBit<"AND", 2, "and">,
+      I32BitEnumAttrCaseBit<"OR", 3, "or">,
+      I32BitEnumAttrCaseBit<"EQV", 4, "eqv">,
+      I32BitEnumAttrCaseBit<"NEQV", 5, "neqv">,
+      I32BitEnumAttrCaseBit<"MAX", 6, "max">,
+      I32BitEnumAttrCaseBit<"MIN", 7, "min">,
+      I32BitEnumAttrCaseBit<"IAND", 8, "iand">,
+      I32BitEnumAttrCaseBit<"IOR", 9, "ior">,
+      I32BitEnumAttrCaseBit<"EIOR", 10, "eior">
+    ]> {
+  let separator = ", ";
+  let cppNamespace = "::fir";
+  let printBitEnumPrimaryGroups = 1;
+}
+
+def fir_ReduceAttr : fir_Attr<"Reduce"> {
+  let mnemonic = "reduce_attr";
+
+  let parameters = (ins
+    "ReduceOperationEnum":$reduce_operation
+  );
+
+  let assemblyFormat = "`<` $reduce_operation `>`";
+}
+
 // mlir::SideEffects::Resource for modelling operations which add debugging information
 def DebuggingResource : Resource<"::fir::DebuggingResource">;
 
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 3afc97475db11..d79f2da916d05 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2107,8 +2107,37 @@ class region_Op<string mnemonic, list<Trait> traits = []> :
   let hasVerifier = 1;
 }
 
-def fir_DoLoopOp : region_Op<"do_loop",
-    [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+def fir_ReduceOp : fir_SimpleOp<"reduce", [NoMemoryEffect]> {
+  let summary = "Represent reduction semantics for the reduce clause";
+
+  let description = [{
+    Given the address of a variable, creates reduction information for the
+    reduce clause.
+
+    ```
+      %17 = fir.reduce %8 {name = "sum"} : (!fir.ref<f32>) -> !fir.ref<f32>
+      fir.do_loop ... unordered reduce(#fir.reduce_attr<add> -> %17 : !fir.ref<f32>) ...
+    ```
+
+    This operation is typically used for DO CONCURRENT REDUCE clause. The memref
+    operand may have a unique name while the `name` attribute preserves the
+    original name of a reduction variable.
+  }];
+
+  let arguments = (ins
+     AnyRefOrBoxLike:$memref,
+     Builtin_StringAttr:$name
+  );
+
+  let results = (outs AnyRefOrBox);
+
+  let assemblyFormat = [{
+    operands attr-dict `:` functional-type(operands, results)
+  }];
+}
+
+def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<LoopLikeOpInterface,
         ["getYieldedValuesMutable"]>]> {
   let summary = "generalized loop operation";
   let description = [{
@@ -2138,9 +2167,11 @@ def fir_DoLoopOp : region_Op<"do_loop",
     Index:$lowerBound,
     Index:$upperBound,
     Index:$step,
+    Variadic<AnyType>:$reduceOperands,
     Variadic<AnyType>:$initArgs,
     OptionalAttr<UnitAttr>:$unordered,
-    OptionalAttr<UnitAttr>:$finalValue
+    OptionalAttr<UnitAttr>:$finalValue,
+    OptionalAttr<ArrayAttr>:$reduceAttrs
   );
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
@@ -2151,6 +2182,8 @@ def fir_DoLoopOp : region_Op<"do_loop",
       "mlir::Value":$step, CArg<"bool", "false">:$unordered,
       CArg<"bool", "false">:$finalCountValue,
       CArg<"mlir::ValueRange", "std::nullopt">:$iterArgs,
+      CArg<"mlir::ValueRange", "std::nullopt">:$reduceOperands,
+      CArg<"llvm::ArrayRef<mlir::Attribute>", "{}">:$reduceAttrs,
       CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>
   ];
 
@@ -2163,11 +2196,12 @@ def fir_DoLoopOp : region_Op<"do_loop",
       return getBody()->getArguments().drop_front();
     }
     mlir::Operation::operand_range getIterOperands() {
-      return getOperands().drop_front(getNumControlOperands());
+      return getOperands()
+          .drop_front(getNumControlOperands() + getNumReduceOperands());
     }
     llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
-      return
-          getOperation()->getOpOperands().drop_front(getNumControlOperands());
+      return getOperation()->getOpOperands()
+          .drop_front(getNumControlOperands() + getNumReduceOperands());
     }
 
     void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
@@ -2182,11 +2216,25 @@ def fir_DoLoopOp : region_Op<"do_loop",
     unsigned getNumControlOperands() { return 3; }
     /// Does the operation hold operands for loop-carried values
     bool hasIterOperands() {
-      return (*this)->getNumOperands() > getNumControlOperands();
+      return getNumIterOperands() > 0;
+    }
+    /// Does the operation hold operands for reduction variables
+    bool hasReduceOperands() {
+      return getNumReduceOperands() > 0;
+    }
+    /// Get Number of variadic operands
+    unsigned getNumOperands(unsigned idx) {
+      auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(
+        getOperandSegmentSizeAttr());
+      return static_cast<unsigned>(segments[idx]);
+    }
+    // Get Number of reduction operands
+    unsigned getNumReduceOperands() {
+      return getNumOperands(3);
     }
     /// Get Number of loop-carried values
     unsigned getNumIterOperands() {
-      return (*this)->getNumOperands() - getNumControlOperands();
+      return getNumOperands(4);
     }
 
     /// Get the body of the loop
diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
index 2faba63dfba07..a0202a0159228 100644
--- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
@@ -297,6 +297,6 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
 
 void FIROpsDialect::registerAttributes() {
   addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
-                LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
-                UpperBoundAttr>();
+                LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr,
+                SubclassAttr, UpperBoundAttr>();
 }
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index b541b7cdc7a5b..807459c8ec3c7 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2079,9 +2079,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
                           mlir::OperationState &result, mlir::Value lb,
                           mlir::Value ub, mlir::Value step, bool unordered,
                           bool finalCountValue, mlir::ValueRange iterArgs,
+                          mlir::ValueRange reduceOperands,
+                          llvm::ArrayRef<mlir::Attribute> reduceAttrs,
                           llvm::ArrayRef<mlir::NamedAttribute> attributes) {
   result.addOperands({lb, ub, step});
+  result.addOperands(reduceOperands);
   result.addOperands(iterArgs);
+  result.addAttribute(getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr(
+                          {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+                           static_cast<int32_t>(iterArgs.size())}));
   if (finalCountValue) {
     result.addTypes(builder.getIndexType());
     result.addAttribute(getFinalValueAttrName(result.name),
@@ -2100,6 +2107,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
   if (unordered)
     result.addAttribute(getUnorderedAttrName(result.name),
                         builder.getUnitAttr());
+  if (!reduceAttrs.empty())
+    result.addAttribute(getReduceAttrsAttrName(result.name),
+                        builder.getArrayAttr(reduceAttrs));
   result.addAttributes(attributes);
 }
 
@@ -2125,24 +2135,51 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
   if (mlir::succeeded(parser.parseOptionalKeyword("unordered")))
     result.addAttribute("unordered", builder.getUnitAttr());
 
+  // Parse the reduction arguments.
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
+  llvm::SmallVector<mlir::Type> reduceArgTypes;
+  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
+    // Parse reduction attributes and variables.
+    llvm::SmallVector<ReduceAttr> attributes;
+    if (failed(parser.parseCommaSeparatedList(
+            mlir::AsmParser::Delimiter::Paren, [&]() {
+              if (parser.parseAttribute(attributes.emplace_back()) ||
+                  parser.parseArrow() ||
+                  parser.parseOperand(reduceOperands.emplace_back()) ||
+                  parser.parseColonType(reduceArgTypes.emplace_back()))
+                return mlir::failure();
+              return mlir::success();
+            })))
+      return mlir::failure();
+    // Resolve input operands.
+    for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes))
+      if (parser.resolveOperand(std::get<0>(operand_type),
+                                std::get<1>(operand_type), result.operands))
+        return mlir::failure();
+    llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+                                                 attributes.end());
+    result.addAttribute(getReduceAttrsAttrName(result.name),
+                        builder.getArrayAttr(arrayAttr));
+  }
+
   // Parse the optional initial iteration arguments.
   llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
-  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
   llvm::SmallVector<mlir::Type> argTypes;
   bool prependCount = false;
   regionArgs.push_back(inductionVariable);
 
   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
     // Parse assignment list and results type list.
-    if (parser.parseAssignmentList(regionArgs, operands) ||
+    if (parser.parseAssignmentList(regionArgs, iterOperands) ||
         parser.parseArrowTypeList(result.types))
       return mlir::failure();
-    if (result.types.size() == operands.size() + 1)
+    if (result.types.size() == iterOperands.size() + 1)
       prependCount = true;
     // Resolve input operands.
     llvm::ArrayRef<mlir::Type> resTypes = result.types;
-    for (auto operand_type :
-         llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes))
+    for (auto operand_type : llvm::zip(
+             iterOperands, prependCount ? resTypes.drop_front() : resTypes))
       if (parser.resolveOperand(std::get<0>(operand_type),
                                 std::get<1>(operand_type), result.operands))
         return mlir::failure();
@@ -2153,6 +2190,12 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
     prependCount = true;
   }
 
+  // Set the operandSegmentSizes attribute
+  result.addAttribute(getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr(
+                          {1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
+                           static_cast<int32_t>(iterOperands.size())}));
+
   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
     return mlir::failure();
 
@@ -2229,6 +2272,10 @@ mlir::LogicalResult fir::DoLoopOp::verify() {
 
     i++;
   }
+  auto reduceAttrs = getReduceAttrsAttr();
+  if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0))
+    return emitOpError(
+        "mismatch in number of reduction variables and reduction attributes");
   return mlir::success();
 }
 
@@ -2238,6 +2285,17 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
     << getUpperBound() << " step " << getStep();
   if (getUnordered())
     p << " unordered";
+  if (hasReduceOperands()) {
+    p << " reduce(";
+    auto attrs = getReduceAttrsAttr();
+    auto operands = getReduceOperands();
+    llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) {
+      p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
+        << std::get<1>(it).getType();
+    });
+    p << ')';
+    printBlockTerminators = true;
+  }
   if (hasIterOperands()) {
     p << " iter_args(";
     auto regionArgs = getRegionIterArgs();
@@ -2251,8 +2309,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
     p << " -> " << getResultTypes();
     printBlockTerminators = true;
   }
-  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
-                                     {"unordered", "finalValue"});
+  p.printOptionalAttrDictWithKeyword(
+      (*this)->getAttrs(),
+      {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"});
   p << ' ';
   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
                 printBlockTerminators);

@khaki3
Copy link
Contributor Author

khaki3 commented Jun 3, 2024

I cannot find any case in which we need fir.reduce. The name of reduction symbols have to be declared, so we can always track their original names. I will directly put reduction variables instead.

@clementval
Copy link
Contributor

I cannot find any case in which we need fir.reduce. The name of reduction symbols have to be declared, so we can always track their original names. I will directly put reduction variables instead.

Yeah I think it makes sense. In OpenACC dialect with have this extra op because the dialect is meant to be language agnostic. Here we know that we deal with Fortran so it should be fine.

@clementval clementval self-requested a review June 3, 2024 16:06
@khaki3
Copy link
Contributor Author

khaki3 commented Jun 3, 2024

Thanks. I removed fir.reduce.

Comment on lines +2141 to 2142
Variadic<AnyType>:$reduceOperands,
Variadic<AnyType>:$initArgs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Variadic<AnyType>:$reduceOperands,
Variadic<AnyType>:$initArgs,
Variadic<AnyType>:$initArgs,
Variadic<AnyType>:$reduceOperands,

initArgs is more related to the control operands so maybe it's better to keep them together.

Comment on lines +2288 to +2298
if (hasReduceOperands()) {
p << " reduce(";
auto attrs = getReduceAttrsAttr();
auto operands = getReduceOperands();
llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) {
p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
<< std::get<1>(it).getType();
});
p << ')';
printBlockTerminators = true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch position with iter_args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iter_args affects the result types (e.g. iter_args(%arg1 = %7) -> (index, i32)). While we do not expect reduce to be coupled with iter_args, that is still acceptable. Which one should we choose?

  • iter_args(%arg1 = %7) reduce(...) -> (index, i32)
  • reduce(...) iter_args(%arg1 = %7) -> (index, i32)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a good point. Let's keep it like this.

Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks mostly ok to me. Just a comment about the position of the reduce list in the op.

@khaki3
Copy link
Contributor Author

khaki3 commented Jun 5, 2024

@klausler Could you check this PR? I appreciate it!

@clementval clementval requested a review from klausler June 5, 2024 05:51
@klausler
Copy link
Contributor

klausler commented Jun 5, 2024

No, I have no knowledge of fir.

@klausler klausler removed their request for review June 5, 2024 14:46
@clementval clementval merged commit 88cdd99 into llvm:main Jun 6, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants