Skip to content

[Flang][OpenMP] NFC: Minor refactoring of Reduction lowering code #70790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 22, 2023

Conversation

kiranchandramohan
Copy link
Contributor

Move reduction lowering code into a ReductionProcessor class. Create an enumeration for Intrinsic Procedure reductions.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp labels Oct 31, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 31, 2023

@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kiran Chandramohan (kiranchandramohan)

Changes

Move reduction lowering code into a ReductionProcessor class. Create an enumeration for Intrinsic Procedure reductions.


Patch is 39.63 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70790.diff

1 Files Affected:

  • (modified) flang/lib/Lower/OpenMP.cpp (+431-358)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 0faaae6c08e0476..a804a2ce1780e7d 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -667,276 +667,441 @@ static void checkMapType(mlir::Location location, mlir::Type type) {
       TODO(location, "OMPD_target_data MapOperand BoxType");
 }
 
-static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
-  return (llvm::Twine(name) +
-          (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
-          llvm::Twine(ty.getIntOrFloatBitWidth()))
-      .str();
-}
+class ReductionProcessor {
+public:
+  enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR };
+  static IntrinsicProc
+  getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
+    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+                       getRealName(pd).ToString())
+                       .Case("max", IntrinsicProc::MAX)
+                       .Case("min", IntrinsicProc::MIN)
+                       .Case("iand", IntrinsicProc::IAND)
+                       .Case("ior", IntrinsicProc::IOR)
+                       .Case("ieor", IntrinsicProc::IEOR)
+                       .Default(std::nullopt);
+    assert(redType && "Invalid Reduction");
+    return *redType;
+  }
+
+  static bool supportedIntrinsicProcReduction(
+      const Fortran::parser::ProcedureDesignator &pd) {
+    const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
+    assert(name && "Invalid Reduction Intrinsic.");
+    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+                       getRealName(name).ToString())
+                       .Case("max", IntrinsicProc::MAX)
+                       .Case("min", IntrinsicProc::MIN)
+                       .Case("iand", IntrinsicProc::IAND)
+                       .Case("ior", IntrinsicProc::IOR)
+                       .Case("ieor", IntrinsicProc::IEOR)
+                       .Default(std::nullopt);
+    if (redType)
+      return true;
+    return false;
+  }
+
+  static const Fortran::semantics::SourceName
+  getRealName(const Fortran::parser::Name *name) {
+    return name->symbol->GetUltimate().name();
+  }
 
