Skip to content

[flang][hlfir] add hlfir.eval_in_mem operation #118067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions flang/include/flang/Optimizer/Builder/HLFIRTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class AssociateOp;
class ElementalOp;
class ElementalOpInterface;
class ElementalAddrOp;
class EvaluateInMemoryOp;
class YieldElementOp;

/// Is this a Fortran variable for which the defining op carrying the Fortran
Expand Down Expand Up @@ -398,6 +399,24 @@ mlir::Value inlineElementalOp(
mlir::IRMapping &mapper,
const std::function<bool(hlfir::ElementalOp)> &mustRecursivelyInline);

/// Create a new temporary with the shape and parameters of the provided
/// hlfir.eval_in_mem operation and clone the body of the hlfir.eval_in_mem
/// operating on this new temporary. returns the temporary and whether the
/// temporary is heap or stack allocated.
std::pair<hlfir::Entity, bool>
computeEvaluateOpInNewTemp(mlir::Location, fir::FirOpBuilder &,
hlfir::EvaluateInMemoryOp evalInMem,
mlir::Value shape, mlir::ValueRange typeParams);

// Clone the body of the hlfir.eval_in_mem operating on this the provided
// storage. The provided storage must be a contiguous "raw" memory reference
// (not a fir.box) big enough to hold the value computed by hlfir.eval_in_mem.
// No runtime check is inserted by this utility to enforce that. It is also
// usually invalid to provide some storage that is already addressed directly
// or indirectly inside the hlfir.eval_in_mem body.
void computeEvaluateOpIn(mlir::Location, fir::FirOpBuilder &,
hlfir::EvaluateInMemoryOp, mlir::Value storage);

std::pair<fir::ExtendedValue, std::optional<hlfir::CleanupFunction>>
convertToValue(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity entity);
Expand Down
59 changes: 59 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1755,4 +1755,63 @@ def hlfir_CharExtremumOp : hlfir_Op<"char_extremum",
let hasVerifier = 1;
}

def hlfir_EvaluateInMemoryOp : hlfir_Op<"eval_in_mem", [AttrSizedOperandSegments,
RecursiveMemoryEffects, RecursivelySpeculatable,
SingleBlockImplicitTerminator<"fir::FirEndOp">]> {
let summary = "Wrap an in-memory implementation that computes expression value";
let description = [{
Returns a Fortran expression value for which the computation is
implemented inside the region operating on the block argument which
is a raw memory reference corresponding to the expression type.

The shape and type parameters of the expressions are operands of the
operations.

The memory cannot escape the region, and it is not described how it is
allocated. This facilitates later elision of the temporary storage for the
expression evaluation if it can be evaluated in some other storage (like a
left-hand side variable).

Example:

A function returning an array can be represented as:
```
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
%2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
%3 = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
}
```
}];

let arguments = (ins
Optional<fir_ShapeType>:$shape,
Variadic<AnyIntegerType>:$typeparams
);

let results = (outs hlfir_ExprType);
let regions = (region SizedRegion<1>:$body);

let assemblyFormat = [{
(`shape` $shape^)? (`typeparams` $typeparams^)?
attr-dict `:` functional-type(operands, results)
$body}];

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "mlir::Type":$result_type, "mlir::Value":$shape,
CArg<"mlir::ValueRange", "{}">:$typeparams)>
];

let extraClassDeclaration = [{
// Return block argument representing the memory where the expression
// is evaluated.
mlir::Value getMemory() {return getBody().getArgument(0);}
}];

let hasVerifier = 1;
}


#endif // FORTRAN_DIALECT_HLFIR_OPS
47 changes: 47 additions & 0 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,8 @@ static mlir::Value tryRetrievingShapeOrShift(hlfir::Entity entity) {
if (mlir::isa<hlfir::ExprType>(entity.getType())) {
if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>())
return elemental.getShape();
if (auto evalInMem = entity.getDefiningOp<hlfir::EvaluateInMemoryOp>())
return evalInMem.getShape();
return mlir::Value{};
}
if (auto varIface = entity.getIfVariableInterface())
Expand Down Expand Up @@ -642,6 +644,11 @@ void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
result.append(elemental.getTypeparams().begin(),
elemental.getTypeparams().end());
return;
} else if (auto evalInMem =
expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) {
result.append(evalInMem.getTypeparams().begin(),
evalInMem.getTypeparams().end());
return;
} else if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) {
result.append(apply.getTypeparams().begin(), apply.getTypeparams().end());
return;
Expand Down Expand Up @@ -1313,3 +1320,43 @@ hlfir::genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder,
};
return {hlfir::Entity{convertedRhs}, cleanup};
}

std::pair<hlfir::Entity, bool> hlfir::computeEvaluateOpInNewTemp(
mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::EvaluateInMemoryOp evalInMem, mlir::Value shape,
mlir::ValueRange typeParams) {
llvm::StringRef tmpName{".tmp.expr_result"};
llvm::SmallVector<mlir::Value> extents =
hlfir::getIndexExtents(loc, builder, shape);
mlir::Type baseType =
hlfir::getFortranElementOrSequenceType(evalInMem.getType());
bool heapAllocated = fir::hasDynamicSize(baseType);
// Note: temporaries are stack allocated here when possible (do not require
// stack save/restore) because flang has always stack allocated function
// results.
mlir::Value temp = heapAllocated
? builder.createHeapTemporary(loc, baseType, tmpName,
extents, typeParams)
: builder.createTemporary(loc, baseType, tmpName,
extents, typeParams);
mlir::Value innerMemory = evalInMem.getMemory();
temp = builder.createConvert(loc, innerMemory.getType(), temp);
auto declareOp = builder.create<hlfir::DeclareOp>(
loc, temp, tmpName, shape, typeParams,
/*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
computeEvaluateOpIn(loc, builder, evalInMem, declareOp.getOriginalBase());
return {hlfir::Entity{declareOp.getBase()}, /*heapAllocated=*/heapAllocated};
}

void hlfir::computeEvaluateOpIn(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::EvaluateInMemoryOp evalInMem,
mlir::Value storage) {
mlir::Value innerMemory = evalInMem.getMemory();
mlir::Value storageCast =
builder.createConvert(loc, innerMemory.getType(), storage);
mlir::IRMapping mapper;
mapper.map(innerMemory, storageCast);
for (auto &op : evalInMem.getBody().front().without_terminator())
builder.clone(op, mapper);
return;
}
76 changes: 62 additions & 14 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,25 @@ static void printDesignatorComplexPart(mlir::OpAsmPrinter &p,
p << "real";
}
}
template <typename Op>
static llvm::LogicalResult verifyTypeparams(Op &op, mlir::Type elementType,
unsigned numLenParam) {
if (mlir::isa<fir::CharacterType>(elementType)) {
if (numLenParam != 1)
return op.emitOpError("must be provided one length parameter when the "
"result is a character");
} else if (fir::isRecordWithTypeParameters(elementType)) {
if (numLenParam !=
mlir::cast<fir::RecordType>(elementType).getNumLenParams())
return op.emitOpError("must be provided the same number of length "
"parameters as in the result derived type");
} else if (numLenParam != 0) {
return op.emitOpError(
"must not be provided length parameters if the result "
"type does not have length parameters");
}
return mlir::success();
}

llvm::LogicalResult hlfir::DesignateOp::verify() {
mlir::Type memrefType = getMemref().getType();
Expand Down Expand Up @@ -462,20 +481,10 @@ llvm::LogicalResult hlfir::DesignateOp::verify() {
return emitOpError("shape must be a fir.shape or fir.shapeshift with "
"the rank of the result");
}
auto numLenParam = getTypeparams().size();
if (mlir::isa<fir::CharacterType>(outputElementType)) {
if (numLenParam != 1)
return emitOpError("must be provided one length parameter when the "
"result is a character");
} else if (fir::isRecordWithTypeParameters(outputElementType)) {
if (numLenParam !=
mlir::cast<fir::RecordType>(outputElementType).getNumLenParams())
return emitOpError("must be provided the same number of length "
"parameters as in the result derived type");
} else if (numLenParam != 0) {
return emitOpError("must not be provided length parameters if the result "
"type does not have length parameters");
}
if (auto res =
verifyTypeparams(*this, outputElementType, getTypeparams().size());
failed(res))
return res;
}
return mlir::success();
}
Expand Down Expand Up @@ -1989,6 +1998,45 @@ hlfir::GetLengthOp::canonicalize(GetLengthOp getLength,
return mlir::success();
}

//===----------------------------------------------------------------------===//
// EvaluateInMemoryOp
//===----------------------------------------------------------------------===//

void hlfir::EvaluateInMemoryOp::build(mlir::OpBuilder &builder,
mlir::OperationState &odsState,
mlir::Type resultType, mlir::Value shape,
mlir::ValueRange typeparams) {
odsState.addTypes(resultType);
if (shape)
odsState.addOperands(shape);
odsState.addOperands(typeparams);
odsState.addAttribute(
getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr(
{shape ? 1 : 0, static_cast<int32_t>(typeparams.size())}));
mlir::Region *bodyRegion = odsState.addRegion();
bodyRegion->push_back(new mlir::Block{});
mlir::Type memType = fir::ReferenceType::get(
hlfir::getFortranElementOrSequenceType(resultType));
bodyRegion->front().addArgument(memType, odsState.location);
EvaluateInMemoryOp::ensureTerminator(*bodyRegion, builder, odsState.location);
}

llvm::LogicalResult hlfir::EvaluateInMemoryOp::verify() {
unsigned shapeRank = 0;
if (mlir::Value shape = getShape())
if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(shape.getType()))
shapeRank = shapeTy.getRank();
auto exprType = mlir::cast<hlfir::ExprType>(getResult().getType());
if (shapeRank != exprType.getRank())
return emitOpError("`shape` rank must match the result rank");
mlir::Type elementType = exprType.getElementType();
if (auto res = verifyTypeparams(*this, elementType, getTypeparams().size());
failed(res))
return res;
return mlir::success();
}

#include "flang/Optimizer/HLFIR/HLFIROpInterfaces.cpp.inc"
#define GET_OP_CLASSES
#include "flang/Optimizer/HLFIR/HLFIREnums.cpp.inc"
Expand Down
33 changes: 27 additions & 6 deletions flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,26 @@ struct CharExtremumOpConversion
}
};

struct EvaluateInMemoryOpConversion
: public mlir::OpConversionPattern<hlfir::EvaluateInMemoryOp> {
using mlir::OpConversionPattern<
hlfir::EvaluateInMemoryOp>::OpConversionPattern;
explicit EvaluateInMemoryOpConversion(mlir::MLIRContext *ctx)
: mlir::OpConversionPattern<hlfir::EvaluateInMemoryOp>{ctx} {}
llvm::LogicalResult
matchAndRewrite(hlfir::EvaluateInMemoryOp evalInMemOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = evalInMemOp->getLoc();
fir::FirOpBuilder builder(rewriter, evalInMemOp.getOperation());
auto [temp, isHeapAlloc] = hlfir::computeEvaluateOpInNewTemp(
loc, builder, evalInMemOp, adaptor.getShape(), adaptor.getTypeparams());
mlir::Value bufferizedExpr =
packageBufferizedExpr(loc, builder, temp, isHeapAlloc);
rewriter.replaceOp(evalInMemOp, bufferizedExpr);
return mlir::success();
}
};

class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
public:
void runOnOperation() override {
Expand All @@ -918,12 +938,13 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
auto module = this->getOperation();
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
AssociateOpConversion, CharExtremumOpConversion,
ConcatOpConversion, DestroyOpConversion,
ElementalOpConversion, EndAssociateOpConversion,
NoReassocOpConversion, SetLengthOpConversion,
ShapeOfOpConversion, GetLengthOpConversion>(context);
patterns
.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
AssociateOpConversion, CharExtremumOpConversion,
ConcatOpConversion, DestroyOpConversion, ElementalOpConversion,
EndAssociateOpConversion, EvaluateInMemoryOpConversion,
NoReassocOpConversion, SetLengthOpConversion,
ShapeOfOpConversion, GetLengthOpConversion>(context);
mlir::ConversionTarget target(*context);
// Note that YieldElementOp is not marked as an illegal operation.
// It must be erased by its parent converter and there is no explicit
Expand Down
Loading