Skip to content

[flang] optimize array function calls using hlfir.eval_in_mem #118070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 3, 2024

Conversation

jeanPerier
Copy link
Contributor

This patch encapsulate array function call lowering into hlfir.eval_in_mem and allows directly evaluating the call into the LHS when possible.

The conditions are: LHS is contiguous, not accessed inside the function, it is not a whole allocatable, and the function results needs not to be finalized. All these conditions are tested in the previous hlfir.eval_in_mem optimization patch that is leveraging the extension of getModRef to handle function calls).

This yields a 25% speed-up on polyhedron channel2 benchmark (from 1min to 45s measured on an X86-64 Zen 2).

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Nov 29, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 29, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (jeanPerier)

Changes

This patch encapsulate array function call lowering into hlfir.eval_in_mem and allows directly evaluating the call into the LHS when possible.

The conditions are: LHS is contiguous, not accessed inside the function, it is not a whole allocatable, and the function results needs not to be finalized. All these conditions are tested in the previous hlfir.eval_in_mem optimization patch that is leveraging the extension of getModRef to handle function calls).

This yields a 25% speed-up on polyhedron channel2 benchmark (from 1min to 45s measured on an X86-64 Zen 2).


Patch is 34.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118070.diff

12 Files Affected:

  • (modified) flang/include/flang/Lower/ConvertCall.h (+5-1)
  • (modified) flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h (+4)
  • (modified) flang/lib/Lower/ConvertCall.cpp (+69-33)
  • (modified) flang/lib/Lower/ConvertExpr.cpp (+8-5)
  • (modified) flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp (+15)
  • (modified) flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp (+1-10)
  • (modified) flang/test/HLFIR/order_assignments/where-scheduling.f90 (+1-1)
  • (added) flang/test/Lower/HLFIR/calls-array-results.f90 (+131)
  • (modified) flang/test/Lower/HLFIR/where-nonelemental.f90 (+21-17)
  • (modified) flang/test/Lower/explicit-interface-results-2.f90 (-2)
  • (modified) flang/test/Lower/explicit-interface-results.f90 (+4-4)
  • (modified) flang/test/Lower/forall/array-constructor.f90 (+1-1)
diff --git a/flang/include/flang/Lower/ConvertCall.h b/flang/include/flang/Lower/ConvertCall.h
index bc082907e61760..2c51a887010c87 100644
--- a/flang/include/flang/Lower/ConvertCall.h
+++ b/flang/include/flang/Lower/ConvertCall.h
@@ -24,6 +24,10 @@
 
 namespace Fortran::lower {
 
+struct LoweredResult {
+  std::variant<fir::ExtendedValue, hlfir::EntityWithAttributes> result;
+};
+
 /// Given a call site for which the arguments were already lowered, generate
 /// the call and return the result. This function deals with explicit result
 /// allocation and lowering if needed. It also deals with passing the host
@@ -32,7 +36,7 @@ namespace Fortran::lower {
 /// It is only used for HLFIR.
 /// The returned boolean indicates if finalization has been emitted in
 /// \p stmtCtx for the result.
-std::pair<fir::ExtendedValue, bool> genCallOpAndResult(
+std::pair<LoweredResult, bool> genCallOpAndResult(
     mlir::Location loc, Fortran::lower::AbstractConverter &converter,
     Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
     Fortran::lower::CallerInterface &caller, mlir::FunctionType callSiteType,
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 3830237f96f3cc..447d5fbab89998 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -61,6 +61,10 @@ inline mlir::Type getFortranElementOrSequenceType(mlir::Type type) {
   return type;
 }
 
+/// Build the hlfir.expr type for the value held in a variable of type \p
+/// variableType.
+mlir::Type getExprType(mlir::Type variableType);
+
 /// Is this a fir.box or fir.class address type?
 inline bool isBoxAddressType(mlir::Type type) {
   type = fir::dyn_cast_ptrEleTy(type);
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index e84e7afbe82e09..088d8f96caa410 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -284,7 +284,8 @@ static void remapActualToDummyDescriptors(
   }
 }
 
-std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
+std::pair<Fortran::lower::LoweredResult, bool>
+Fortran::lower::genCallOpAndResult(
     mlir::Location loc, Fortran::lower::AbstractConverter &converter,
     Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
     Fortran::lower::CallerInterface &caller, mlir::FunctionType callSiteType,
@@ -326,6 +327,11 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
     }
   }
 
+  const bool isExprCall =
+      converter.getLoweringOptions().getLowerToHighLevelFIR() &&
+      callSiteType.getNumResults() == 1 &&
+      llvm::isa<fir::SequenceType>(callSiteType.getResult(0));
+
   mlir::IndexType idxTy = builder.getIndexType();
   auto lowerSpecExpr = [&](const auto &expr) -> mlir::Value {
     mlir::Value convertExpr = builder.createConvert(
@@ -333,6 +339,8 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
     return fir::factory::genMaxWithZero(builder, loc, convertExpr);
   };
   llvm::SmallVector<mlir::Value> resultLengths;
+  mlir::Value arrayResultShape;
+  hlfir::EvaluateInMemoryOp evaluateInMemory;
   auto allocatedResult = [&]() -> std::optional<fir::ExtendedValue> {
     llvm::SmallVector<mlir::Value> extents;
     llvm::SmallVector<mlir::Value> lengths;
@@ -366,6 +374,18 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
       resultLengths = lengths;
     }
 
+    if (!extents.empty())
+      arrayResultShape = builder.genShape(loc, extents);
+
+    if (isExprCall) {
+      mlir::Type exprType = hlfir::getExprType(type);
+      evaluateInMemory = builder.create<hlfir::EvaluateInMemoryOp>(
+          loc, exprType, arrayResultShape, resultLengths);
+      builder.setInsertionPointToStart(&evaluateInMemory.getBody().front());
+      return toExtendedValue(loc, evaluateInMemory.getMemory(), extents,
+                             lengths);
+    }
+
     if ((!extents.empty() || !lengths.empty()) && !isElemental) {
       // Note: in the elemental context, the alloca ownership inside the
       // elemental region is implicit, and later pass in lowering (stack
@@ -384,8 +404,7 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
   if (mustPopSymMap)
     symMap.popScope();
 
-  // Place allocated result or prepare the fir.save_result arguments.
-  mlir::Value arrayResultShape;
+  // Place allocated result
   if (allocatedResult) {
     if (std::optional<Fortran::lower::CallInterface<
             Fortran::lower::CallerInterface>::PassedEntity>
@@ -399,16 +418,6 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
       else
         fir::emitFatalError(
             loc, "only expect character scalar result to be passed by ref");
-    } else {
-      assert(caller.mustSaveResult());
-      arrayResultShape = allocatedResult->match(
-          [&](const fir::CharArrayBoxValue &) {
-            return builder.createShape(loc, *allocatedResult);
-          },
-          [&](const fir::ArrayBoxValue &) {
-            return builder.createShape(loc, *allocatedResult);
-          },
-          [&](const auto &) { return mlir::Value{}; });
     }
   }
 
@@ -642,6 +651,19 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
       callResult = call.getResult(0);
   }
 
+  std::optional<Fortran::evaluate::DynamicType> retTy =
+      caller.getCallDescription().proc().GetType();
+  // With HLFIR lowering, isElemental must be set to true
+  // if we are producing an elemental call. In this case,
+  // the elemental results must not be destroyed, instead,
+  // the resulting array result will be finalized/destroyed
+  // as needed by hlfir.destroy.
+  const bool mustFinalizeResult =
+      !isElemental && callSiteType.getNumResults() > 0 &&
+      !fir::isPointerType(callSiteType.getResult(0)) && retTy.has_value() &&
+      (retTy->category() == Fortran::common::TypeCategory::Derived ||
+       retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic());
+
   if (caller.mustSaveResult()) {
     assert(allocatedResult.has_value());
     builder.create<fir::SaveResultOp>(loc, callResult,
@@ -649,6 +671,19 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
                                       arrayResultShape, resultLengths);
   }
 
+  if (evaluateInMemory) {
+    builder.setInsertionPointAfter(evaluateInMemory);
+    mlir::Value expr = evaluateInMemory.getResult();
+    fir::FirOpBuilder *bldr = &converter.getFirOpBuilder();
+    if (!isElemental)
+      stmtCtx.attachCleanup([bldr, loc, expr, mustFinalizeResult]() {
+        bldr->create<hlfir::DestroyOp>(loc, expr,
+                                       /*finalize=*/mustFinalizeResult);
+      });
+    return {LoweredResult{hlfir::EntityWithAttributes{expr}},
+            mustFinalizeResult};
+  }
+
   if (allocatedResult) {
     // The result must be optionally destroyed (if it is of a derived type
     // that may need finalization or deallocation of the components).
@@ -679,17 +714,7 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
     // derived-type.
     // For polymorphic and unlimited polymorphic enities call the runtime
     // in any cases.
-    std::optional<Fortran::evaluate::DynamicType> retTy =
-        caller.getCallDescription().proc().GetType();
-    // With HLFIR lowering, isElemental must be set to true
-    // if we are producing an elemental call. In this case,
-    // the elemental results must not be destroyed, instead,
-    // the resulting array result will be finalized/destroyed
-    // as needed by hlfir.destroy.
-    if (!isElemental && !fir::isPointerType(funcType.getResults()[0]) &&
-        retTy &&
-        (retTy->category() == Fortran::common::TypeCategory::Derived ||
-         retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic())) {
+    if (mustFinalizeResult) {
       if (retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic()) {
         auto *bldr = &converter.getFirOpBuilder();
         stmtCtx.attachCleanup([bldr, loc, allocatedResult]() {
@@ -715,12 +740,13 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
         }
       }
     }
-    return {*allocatedResult, resultIsFinalized};
+    return {LoweredResult{*allocatedResult}, resultIsFinalized};
   }
 
   // subroutine call
   if (!resultType)
-    return {fir::ExtendedValue{mlir::Value{}}, /*resultIsFinalized=*/false};
+    return {LoweredResult{fir::ExtendedValue{mlir::Value{}}},
+            /*resultIsFinalized=*/false};
 
   // For now, Fortran return values are implemented with a single MLIR
   // function return value.
@@ -734,10 +760,13 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
         mlir::dyn_cast<fir::CharacterType>(funcType.getResults()[0]);
     mlir::Value len = builder.createIntegerConstant(
         loc, builder.getCharacterLengthType(), charTy.getLen());
-    return {fir::CharBoxValue{callResult, len}, /*resultIsFinalized=*/false};
+    return {
+        LoweredResult{fir::ExtendedValue{fir::CharBoxValue{callResult, len}}},
+        /*resultIsFinalized=*/false};
   }
 
-  return {callResult, /*resultIsFinalized=*/false};
+  return {LoweredResult{fir::ExtendedValue{callResult}},
+          /*resultIsFinalized=*/false};
 }
 
 static hlfir::EntityWithAttributes genStmtFunctionRef(
@@ -1661,19 +1690,26 @@ genUserCall(Fortran::lower::PreparedActualArguments &loweredActuals,
   // Prepare lowered arguments according to the interface
   // and map the lowered values to the dummy
   // arguments.
-  auto [result, resultIsFinalized] = Fortran::lower::genCallOpAndResult(
+  auto [loweredResult, resultIsFinalized] = Fortran::lower::genCallOpAndResult(
       loc, callContext.converter, callContext.symMap, callContext.stmtCtx,
       caller, callSiteType, callContext.resultType,
       callContext.isElementalProcWithArrayArgs());
-  // For procedure pointer function result, just return the call.
-  if (callContext.resultType &&
-      mlir::isa<fir::BoxProcType>(*callContext.resultType))
-    return hlfir::EntityWithAttributes(fir::getBase(result));
 
   /// Clean-up associations and copy-in.
   for (auto cleanUp : callCleanUps)
     cleanUp.genCleanUp(loc, builder);
 
+  if (auto *entity =
+          std::get_if<hlfir::EntityWithAttributes>(&loweredResult.result))
+    return *entity;
+
+  auto &result = std::get<fir::ExtendedValue>(loweredResult.result);
+
+  // For procedure pointer function result, just return the call.
+  if (callContext.resultType &&
+      mlir::isa<fir::BoxProcType>(*callContext.resultType))
+    return hlfir::EntityWithAttributes(fir::getBase(result));
+
   if (!fir::getBase(result))
     return std::nullopt; // subroutine call.
 
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 46168b81dd3a03..926e45b807e0a0 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -2852,10 +2852,11 @@ class ScalarExprLowering {
       }
     }
 
-    ExtValue result =
+    auto loweredResult =
         Fortran::lower::genCallOpAndResult(loc, converter, symMap, stmtCtx,
                                            caller, callSiteType, resultType)
             .first;
+    auto &result = std::get<ExtValue>(loweredResult.result);
 
     // Sync pointers and allocatables that may have been modified during the
     // call.
@@ -4881,10 +4882,12 @@ class ArrayExprLowering {
             [&](const auto &) { return fir::getBase(exv); });
         caller.placeInput(argIface, arg);
       }
-      return Fortran::lower::genCallOpAndResult(loc, converter, symMap,
-                                                getElementCtx(), caller,
-                                                callSiteType, retTy)
-          .first;
+      Fortran::lower::LoweredResult res =
+          Fortran::lower::genCallOpAndResult(loc, converter, symMap,
+                                             getElementCtx(), caller,
+                                             callSiteType, retTy)
+              .first;
+      return std::get<ExtValue>(res.result);
     };
   }
 
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index 0b61c0edce622b..c66ba75f912fb7 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -215,3 +215,18 @@ bool hlfir::mayHaveAllocatableComponent(mlir::Type ty) {
   return fir::isPolymorphicType(ty) || fir::isUnlimitedPolymorphicType(ty) ||
          fir::isRecordWithAllocatableMember(hlfir::getFortranElementType(ty));
 }
+
+mlir::Type hlfir::getExprType(mlir::Type variableType) {
+  hlfir::ExprType::Shape typeShape;
+  bool isPolymorphic = fir::isPolymorphicType(variableType);
+  mlir::Type type = getFortranElementOrSequenceType(variableType);
+  assert(!fir::isa_trivial(type) &&
+         "numerical and logical scalar should not be wrapped in hlfir.expr");
+  if (auto seqType = mlir::dyn_cast<fir::SequenceType>(type)) {
+    assert(!seqType.hasUnknownShape() && "assumed-rank cannot be expressions");
+    typeShape.append(seqType.getShape().begin(), seqType.getShape().end());
+    type = seqType.getEleTy();
+  }
+  return hlfir::ExprType::get(variableType.getContext(), typeShape, type,
+                              isPolymorphic);
+}
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 87519882446485..3a172d1b8b5400 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -1427,16 +1427,7 @@ llvm::LogicalResult hlfir::EndAssociateOp::verify() {
 void hlfir::AsExprOp::build(mlir::OpBuilder &builder,
                             mlir::OperationState &result, mlir::Value var,
                             mlir::Value mustFree) {
-  hlfir::ExprType::Shape typeShape;
-  bool isPolymorphic = fir::isPolymorphicType(var.getType());
-  mlir::Type type = getFortranElementOrSequenceType(var.getType());
-  if (auto seqType = mlir::dyn_cast<fir::SequenceType>(type)) {
-    typeShape.append(seqType.getShape().begin(), seqType.getShape().end());
-    type = seqType.getEleTy();
-  }
-
-  auto resultType = hlfir::ExprType::get(builder.getContext(), typeShape, type,
-                                         isPolymorphic);
+  mlir::Type resultType = hlfir::getExprType(var.getType());
   return build(builder, result, resultType, var, mustFree);
 }
 
diff --git a/flang/test/HLFIR/order_assignments/where-scheduling.f90 b/flang/test/HLFIR/order_assignments/where-scheduling.f90
index 3010476d4a188f..6feaba0d3389a0 100644
--- a/flang/test/HLFIR/order_assignments/where-scheduling.f90
+++ b/flang/test/HLFIR/order_assignments/where-scheduling.f90
@@ -134,7 +134,7 @@ end function f
 !CHECK-NEXT: run 1 save    : where/mask
 !CHECK-NEXT: run 2 evaluate: where/region_assign1
 !CHECK-LABEL: ------------ scheduling where in _QPonly_once ------------
-!CHECK-NEXT: unknown effect: %{{[0-9]+}} = llvm.intr.stacksave : !llvm.ptr
+!CHECK-NEXT: unknown effect: %11 = fir.call @_QPcall_me_only_once() fastmath<contract> : () -> !fir.array<10x!fir.logical<4>>
 !CHECK-NEXT: saving eval because write effect prevents re-evaluation
 !CHECK-NEXT: run 1 save  (w): where/mask
 !CHECK-NEXT: run 2 evaluate: where/region_assign1
diff --git a/flang/test/Lower/HLFIR/calls-array-results.f90 b/flang/test/Lower/HLFIR/calls-array-results.f90
new file mode 100644
index 00000000000000..d91844cc2e6f87
--- /dev/null
+++ b/flang/test/Lower/HLFIR/calls-array-results.f90
@@ -0,0 +1,131 @@
+! RUN: bbc -emit-hlfir -o - %s -I nowhere | FileCheck %s
+
+subroutine simple_test()
+  implicit none
+  interface
+    function array_func()
+      real :: array_func(10)
+    end function
+  end interface
+  real :: x(10)
+  x = array_func()
+end subroutine
+
+subroutine arg_test(n)
+  implicit none
+  interface
+    function array_func_2(n)
+      integer(8) :: n
+      real :: array_func_2(n)
+    end function
+  end interface
+  integer(8) :: n
+  real :: x(n)
+  x = array_func_2(n)
+end subroutine
+
+module type_defs
+  interface
+    function array_func()
+      real :: array_func(10)
+    end function
+  end interface
+  type t
+    contains
+    procedure, nopass :: array_func => array_func
+  end type
+end module
+
+subroutine dispatch_test(x, a)
+  use type_defs, only : t
+  implicit none
+  real :: x(10)
+  class(t) :: a
+  x = a%array_func()
+end subroutine
+
+! CHECK-LABEL:   func.func @_QPsimple_test() {
+! CHECK:           %[[VAL_0:.*]] = arith.constant 10 : index
+! CHECK:           %[[VAL_1:.*]] = fir.alloca !fir.array<10xf32> {bindc_name = "x", uniq_name = "_QFsimple_testEx"}
+! CHECK:           %[[VAL_2:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_2]]) {uniq_name = "_QFsimple_testEx"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+! CHECK:           %[[VAL_4:.*]] = arith.constant 10 : i64
+! CHECK:           %[[VAL_5:.*]] = arith.constant 1 : i64
+! CHECK:           %[[VAL_6:.*]] = arith.subi %[[VAL_4]], %[[VAL_5]] : i64
+! CHECK:           %[[VAL_7:.*]] = arith.constant 1 : i64
+! CHECK:           %[[VAL_8:.*]] = arith.addi %[[VAL_6]], %[[VAL_7]] : i64
+! CHECK:           %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (i64) -> index
+! CHECK:           %[[VAL_10:.*]] = arith.constant 0 : index
+! CHECK:           %[[VAL_11:.*]] = arith.cmpi sgt, %[[VAL_9]], %[[VAL_10]] : index
+! CHECK:           %[[VAL_12:.*]] = arith.select %[[VAL_11]], %[[VAL_9]], %[[VAL_10]] : index
+! CHECK:           %[[VAL_13:.*]] = fir.shape %[[VAL_12]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_14:.*]] = hlfir.eval_in_mem shape %[[VAL_13]] : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK:           ^bb0(%[[VAL_15:.*]]: !fir.ref<!fir.array<10xf32>>):
+! CHECK:             %[[VAL_16:.*]] = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+! CHECK:             fir.save_result %[[VAL_16]] to %[[VAL_15]](%[[VAL_13]]) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+! CHECK:           }
+! CHECK:           hlfir.assign %[[VAL_14]] to %[[VAL_3]]#0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+! CHECK:           hlfir.destroy %[[VAL_14]] : !hlfir.expr<10xf32>
+! CHECK:           return
+! CHECK:         }
+
+! CHECK-LABEL:   func.func @_QParg_test(
+! CHECK-SAME:                           %[[VAL_0:.*]]: !fir.ref<i64> {fir.bindc_name = "n"}) {
+! CHECK:           %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope
+! CHECK:           %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_1]] {uniq_name = "_QFarg_testEn"} : (!fir.ref<i64>, !fir.dscope) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK:           %[[VAL_3:.*]] = fir.load %[[VAL_2]]#0 : !fir.ref<i64>
+! CHECK:           %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (i64) -> index
+! CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
+! CHECK:           %[[VAL_6:.*]] = arith.cmpi sgt, %[[VAL_4]], %[[VAL_5]] : index
+! CHECK:           %[[VAL_7:.*]] = arith.select %[[VAL_6]], %[[VAL_4]], %[[VAL_5]] : index
+! CHECK:           %[[VAL_8:.*]] = fir.alloca !fir.array<?xf32>, %[[VAL_7]] {bindc_name = "x", uniq_name = "_QFarg_testEx"}
+! CHECK:           %[[VAL_9:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_8]](%[[VAL_9]]) {uniq_name = "_QFarg_testEx"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+! CHECK:           %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_2]]#1 {uniq_name = "_QFarg_testFarray_func_2En"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK:           %[[VAL_12:.*]] = fir.load %[[VAL_11]]#0 : !fir.ref<i64>
+! CHECK:           %[[VAL_13:.*]] = arith.constant 1 : i64
+! CHECK:           %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_13]] : i64
+! CHECK:           %[[VAL_15:.*]] = arith.constant 1 : i64
+! CHECK:           %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : i64
+! CHECK:           %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (i64) -> index
+! CHECK:           %[[VAL_18:.*]] = arith.constant 0 : index
+! CHECK:           %[[VAL_19:.*]] = arith.cmpi sgt, %...
[truncated]