-static std::string getReductionName(
-    Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-    mlir::Type ty) {
-  std::string reductionName;
+  static const Fortran::semantics::SourceName
+  getRealName(const Fortran::parser::ProcedureDesignator &pd) {
+    const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
+    assert(name && "Invalid Reduction Intrinsic.");
+    return getRealName(name);
+  }
 
-  switch (intrinsicOp) {
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-    reductionName = "add_reduction";
-    break;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-    reductionName = "multiply_reduction";
-    break;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-    return "and_reduction";
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
-    return "eqv_reduction";
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-    return "or_reduction";
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
-    return "neqv_reduction";
-  default:
-    reductionName = "other_reduction";
-    break;
+  static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
+    return (llvm::Twine(name) +
+            (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
+            llvm::Twine(ty.getIntOrFloatBitWidth()))
+        .str();
+  }
+
+  static std::string getReductionName(
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+      mlir::Type ty) {
+    std::string reductionName;
+
+    switch (intrinsicOp) {
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+      reductionName = "add_reduction";
+      break;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+      reductionName = "multiply_reduction";
+      break;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+      return "and_reduction";
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+      return "eqv_reduction";
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+      return "or_reduction";
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      return "neqv_reduction";
+    default:
+      reductionName = "other_reduction";
+      break;
+    }
+
+    return getReductionName(reductionName, ty);
+  }
+
+  /// This function returns the identity value of the operator \p
+  /// reductionOpName. For example:
+  ///    0 + x = x,
+  ///    1 * x = x
+  static int getOperationIdentity(
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+      mlir::Location loc) {
+    switch (intrinsicOp) {
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      return 0;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+      return 1;
+    default:
+      TODO(loc, "Reduction of some intrinsic operators is not supported");
+    }
   }
 
-  return getReductionName(reductionName, ty);
-}
-
-/// This function returns the identity value of the operator \p reductionOpName.
-/// For example:
-///    0 + x = x,
-///    1 * x = x
-static int getOperationIdentity(llvm::StringRef reductionOpName,
-                                mlir::Location loc) {
-  if (reductionOpName.contains("add") || reductionOpName.contains("or") ||
-      reductionOpName.contains("neqv"))
-    return 0;
-  if (reductionOpName.contains("multiply") || reductionOpName.contains("and") ||
-      reductionOpName.contains("eqv"))
-    return 1;
-  TODO(loc, "Reduction of some intrinsic operators is not supported");
-}
-
-static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
-                                         llvm::StringRef reductionOpName,
-                                         fir::FirOpBuilder &builder) {
-  assert((fir::isa_integer(type) || fir::isa_real(type) ||
-          type.isa<fir::LogicalType>()) &&
-         "only integer, logical and real types are currently supported");
-  if (reductionOpName.contains("max")) {
-    if (auto ty = type.dyn_cast<mlir::FloatType>()) {
-      const llvm::fltSemantics &sem = ty.getFloatSemantics();
-      return builder.createRealConstant(
-          loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+  static mlir::Value getIntrinsicProcInitValue(
+      mlir::Location loc, mlir::Type type,
+      const Fortran::parser::ProcedureDesignator &procDesignator,
+      fir::FirOpBuilder &builder) {
+    assert((fir::isa_integer(type) || fir::isa_real(type) ||
+            type.isa<fir::LogicalType>()) &&
+           "only integer, logical and real types are currently supported");
+    switch (getReductionType(procDesignator)) {
+    case IntrinsicProc::MAX: {
+      if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+        const llvm::fltSemantics &sem = ty.getFloatSemantics();
+        return builder.createRealConstant(
+            loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+      }
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, minInt);
+    }
+    case IntrinsicProc::MIN: {
+      if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+        const llvm::fltSemantics &sem = ty.getFloatSemantics();
+        return builder.createRealConstant(
+            loc, type, llvm::APFloat::getSmallest(sem, /*Negative=*/true));
+      }
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, maxInt);
     }
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, minInt);
-  } else if (reductionOpName.contains("min")) {
-    if (auto ty = type.dyn_cast<mlir::FloatType>()) {
-      const llvm::fltSemantics &sem = ty.getFloatSemantics();
-      return builder.createRealConstant(
-          loc, type, llvm::APFloat::getSmallest(sem, /*Negative=*/true));
+    case IntrinsicProc::IOR: {
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, zeroInt);
     }
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, maxInt);
-  } else if (reductionOpName.contains("ior")) {
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, zeroInt);
-  } else if (reductionOpName.contains("ieor")) {
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, zeroInt);
-  } else if (reductionOpName.contains("iand")) {
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, allOnInt);
-  } else {
+    case IntrinsicProc::IEOR: {
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, zeroInt);
+    }
+    case IntrinsicProc::IAND: {
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, allOnInt);
+    }
+    }
+    llvm_unreachable("Unknown Reduction Intrinsic");
+  }
+
+  static mlir::Value getIntrinsicOpInitValue(
+      mlir::Location loc, mlir::Type type,
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+      fir::FirOpBuilder &builder) {
     if (type.isa<mlir::FloatType>())
       return builder.create<mlir::arith::ConstantOp>(
           loc, type,
-          builder.getFloatAttr(
-              type, (double)getOperationIdentity(reductionOpName, loc)));
+          builder.getFloatAttr(type,
+                               (double)getOperationIdentity(intrinsicOp, loc)));
 
     if (type.isa<fir::LogicalType>()) {
       mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
           loc, builder.getI1Type(),
           builder.getIntegerAttr(builder.getI1Type(),
-                                 getOperationIdentity(reductionOpName, loc)));
+                                 getOperationIdentity(intrinsicOp, loc)));
       return builder.createConvert(loc, type, intConst);
     }
 
     return builder.create<mlir::arith::ConstantOp>(
         loc, type,
-        builder.getIntegerAttr(type,
-                               getOperationIdentity(reductionOpName, loc)));
-  }
-}
-
-template <typename FloatOp, typename IntegerOp>
-static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
-                                         mlir::Type type, mlir::Location loc,
-                                         mlir::Value op1, mlir::Value op2) {
-  assert(type.isIntOrIndexOrFloat() &&
-         "only integer and float types are currently supported");
-  if (type.isIntOrIndex())
-    return builder.create<IntegerOp>(loc, op1, op2);
-  return builder.create<FloatOp>(loc, op1, op2);
-}
-
-static mlir::omp::ReductionDeclareOp
-createMinimalReductionDecl(fir::FirOpBuilder &builder,
-                           llvm::StringRef reductionOpName, mlir::Type type,
-                           mlir::Location loc) {
-  mlir::ModuleOp module = builder.getModule();
-  mlir::OpBuilder modBuilder(module.getBodyRegion());
-
-  mlir::omp::ReductionDeclareOp decl =
-      modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
-                                                       type);
-  builder.createBlock(&decl.getInitializerRegion(),
-                      decl.getInitializerRegion().end(), {type}, {loc});
-  builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
-  mlir::Value init = getReductionInitValue(loc, type, reductionOpName, builder);
-  builder.create<mlir::omp::YieldOp>(loc, init);
-
-  builder.createBlock(&decl.getReductionRegion(),
-                      decl.getReductionRegion().end(), {type, type},
-                      {loc, loc});
-
-  return decl;
-}
-
-/// Creates an OpenMP reduction declaration and inserts it into the provided
-/// symbol table. The declaration has a constant initializer with the neutral
-/// value `initValue`, and the reduction combiner carried over from `reduce`.
-/// TODO: Generalize this for non-integer types, add atomic region.
-static mlir::omp::ReductionDeclareOp
-createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-                    const Fortran::parser::ProcedureDesignator &procDesignator,
-                    mlir::Type type, mlir::Location loc) {
-  mlir::OpBuilder::InsertionGuard guard(builder);
-  mlir::ModuleOp module = builder.getModule();
-
-  auto decl =
-      module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-  if (decl)
-    return decl;
+        builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc)));
+  }
+
+  template <typename FloatOp, typename IntegerOp>
+  static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+                                           mlir::Type type, mlir::Location loc,
+                                           mlir::Value op1, mlir::Value op2) {
+    assert(type.isIntOrIndexOrFloat() &&
+           "only integer and float types are currently supported");
+    if (type.isIntOrIndex())
+      return builder.create<IntegerOp>(loc, op1, op2);
+    return builder.create<FloatOp>(loc, op1, op2);
+  }
+
+  /// Creates an OpenMP reduction declaration and inserts it into the provided
+  /// symbol table. The declaration has a constant initializer with the neutral
+  /// value `initValue`, and the reduction combiner carried over from `reduce`.
+  /// TODO: Generalize this for non-integer types, add atomic region.
+  static mlir::omp::ReductionDeclareOp createReductionDecl(
+      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+      const Fortran::parser::ProcedureDesignator &procDesignator,
+      mlir::Type type, mlir::Location loc) {
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    mlir::ModuleOp module = builder.getModule();
+
+    auto decl =
+        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+    if (decl)
+      return decl;
 
-  decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
-  builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-  mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-  mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+    mlir::OpBuilder modBuilder(module.getBodyRegion());
 
-  mlir::Value reductionOp;
-  if (const auto *name{
-          Fortran::parser::Unwrap<Fortran::parser::Name>(procDesignator)}) {
-    if (name->source == "max") {
+    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+        loc, reductionOpName, type);
+    builder.createBlock(&decl.getInitializerRegion(),
+                        decl.getInitializerRegion().end(), {type}, {loc});
+    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+    mlir::Value init =
+        getIntrinsicProcInitValue(loc, type, procDesignator, builder);
+    builder.create<mlir::omp::YieldOp>(loc, init);
+
+    builder.createBlock(&decl.getReductionRegion(),
+                        decl.getReductionRegion().end(), {type, type},
+                        {loc, loc});
+
+    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+
+    mlir::Value reductionOp;
+    switch (getReductionType(procDesignator)) {
+    case IntrinsicProc::MAX:
       reductionOp =
           getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
               builder, type, loc, op1, op2);
-    } else if (name->source == "min") {
+      break;
+    case IntrinsicProc::MIN:
       reductionOp =
           getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
               builder, type, loc, op1, op2);
-    } else if (name->source == "ior") {
+      break;
+    case IntrinsicProc::IOR:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
-    } else if (name->source == "ieor") {
+      break;
+    case IntrinsicProc::IEOR:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
-    } else if (name->source == "iand") {
+      break;
+    case IntrinsicProc::IAND:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
-    } else {
-      TODO(loc, "Reduction of some intrinsic operators is not supported");
+      break;
+    default:
+      llvm_unreachable(
+          "Reduction of some intrinsic operators is not supported");
     }
+
+    builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+    return decl;
   }
 
-  builder.create<mlir::omp::YieldOp>(loc, reductionOp);
-  return decl;
-}
+  /// Creates an OpenMP reduction declaration and inserts it into the provided
+  /// symbol table. The declaration has a constant initializer with the neutral
+  /// value `initValue`, and the reduction combiner carried over from `reduce`.
+  /// TODO: Generalize this for non-integer types, add atomic region.
+  static mlir::omp::ReductionDeclareOp createReductionDecl(
+      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+      mlir::Type type, mlir::Location loc) {
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    mlir::ModuleOp module = builder.getModule();
 
-/// Creates an OpenMP reduction declaration and inserts it into the provided
-/// symbol table. The declaration has a constant initializer with the neutral
-/// value `initValue`, and the reduction combiner carried over from `reduce`.
-/// TODO: Generalize this for non-integer types, add atomic region.
-static mlir::omp::ReductionDeclareOp createReductionDecl(
-    fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-    Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-    mlir::Type type, mlir::Location loc) {
-  mlir::OpBuilder::InsertionGuard guard(builder);
-  mlir::ModuleOp module = builder.getModule();
+    auto decl =
+        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+    if (decl)
+      return decl;
 
-  auto decl =
-      module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-  if (decl)
-    return decl;
+    mlir::OpBuilder modBuilder(module.getBodyRegion());
 
-  decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
-  builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-  mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-  mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+        loc, reductionOpName, type);
+    builder.createBlock(&decl.getInitializerRegion(),
+                        decl.getInitializerRegion().end(), {type}, {loc});
+    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+    mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder);
+    builder.create<mlir::...
[truncated]

Copy link
Contributor

@NimishMishra NimishMishra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Member

@DavidTruby DavidTruby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Move reduction lowering code into a ReductionProcessor class.
Create an enumeration for Intrinsic Procedure reductions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants