-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[Flang][OpenMP] NFC: Minor refactoring of Reduction lowering code #70790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flang][OpenMP] NFC: Minor refactoring of Reduction lowering code #70790
Conversation
@llvm/pr-subscribers-flang-openmp @llvm/pr-subscribers-flang-fir-hlfir Author: Kiran Chandramohan (kiranchandramohan) ChangesMove reduction lowering code into a ReductionProcessor class. Create an enumeration for Intrinsic Procedure reductions. Patch is 39.63 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70790.diff 1 Files Affected:
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 0faaae6c08e0476..a804a2ce1780e7d 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -667,276 +667,441 @@ static void checkMapType(mlir::Location location, mlir::Type type) {
TODO(location, "OMPD_target_data MapOperand BoxType");
}
-static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
- return (llvm::Twine(name) +
- (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
- llvm::Twine(ty.getIntOrFloatBitWidth()))
- .str();
-}
+class ReductionProcessor {
+public:
+ enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR };
+ static IntrinsicProc
+ getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
+ auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+ getRealName(pd).ToString())
+ .Case("max", IntrinsicProc::MAX)
+ .Case("min", IntrinsicProc::MIN)
+ .Case("iand", IntrinsicProc::IAND)
+ .Case("ior", IntrinsicProc::IOR)
+ .Case("ieor", IntrinsicProc::IEOR)
+ .Default(std::nullopt);
+ assert(redType && "Invalid Reduction");
+ return *redType;
+ }
+
+ static bool supportedIntrinsicProcReduction(
+ const Fortran::parser::ProcedureDesignator &pd) {
+ const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
+ assert(name && "Invalid Reduction Intrinsic.");
+ auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+ getRealName(name).ToString())
+ .Case("max", IntrinsicProc::MAX)
+ .Case("min", IntrinsicProc::MIN)
+ .Case("iand", IntrinsicProc::IAND)
+ .Case("ior", IntrinsicProc::IOR)
+ .Case("ieor", IntrinsicProc::IEOR)
+ .Default(std::nullopt);
+ if (redType)
+ return true;
+ return false;
+ }
+
+ static const Fortran::semantics::SourceName
+ getRealName(const Fortran::parser::Name *name) {
+ return name->symbol->GetUltimate().name();
+ }
-static std::string getReductionName(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type ty) {
- std::string reductionName;
+ static const Fortran::semantics::SourceName
+ getRealName(const Fortran::parser::ProcedureDesignator &pd) {
+ const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
+ assert(name && "Invalid Reduction Intrinsic.");
+ return getRealName(name);
+ }
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- reductionName = "add_reduction";
- break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- reductionName = "multiply_reduction";
- break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- return "and_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
- return "eqv_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- return "or_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
- return "neqv_reduction";
- default:
- reductionName = "other_reduction";
- break;
+ static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
+ return (llvm::Twine(name) +
+ (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
+ llvm::Twine(ty.getIntOrFloatBitWidth()))
+ .str();
+ }
+
+ static std::string getReductionName(
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Type ty) {
+ std::string reductionName;
+
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ reductionName = "add_reduction";
+ break;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ reductionName = "multiply_reduction";
+ break;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ return "and_reduction";
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ return "eqv_reduction";
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ return "or_reduction";
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ return "neqv_reduction";
+ default:
+ reductionName = "other_reduction";
+ break;
+ }
+
+ return getReductionName(reductionName, ty);
+ }
+
+ /// This function returns the identity value of the operator \p
+ /// reductionOpName. For example:
+ /// 0 + x = x,
+ /// 1 * x = x
+ static int getOperationIdentity(
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Location loc) {
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ return 0;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ return 1;
+ default:
+ TODO(loc, "Reduction of some intrinsic operators is not supported");
+ }
}
- return getReductionName(reductionName, ty);
-}
-
-/// This function returns the identity value of the operator \p reductionOpName.
-/// For example:
-/// 0 + x = x,
-/// 1 * x = x
-static int getOperationIdentity(llvm::StringRef reductionOpName,
- mlir::Location loc) {
- if (reductionOpName.contains("add") || reductionOpName.contains("or") ||
- reductionOpName.contains("neqv"))
- return 0;
- if (reductionOpName.contains("multiply") || reductionOpName.contains("and") ||
- reductionOpName.contains("eqv"))
- return 1;
- TODO(loc, "Reduction of some intrinsic operators is not supported");
-}
-
-static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
- llvm::StringRef reductionOpName,
- fir::FirOpBuilder &builder) {
- assert((fir::isa_integer(type) || fir::isa_real(type) ||
- type.isa<fir::LogicalType>()) &&
- "only integer, logical and real types are currently supported");
- if (reductionOpName.contains("max")) {
- if (auto ty = type.dyn_cast<mlir::FloatType>()) {
- const llvm::fltSemantics &sem = ty.getFloatSemantics();
- return builder.createRealConstant(
- loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+ static mlir::Value getIntrinsicProcInitValue(
+ mlir::Location loc, mlir::Type type,
+ const Fortran::parser::ProcedureDesignator &procDesignator,
+ fir::FirOpBuilder &builder) {
+ assert((fir::isa_integer(type) || fir::isa_real(type) ||
+ type.isa<fir::LogicalType>()) &&
+ "only integer, logical and real types are currently supported");
+ switch (getReductionType(procDesignator)) {
+ case IntrinsicProc::MAX: {
+ if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(
+ loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+ }
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, minInt);
+ }
+ case IntrinsicProc::MIN: {
+ if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(
+ loc, type, llvm::APFloat::getSmallest(sem, /*Negative=*/true));
+ }
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, maxInt);
}
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, minInt);
- } else if (reductionOpName.contains("min")) {
- if (auto ty = type.dyn_cast<mlir::FloatType>()) {
- const llvm::fltSemantics &sem = ty.getFloatSemantics();
- return builder.createRealConstant(
- loc, type, llvm::APFloat::getSmallest(sem, /*Negative=*/true));
+ case IntrinsicProc::IOR: {
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, zeroInt);
}
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, maxInt);
- } else if (reductionOpName.contains("ior")) {
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, zeroInt);
- } else if (reductionOpName.contains("ieor")) {
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, zeroInt);
- } else if (reductionOpName.contains("iand")) {
- unsigned bits = type.getIntOrFloatBitWidth();
- int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
- return builder.createIntegerConstant(loc, type, allOnInt);
- } else {
+ case IntrinsicProc::IEOR: {
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, zeroInt);
+ }
+ case IntrinsicProc::IAND: {
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, allOnInt);
+ }
+ }
+ llvm_unreachable("Unknown Reduction Intrinsic");
+ }
+
+ static mlir::Value getIntrinsicOpInitValue(
+ mlir::Location loc, mlir::Type type,
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+ fir::FirOpBuilder &builder) {
if (type.isa<mlir::FloatType>())
return builder.create<mlir::arith::ConstantOp>(
loc, type,
- builder.getFloatAttr(
- type, (double)getOperationIdentity(reductionOpName, loc)));
+ builder.getFloatAttr(type,
+ (double)getOperationIdentity(intrinsicOp, loc)));
if (type.isa<fir::LogicalType>()) {
mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
loc, builder.getI1Type(),
builder.getIntegerAttr(builder.getI1Type(),
- getOperationIdentity(reductionOpName, loc)));
+ getOperationIdentity(intrinsicOp, loc)));
return builder.createConvert(loc, type, intConst);
}
return builder.create<mlir::arith::ConstantOp>(
loc, type,
- builder.getIntegerAttr(type,
- getOperationIdentity(reductionOpName, loc)));
- }
-}
-
-template <typename FloatOp, typename IntegerOp>
-static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
- mlir::Type type, mlir::Location loc,
- mlir::Value op1, mlir::Value op2) {
- assert(type.isIntOrIndexOrFloat() &&
- "only integer and float types are currently supported");
- if (type.isIntOrIndex())
- return builder.create<IntegerOp>(loc, op1, op2);
- return builder.create<FloatOp>(loc, op1, op2);
-}
-
-static mlir::omp::ReductionDeclareOp
-createMinimalReductionDecl(fir::FirOpBuilder &builder,
- llvm::StringRef reductionOpName, mlir::Type type,
- mlir::Location loc) {
- mlir::ModuleOp module = builder.getModule();
- mlir::OpBuilder modBuilder(module.getBodyRegion());
-
- mlir::omp::ReductionDeclareOp decl =
- modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
- type);
- builder.createBlock(&decl.getInitializerRegion(),
- decl.getInitializerRegion().end(), {type}, {loc});
- builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
- mlir::Value init = getReductionInitValue(loc, type, reductionOpName, builder);
- builder.create<mlir::omp::YieldOp>(loc, init);
-
- builder.createBlock(&decl.getReductionRegion(),
- decl.getReductionRegion().end(), {type, type},
- {loc, loc});
-
- return decl;
-}
-
-/// Creates an OpenMP reduction declaration and inserts it into the provided
-/// symbol table. The declaration has a constant initializer with the neutral
-/// value `initValue`, and the reduction combiner carried over from `reduce`.
-/// TODO: Generalize this for non-integer types, add atomic region.
-static mlir::omp::ReductionDeclareOp
-createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- const Fortran::parser::ProcedureDesignator &procDesignator,
- mlir::Type type, mlir::Location loc) {
- mlir::OpBuilder::InsertionGuard guard(builder);
- mlir::ModuleOp module = builder.getModule();
-
- auto decl =
- module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
- if (decl)
- return decl;
+ builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc)));
+ }
+
+ template <typename FloatOp, typename IntegerOp>
+ static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2) {
+ assert(type.isIntOrIndexOrFloat() &&
+ "only integer and float types are currently supported");
+ if (type.isIntOrIndex())
+ return builder.create<IntegerOp>(loc, op1, op2);
+ return builder.create<FloatOp>(loc, op1, op2);
+ }
+
+ /// Creates an OpenMP reduction declaration and inserts it into the provided
+ /// symbol table. The declaration has a constant initializer with the neutral
+ /// value `initValue`, and the reduction combiner carried over from `reduce`.
+ /// TODO: Generalize this for non-integer types, add atomic region.
+ static mlir::omp::ReductionDeclareOp createReductionDecl(
+ fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+ const Fortran::parser::ProcedureDesignator &procDesignator,
+ mlir::Type type, mlir::Location loc) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ mlir::ModuleOp module = builder.getModule();
+
+ auto decl =
+ module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+ if (decl)
+ return decl;
- decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
- builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
- mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
- mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+ mlir::OpBuilder modBuilder(module.getBodyRegion());
- mlir::Value reductionOp;
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(procDesignator)}) {
- if (name->source == "max") {
+ decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+ loc, reductionOpName, type);
+ builder.createBlock(&decl.getInitializerRegion(),
+ decl.getInitializerRegion().end(), {type}, {loc});
+ builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+ mlir::Value init =
+ getIntrinsicProcInitValue(loc, type, procDesignator, builder);
+ builder.create<mlir::omp::YieldOp>(loc, init);
+
+ builder.createBlock(&decl.getReductionRegion(),
+ decl.getReductionRegion().end(), {type, type},
+ {loc, loc});
+
+ builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+ mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+ mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+
+ mlir::Value reductionOp;
+ switch (getReductionType(procDesignator)) {
+ case IntrinsicProc::MAX:
reductionOp =
getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
builder, type, loc, op1, op2);
- } else if (name->source == "min") {
+ break;
+ case IntrinsicProc::MIN:
reductionOp =
getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
builder, type, loc, op1, op2);
- } else if (name->source == "ior") {
+ break;
+ case IntrinsicProc::IOR:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
- } else if (name->source == "ieor") {
+ break;
+ case IntrinsicProc::IEOR:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
- } else if (name->source == "iand") {
+ break;
+ case IntrinsicProc::IAND:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
- } else {
- TODO(loc, "Reduction of some intrinsic operators is not supported");
+ break;
+ default:
+ llvm_unreachable(
+ "Reduction of some intrinsic operators is not supported");
}
+
+ builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+ return decl;
}
- builder.create<mlir::omp::YieldOp>(loc, reductionOp);
- return decl;
-}
+ /// Creates an OpenMP reduction declaration and inserts it into the provided
+ /// symbol table. The declaration has a constant initializer with the neutral
+ /// value `initValue`, and the reduction combiner carried over from `reduce`.
+ /// TODO: Generalize this for non-integer types, add atomic region.
+ static mlir::omp::ReductionDeclareOp createReductionDecl(
+ fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Type type, mlir::Location loc) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ mlir::ModuleOp module = builder.getModule();
-/// Creates an OpenMP reduction declaration and inserts it into the provided
-/// symbol table. The declaration has a constant initializer with the neutral
-/// value `initValue`, and the reduction combiner carried over from `reduce`.
-/// TODO: Generalize this for non-integer types, add atomic region.
-static mlir::omp::ReductionDeclareOp createReductionDecl(
- fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type type, mlir::Location loc) {
- mlir::OpBuilder::InsertionGuard guard(builder);
- mlir::ModuleOp module = builder.getModule();
+ auto decl =
+ module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+ if (decl)
+ return decl;
- auto decl =
- module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
- if (decl)
- return decl;
+ mlir::OpBuilder modBuilder(module.getBodyRegion());
- decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
- builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
- mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
- mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+ decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+ loc, reductionOpName, type);
+ builder.createBlock(&decl.getInitializerRegion(),
+ decl.getInitializerRegion().end(), {type}, {loc});
+ builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+ mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder);
+ builder.create<mlir::...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Move reduction lowering code into a ReductionProcessor class. Create an enumeration for Intrinsic Procedure reductions.
3911841
to
d0f6d3e
Compare
…code" (#73139) Reverts #70790 to fix CI failure (https://lab.llvm.org/buildbot/#/builders/268/builds/2884)
Move reduction lowering code into a ReductionProcessor class. Create an enumeration for Intrinsic Procedure reductions.