@jeanPerier
Copy link
Contributor Author

Note that this is a "classic" Fortran compiler optimization, you can verify that at least all of gfortran, ifx, and nvfortran are doing it on something very easy like:

subroutine test
interface
 function foo()
    real :: foo(100)
 end function
end interface
   real x(100)
   x = foo()
   call bar(x)
end subroutine

jeanPerier added a commit that referenced this pull request Dec 2, 2024
See HLFIROps.td change for the description of the operation.

The goal is to ease temporary storage elision for expression evaluation
(typically evaluating the RHS directly inside the LHS) for expressions
that do not have abtsractions in HLFIR and for which it is not clear
adding one would bring much. The case that is implemented in [the
following lowering
patch](#118070) is the array
call case, where adding a new hlfir.call would add complexity (needs to
deal with dispatch, inlining ....).
Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with the one nit

I added an assert that prevents ExprType to be "trivial types" (i32, f32, ...)
but that actually breaks existing lowering code that generate as_expr + associate
patterns to create temporaries:

https://github.com/llvm/llvm-project/blob/03b5f8f0f0d10c412842ed04b90e2217cf071218/flang/lib/Lower/ConvertCall.cpp#L1271

```
subroutine trivial_as_expr(i)
  integer, optional :: i
  interface
    subroutine foo(j)
      integer, optional, value :: j
    end subroutine
  end interface
  call foo(i)
end subroutine
```

This is unrelated to this patch, so leave it like this (it is not even a
functional problem).
@jeanPerier jeanPerier requested a review from vzakhari December 2, 2024 17:39
Base automatically changed from users/jperier/opt_eval_in_mem to main December 3, 2024 08:59
@jeanPerier
Copy link
Contributor Author

Thanks for all the reviews Tom!

@jeanPerier jeanPerier merged commit cd7e653 into main Dec 3, 2024
8 checks passed
@jeanPerier jeanPerier deleted the users/jperier/array_call_opt branch December 3, 2024 09:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants