Skip to content

[flang] optimize array function calls using hlfir.eval_in_mem #118070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion flang/include/flang/Lower/ConvertCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@

namespace Fortran::lower {

/// Data structure packaging the SSA value(s) produced for the result of lowered
/// function calls.
using LoweredResult =
std::variant<fir::ExtendedValue, hlfir::EntityWithAttributes>;

/// Given a call site for which the arguments were already lowered, generate
/// the call and return the result. This function deals with explicit result
/// allocation and lowering if needed. It also deals with passing the host
Expand All @@ -32,7 +37,7 @@ namespace Fortran::lower {
/// It is only used for HLFIR.
/// The returned boolean indicates if finalization has been emitted in
/// \p stmtCtx for the result.
std::pair<fir::ExtendedValue, bool> genCallOpAndResult(
std::pair<LoweredResult, bool> genCallOpAndResult(
mlir::Location loc, Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
Fortran::lower::CallerInterface &caller, mlir::FunctionType callSiteType,
Expand Down
6 changes: 6 additions & 0 deletions flang/include/flang/Optimizer/Analysis/AliasAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ struct AliasAnalysis {
/// Return the modify-reference behavior of `op` on `location`.
mlir::ModRefResult getModRef(mlir::Operation *op, mlir::Value location);

/// Return the modify-reference behavior of operations inside `region` on
/// `location`. Contrary to getModRef(operation, location), this will visit
/// nested regions recursively according to the HasRecursiveMemoryEffects
/// trait.
mlir::ModRefResult getModRef(mlir::Region &region, mlir::Value location);

/// Return the memory source of a value.
/// If getLastInstantiationPoint is true, the search for the source
/// will stop at [hl]fir.declare if it represents a dummy
Expand Down
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ inline mlir::Type getFortranElementOrSequenceType(mlir::Type type) {
return type;
}

/// Build the hlfir.expr type for the value held in a variable of type \p
/// variableType.
mlir::Type getExprType(mlir::Type variableType);

/// Is this a fir.box or fir.class address type?
inline bool isBoxAddressType(mlir::Type type) {
type = fir::dyn_cast_ptrEleTy(type);
Expand Down
101 changes: 68 additions & 33 deletions flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ static void remapActualToDummyDescriptors(
}
}

std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
std::pair<Fortran::lower::LoweredResult, bool>
Fortran::lower::genCallOpAndResult(
mlir::Location loc, Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
Fortran::lower::CallerInterface &caller, mlir::FunctionType callSiteType,
Expand Down Expand Up @@ -326,13 +327,20 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
}
}

const bool isExprCall =
converter.getLoweringOptions().getLowerToHighLevelFIR() &&
callSiteType.getNumResults() == 1 &&
llvm::isa<fir::SequenceType>(callSiteType.getResult(0));

mlir::IndexType idxTy = builder.getIndexType();
auto lowerSpecExpr = [&](const auto &expr) -> mlir::Value {
mlir::Value convertExpr = builder.createConvert(
loc, idxTy, fir::getBase(converter.genExprValue(expr, stmtCtx)));
return fir::factory::genMaxWithZero(builder, loc, convertExpr);
};
llvm::SmallVector<mlir::Value> resultLengths;
mlir::Value arrayResultShape;
hlfir::EvaluateInMemoryOp evaluateInMemory;
auto allocatedResult = [&]() -> std::optional<fir::ExtendedValue> {
llvm::SmallVector<mlir::Value> extents;
llvm::SmallVector<mlir::Value> lengths;
Expand Down Expand Up @@ -366,6 +374,18 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
resultLengths = lengths;
}

if (!extents.empty())
arrayResultShape = builder.genShape(loc, extents);

if (isExprCall) {
mlir::Type exprType = hlfir::getExprType(type);
evaluateInMemory = builder.create<hlfir::EvaluateInMemoryOp>(
loc, exprType, arrayResultShape, resultLengths);
builder.setInsertionPointToStart(&evaluateInMemory.getBody().front());
return toExtendedValue(loc, evaluateInMemory.getMemory(), extents,
lengths);
}

if ((!extents.empty() || !lengths.empty()) && !isElemental) {
// Note: in the elemental context, the alloca ownership inside the
// elemental region is implicit, and later pass in lowering (stack
Expand All @@ -384,8 +404,7 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
if (mustPopSymMap)
symMap.popScope();

// Place allocated result or prepare the fir.save_result arguments.
mlir::Value arrayResultShape;
// Place allocated result
if (allocatedResult) {
if (std::optional<Fortran::lower::CallInterface<
Fortran::lower::CallerInterface>::PassedEntity>
Expand All @@ -399,16 +418,6 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
else
fir::emitFatalError(
loc, "only expect character scalar result to be passed by ref");
} else {
assert(caller.mustSaveResult());
arrayResultShape = allocatedResult->match(
[&](const fir::CharArrayBoxValue &) {
return builder.createShape(loc, *allocatedResult);
},
[&](const fir::ArrayBoxValue &) {
return builder.createShape(loc, *allocatedResult);
},
[&](const auto &) { return mlir::Value{}; });
}
}

Expand Down Expand Up @@ -642,13 +651,39 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
callResult = call.getResult(0);
}

std::optional<Fortran::evaluate::DynamicType> retTy =
caller.getCallDescription().proc().GetType();
// With HLFIR lowering, isElemental must be set to true
// if we are producing an elemental call. In this case,
// the elemental results must not be destroyed, instead,
// the resulting array result will be finalized/destroyed
// as needed by hlfir.destroy.
const bool mustFinalizeResult =
!isElemental && callSiteType.getNumResults() > 0 &&
!fir::isPointerType(callSiteType.getResult(0)) && retTy.has_value() &&
(retTy->category() == Fortran::common::TypeCategory::Derived ||
retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic());

if (caller.mustSaveResult()) {
assert(allocatedResult.has_value());
builder.create<fir::SaveResultOp>(loc, callResult,
fir::getBase(*allocatedResult),
arrayResultShape, resultLengths);
}

if (evaluateInMemory) {
builder.setInsertionPointAfter(evaluateInMemory);
mlir::Value expr = evaluateInMemory.getResult();
fir::FirOpBuilder *bldr = &converter.getFirOpBuilder();
if (!isElemental)
stmtCtx.attachCleanup([bldr, loc, expr, mustFinalizeResult]() {
bldr->create<hlfir::DestroyOp>(loc, expr,
/*finalize=*/mustFinalizeResult);
});
return {LoweredResult{hlfir::EntityWithAttributes{expr}},
mustFinalizeResult};
}

if (allocatedResult) {
// The result must be optionally destroyed (if it is of a derived type
// that may need finalization or deallocation of the components).
Expand Down Expand Up @@ -679,17 +714,7 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
// derived-type.
// For polymorphic and unlimited polymorphic enities call the runtime
// in any cases.
std::optional<Fortran::evaluate::DynamicType> retTy =
caller.getCallDescription().proc().GetType();
// With HLFIR lowering, isElemental must be set to true
// if we are producing an elemental call. In this case,
// the elemental results must not be destroyed, instead,
// the resulting array result will be finalized/destroyed
// as needed by hlfir.destroy.
if (!isElemental && !fir::isPointerType(funcType.getResults()[0]) &&
retTy &&
(retTy->category() == Fortran::common::TypeCategory::Derived ||
retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic())) {
if (mustFinalizeResult) {
if (retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic()) {
auto *bldr = &converter.getFirOpBuilder();
stmtCtx.attachCleanup([bldr, loc, allocatedResult]() {
Expand All @@ -715,12 +740,13 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
}
}
}
return {*allocatedResult, resultIsFinalized};
return {LoweredResult{*allocatedResult}, resultIsFinalized};
}

// subroutine call
if (!resultType)
return {fir::ExtendedValue{mlir::Value{}}, /*resultIsFinalized=*/false};
return {LoweredResult{fir::ExtendedValue{mlir::Value{}}},
/*resultIsFinalized=*/false};

// For now, Fortran return values are implemented with a single MLIR
// function return value.
Expand All @@ -734,10 +760,13 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
mlir::dyn_cast<fir::CharacterType>(funcType.getResults()[0]);
mlir::Value len = builder.createIntegerConstant(
loc, builder.getCharacterLengthType(), charTy.getLen());
return {fir::CharBoxValue{callResult, len}, /*resultIsFinalized=*/false};
return {
LoweredResult{fir::ExtendedValue{fir::CharBoxValue{callResult, len}}},
/*resultIsFinalized=*/false};
}

return {callResult, /*resultIsFinalized=*/false};
return {LoweredResult{fir::ExtendedValue{callResult}},
/*resultIsFinalized=*/false};
}

static hlfir::EntityWithAttributes genStmtFunctionRef(
Expand Down Expand Up @@ -1661,19 +1690,25 @@ genUserCall(Fortran::lower::PreparedActualArguments &loweredActuals,
// Prepare lowered arguments according to the interface
// and map the lowered values to the dummy
// arguments.
auto [result, resultIsFinalized] = Fortran::lower::genCallOpAndResult(
auto [loweredResult, resultIsFinalized] = Fortran::lower::genCallOpAndResult(
loc, callContext.converter, callContext.symMap, callContext.stmtCtx,
caller, callSiteType, callContext.resultType,
callContext.isElementalProcWithArrayArgs());
// For procedure pointer function result, just return the call.
if (callContext.resultType &&
mlir::isa<fir::BoxProcType>(*callContext.resultType))
return hlfir::EntityWithAttributes(fir::getBase(result));

/// Clean-up associations and copy-in.
for (auto cleanUp : callCleanUps)
cleanUp.genCleanUp(loc, builder);

if (auto *entity = std::get_if<hlfir::EntityWithAttributes>(&loweredResult))
return *entity;

auto &result = std::get<fir::ExtendedValue>(loweredResult);

// For procedure pointer function result, just return the call.
if (callContext.resultType &&
mlir::isa<fir::BoxProcType>(*callContext.resultType))
return hlfir::EntityWithAttributes(fir::getBase(result));

if (!fir::getBase(result))
return std::nullopt; // subroutine call.

Expand Down
13 changes: 8 additions & 5 deletions flang/lib/Lower/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2852,10 +2852,11 @@ class ScalarExprLowering {
}
}

ExtValue result =
auto loweredResult =
Fortran::lower::genCallOpAndResult(loc, converter, symMap, stmtCtx,
caller, callSiteType, resultType)
.first;
auto &result = std::get<ExtValue>(loweredResult);

// Sync pointers and allocatables that may have been modified during the
// call.
Expand Down Expand Up @@ -4881,10 +4882,12 @@ class ArrayExprLowering {
[&](const auto &) { return fir::getBase(exv); });
caller.placeInput(argIface, arg);
}
return Fortran::lower::genCallOpAndResult(loc, converter, symMap,
getElementCtx(), caller,
callSiteType, retTy)
.first;
Fortran::lower::LoweredResult res =
Fortran::lower::genCallOpAndResult(loc, converter, symMap,
getElementCtx(), caller,
callSiteType, retTy)
.first;
return std::get<ExtValue>(res);
};
}

Expand Down
41 changes: 40 additions & 1 deletion flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ bool AliasAnalysis::Source::isDummyArgument() const {
return false;
}

static bool isEvaluateInMemoryBlockArg(mlir::Value v) {
if (auto evalInMem = llvm::dyn_cast_or_null<hlfir::EvaluateInMemoryOp>(
v.getParentRegion()->getParentOp()))
return evalInMem.getMemory() == v;
return false;
}

bool AliasAnalysis::Source::isData() const { return origin.isData; }
bool AliasAnalysis::Source::isBoxData() const {
return mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(valueType)) &&
Expand Down Expand Up @@ -457,6 +464,33 @@ ModRefResult AliasAnalysis::getModRef(Operation *op, Value location) {
return result;
}

ModRefResult AliasAnalysis::getModRef(mlir::Region &region,
mlir::Value location) {
ModRefResult result = ModRefResult::getNoModRef();
for (mlir::Operation &op : region.getOps()) {
if (op.hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) {
for (mlir::Region &subRegion : op.getRegions()) {
result = result.merge(getModRef(subRegion, location));
// Fast return is already mod and ref.
if (result.isModAndRef())
return result;
}
// In MLIR, RecursiveMemoryEffects can be combined with
// MemoryEffectOpInterface to describe extra effects on top of the
// effects of the nested operations. However, the presence of
// RecursiveMemoryEffects and the absence of MemoryEffectOpInterface
// implies the operation has no other memory effects than the one of its
// nested operations.
if (!mlir::isa<mlir::MemoryEffectOpInterface>(op))
continue;
}
result = result.merge(getModRef(&op, location));
if (result.isModAndRef())
return result;
}
return result;
}

AliasAnalysis::Source::Attributes
getAttrsFromVariable(fir::FortranVariableOpInterface var) {
AliasAnalysis::Source::Attributes attrs;
Expand Down Expand Up @@ -698,7 +732,7 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
breakFromLoop = true;
});
}
if (!defOp && type == SourceKind::Unknown)
if (!defOp && type == SourceKind::Unknown) {
// Check if the memory source is coming through a dummy argument.
if (isDummyArgument(v)) {
type = SourceKind::Argument;
Expand All @@ -708,7 +742,12 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,

if (isPointerReference(ty))
attributes.set(Attribute::Pointer);
} else if (isEvaluateInMemoryBlockArg(v)) {
// hlfir.eval_in_mem block operands is allocated by the operation.
type = SourceKind::Allocate;
ty = v.getType();
}
}

if (type == SourceKind::Global) {
return {{global, instantiationPoint, followingData},
Expand Down
13 changes: 13 additions & 0 deletions flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,16 @@ bool hlfir::mayHaveAllocatableComponent(mlir::Type ty) {
return fir::isPolymorphicType(ty) || fir::isUnlimitedPolymorphicType(ty) ||
fir::isRecordWithAllocatableMember(hlfir::getFortranElementType(ty));
}

mlir::Type hlfir::getExprType(mlir::Type variableType) {
hlfir::ExprType::Shape typeShape;
bool isPolymorphic = fir::isPolymorphicType(variableType);
mlir::Type type = getFortranElementOrSequenceType(variableType);
if (auto seqType = mlir::dyn_cast<fir::SequenceType>(type)) {
assert(!seqType.hasUnknownShape() && "assumed-rank cannot be expressions");
typeShape.append(seqType.getShape().begin(), seqType.getShape().end());
type = seqType.getEleTy();
}
return hlfir::ExprType::get(variableType.getContext(), typeShape, type,
isPolymorphic);
}
11 changes: 1 addition & 10 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1427,16 +1427,7 @@ llvm::LogicalResult hlfir::EndAssociateOp::verify() {
void hlfir::AsExprOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, mlir::Value var,
mlir::Value mustFree) {
hlfir::ExprType::Shape typeShape;
bool isPolymorphic = fir::isPolymorphicType(var.getType());
mlir::Type type = getFortranElementOrSequenceType(var.getType());
if (auto seqType = mlir::dyn_cast<fir::SequenceType>(type)) {
typeShape.append(seqType.getShape().begin(), seqType.getShape().end());
type = seqType.getEleTy();
}

auto resultType = hlfir::ExprType::get(builder.getContext(), typeShape, type,
isPolymorphic);
mlir::Type resultType = hlfir::getExprType(var.getType());
return build(builder, result, resultType, var, mustFree);
}

Expand Down
Loading
Loading