Skip to content

Commit a545cf5

Browse files
authored
[flang][hlfir] add hlfir.eval_in_mem operation (#118067)
See HLFIROps.td change for the description of the operation. The goal is to ease temporary storage elision for expression evaluation (typically evaluating the RHS directly inside the LHS) for expressions that do not have abtsractions in HLFIR and for which it is not clear adding one would bring much. The case that is implemented in [the following lowering patch](#118070) is the array call case, where adding a new hlfir.call would add complexity (needs to deal with dispatch, inlining ....).
1 parent 770adc5 commit a545cf5

File tree

8 files changed

+454
-20
lines changed

8 files changed

+454
-20
lines changed

flang/include/flang/Optimizer/Builder/HLFIRTools.h

+19
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class AssociateOp;
3333
class ElementalOp;
3434
class ElementalOpInterface;
3535
class ElementalAddrOp;
36+
class EvaluateInMemoryOp;
3637
class YieldElementOp;
3738

3839
/// Is this a Fortran variable for which the defining op carrying the Fortran
@@ -398,6 +399,24 @@ mlir::Value inlineElementalOp(
398399
mlir::IRMapping &mapper,
399400
const std::function<bool(hlfir::ElementalOp)> &mustRecursivelyInline);
400401

402+
/// Create a new temporary with the shape and parameters of the provided
403+
/// hlfir.eval_in_mem operation and clone the body of the hlfir.eval_in_mem
404+
/// operating on this new temporary. returns the temporary and whether the
405+
/// temporary is heap or stack allocated.
406+
std::pair<hlfir::Entity, bool>
407+
computeEvaluateOpInNewTemp(mlir::Location, fir::FirOpBuilder &,
408+
hlfir::EvaluateInMemoryOp evalInMem,
409+
mlir::Value shape, mlir::ValueRange typeParams);
410+
411+
// Clone the body of the hlfir.eval_in_mem operating on this the provided
412+
// storage. The provided storage must be a contiguous "raw" memory reference
413+
// (not a fir.box) big enough to hold the value computed by hlfir.eval_in_mem.
414+
// No runtime check is inserted by this utility to enforce that. It is also
415+
// usually invalid to provide some storage that is already addressed directly
416+
// or indirectly inside the hlfir.eval_in_mem body.
417+
void computeEvaluateOpIn(mlir::Location, fir::FirOpBuilder &,
418+
hlfir::EvaluateInMemoryOp, mlir::Value storage);
419+
401420
std::pair<fir::ExtendedValue, std::optional<hlfir::CleanupFunction>>
402421
convertToValue(mlir::Location loc, fir::FirOpBuilder &builder,
403422
hlfir::Entity entity);

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

+59
Original file line numberDiff line numberDiff line change
@@ -1755,4 +1755,63 @@ def hlfir_CharExtremumOp : hlfir_Op<"char_extremum",
17551755
let hasVerifier = 1;
17561756
}
17571757

1758+
def hlfir_EvaluateInMemoryOp : hlfir_Op<"eval_in_mem", [AttrSizedOperandSegments,
1759+
RecursiveMemoryEffects, RecursivelySpeculatable,
1760+
SingleBlockImplicitTerminator<"fir::FirEndOp">]> {
1761+
let summary = "Wrap an in-memory implementation that computes expression value";
1762+
let description = [{
1763+
Returns a Fortran expression value for which the computation is
1764+
implemented inside the region operating on the block argument which
1765+
is a raw memory reference corresponding to the expression type.
1766+
1767+
The shape and type parameters of the expressions are operands of the
1768+
operations.
1769+
1770+
The memory cannot escape the region, and it is not described how it is
1771+
allocated. This facilitates later elision of the temporary storage for the
1772+
expression evaluation if it can be evaluated in some other storage (like a
1773+
left-hand side variable).
1774+
1775+
Example:
1776+
1777+
A function returning an array can be represented as:
1778+
```
1779+
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
1780+
%2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
1781+
^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
1782+
%3 = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
1783+
fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
1784+
}
1785+
```
1786+
}];
1787+
1788+
let arguments = (ins
1789+
Optional<fir_ShapeType>:$shape,
1790+
Variadic<AnyIntegerType>:$typeparams
1791+
);
1792+
1793+
let results = (outs hlfir_ExprType);
1794+
let regions = (region SizedRegion<1>:$body);
1795+
1796+
let assemblyFormat = [{
1797+
(`shape` $shape^)? (`typeparams` $typeparams^)?
1798+
attr-dict `:` functional-type(operands, results)
1799+
$body}];
1800+
1801+
let skipDefaultBuilders = 1;
1802+
let builders = [
1803+
OpBuilder<(ins "mlir::Type":$result_type, "mlir::Value":$shape,
1804+
CArg<"mlir::ValueRange", "{}">:$typeparams)>
1805+
];
1806+
1807+
let extraClassDeclaration = [{
1808+
// Return block argument representing the memory where the expression
1809+
// is evaluated.
1810+
mlir::Value getMemory() {return getBody().getArgument(0);}
1811+
}];
1812+
1813+
let hasVerifier = 1;
1814+
}
1815+
1816+
17581817
#endif // FORTRAN_DIALECT_HLFIR_OPS

flang/lib/Optimizer/Builder/HLFIRTools.cpp

+47
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,8 @@ static mlir::Value tryRetrievingShapeOrShift(hlfir::Entity entity) {
535535
if (mlir::isa<hlfir::ExprType>(entity.getType())) {
536536
if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>())
537537
return elemental.getShape();
538+
if (auto evalInMem = entity.getDefiningOp<hlfir::EvaluateInMemoryOp>())
539+
return evalInMem.getShape();
538540
return mlir::Value{};
539541
}
540542
if (auto varIface = entity.getIfVariableInterface())
@@ -642,6 +644,11 @@ void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
642644
result.append(elemental.getTypeparams().begin(),
643645
elemental.getTypeparams().end());
644646
return;
647+
} else if (auto evalInMem =
648+
expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) {
649+
result.append(evalInMem.getTypeparams().begin(),
650+
evalInMem.getTypeparams().end());
651+
return;
645652
} else if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) {
646653
result.append(apply.getTypeparams().begin(), apply.getTypeparams().end());
647654
return;
@@ -1313,3 +1320,43 @@ hlfir::genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder,
13131320
};
13141321
return {hlfir::Entity{convertedRhs}, cleanup};
13151322
}
1323+
1324+
std::pair<hlfir::Entity, bool> hlfir::computeEvaluateOpInNewTemp(
1325+
mlir::Location loc, fir::FirOpBuilder &builder,
1326+
hlfir::EvaluateInMemoryOp evalInMem, mlir::Value shape,
1327+
mlir::ValueRange typeParams) {
1328+
llvm::StringRef tmpName{".tmp.expr_result"};
1329+
llvm::SmallVector<mlir::Value> extents =
1330+
hlfir::getIndexExtents(loc, builder, shape);
1331+
mlir::Type baseType =
1332+
hlfir::getFortranElementOrSequenceType(evalInMem.getType());
1333+
bool heapAllocated = fir::hasDynamicSize(baseType);
1334+
// Note: temporaries are stack allocated here when possible (do not require
1335+
// stack save/restore) because flang has always stack allocated function
1336+
// results.
1337+
mlir::Value temp = heapAllocated
1338+
? builder.createHeapTemporary(loc, baseType, tmpName,
1339+
extents, typeParams)
1340+
: builder.createTemporary(loc, baseType, tmpName,
1341+
extents, typeParams);
1342+
mlir::Value innerMemory = evalInMem.getMemory();
1343+
temp = builder.createConvert(loc, innerMemory.getType(), temp);
1344+
auto declareOp = builder.create<hlfir::DeclareOp>(
1345+
loc, temp, tmpName, shape, typeParams,
1346+
/*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
1347+
computeEvaluateOpIn(loc, builder, evalInMem, declareOp.getOriginalBase());
1348+
return {hlfir::Entity{declareOp.getBase()}, /*heapAllocated=*/heapAllocated};
1349+
}
1350+
1351+
void hlfir::computeEvaluateOpIn(mlir::Location loc, fir::FirOpBuilder &builder,
1352+
hlfir::EvaluateInMemoryOp evalInMem,
1353+
mlir::Value storage) {
1354+
mlir::Value innerMemory = evalInMem.getMemory();
1355+
mlir::Value storageCast =
1356+
builder.createConvert(loc, innerMemory.getType(), storage);
1357+
mlir::IRMapping mapper;
1358+
mapper.map(innerMemory, storageCast);
1359+
for (auto &op : evalInMem.getBody().front().without_terminator())
1360+
builder.clone(op, mapper);
1361+
return;
1362+
}

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

+62-14
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,25 @@ static void printDesignatorComplexPart(mlir::OpAsmPrinter &p,
333333
p << "real";
334334
}
335335
}
336+
template <typename Op>
337+
static llvm::LogicalResult verifyTypeparams(Op &op, mlir::Type elementType,
338+
unsigned numLenParam) {
339+
if (mlir::isa<fir::CharacterType>(elementType)) {
340+
if (numLenParam != 1)
341+
return op.emitOpError("must be provided one length parameter when the "
342+
"result is a character");
343+
} else if (fir::isRecordWithTypeParameters(elementType)) {
344+
if (numLenParam !=
345+
mlir::cast<fir::RecordType>(elementType).getNumLenParams())
346+
return op.emitOpError("must be provided the same number of length "
347+
"parameters as in the result derived type");
348+
} else if (numLenParam != 0) {
349+
return op.emitOpError(
350+
"must not be provided length parameters if the result "
351+
"type does not have length parameters");
352+
}
353+
return mlir::success();
354+
}
336355

337356
llvm::LogicalResult hlfir::DesignateOp::verify() {
338357
mlir::Type memrefType = getMemref().getType();
@@ -462,20 +481,10 @@ llvm::LogicalResult hlfir::DesignateOp::verify() {
462481
return emitOpError("shape must be a fir.shape or fir.shapeshift with "
463482
"the rank of the result");
464483
}
465-
auto numLenParam = getTypeparams().size();
466-
if (mlir::isa<fir::CharacterType>(outputElementType)) {
467-
if (numLenParam != 1)
468-
return emitOpError("must be provided one length parameter when the "
469-
"result is a character");
470-
} else if (fir::isRecordWithTypeParameters(outputElementType)) {
471-
if (numLenParam !=
472-
mlir::cast<fir::RecordType>(outputElementType).getNumLenParams())
473-
return emitOpError("must be provided the same number of length "
474-
"parameters as in the result derived type");
475-
} else if (numLenParam != 0) {
476-
return emitOpError("must not be provided length parameters if the result "
477-
"type does not have length parameters");
478-
}
484+
if (auto res =
485+
verifyTypeparams(*this, outputElementType, getTypeparams().size());
486+
failed(res))
487+
return res;
479488
}
480489
return mlir::success();
481490
}
@@ -1989,6 +1998,45 @@ hlfir::GetLengthOp::canonicalize(GetLengthOp getLength,
19891998
return mlir::success();
19901999
}
19912000

2001+
//===----------------------------------------------------------------------===//
2002+
// EvaluateInMemoryOp
2003+
//===----------------------------------------------------------------------===//
2004+
2005+
void hlfir::EvaluateInMemoryOp::build(mlir::OpBuilder &builder,
2006+
mlir::OperationState &odsState,
2007+
mlir::Type resultType, mlir::Value shape,
2008+
mlir::ValueRange typeparams) {
2009+
odsState.addTypes(resultType);
2010+
if (shape)
2011+
odsState.addOperands(shape);
2012+
odsState.addOperands(typeparams);
2013+
odsState.addAttribute(
2014+
getOperandSegmentSizeAttr(),
2015+
builder.getDenseI32ArrayAttr(
2016+
{shape ? 1 : 0, static_cast<int32_t>(typeparams.size())}));
2017+
mlir::Region *bodyRegion = odsState.addRegion();
2018+
bodyRegion->push_back(new mlir::Block{});
2019+
mlir::Type memType = fir::ReferenceType::get(
2020+
hlfir::getFortranElementOrSequenceType(resultType));
2021+
bodyRegion->front().addArgument(memType, odsState.location);
2022+
EvaluateInMemoryOp::ensureTerminator(*bodyRegion, builder, odsState.location);
2023+
}
2024+
2025+
llvm::LogicalResult hlfir::EvaluateInMemoryOp::verify() {
2026+
unsigned shapeRank = 0;
2027+
if (mlir::Value shape = getShape())
2028+
if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(shape.getType()))
2029+
shapeRank = shapeTy.getRank();
2030+
auto exprType = mlir::cast<hlfir::ExprType>(getResult().getType());
2031+
if (shapeRank != exprType.getRank())
2032+
return emitOpError("`shape` rank must match the result rank");
2033+
mlir::Type elementType = exprType.getElementType();
2034+
if (auto res = verifyTypeparams(*this, elementType, getTypeparams().size());
2035+
failed(res))
2036+
return res;
2037+
return mlir::success();
2038+
}
2039+
19922040
#include "flang/Optimizer/HLFIR/HLFIROpInterfaces.cpp.inc"
19932041
#define GET_OP_CLASSES
19942042
#include "flang/Optimizer/HLFIR/HLFIREnums.cpp.inc"

flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp

+27-6
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,26 @@ struct CharExtremumOpConversion
905905
}
906906
};
907907

908+
struct EvaluateInMemoryOpConversion
909+
: public mlir::OpConversionPattern<hlfir::EvaluateInMemoryOp> {
910+
using mlir::OpConversionPattern<
911+
hlfir::EvaluateInMemoryOp>::OpConversionPattern;
912+
explicit EvaluateInMemoryOpConversion(mlir::MLIRContext *ctx)
913+
: mlir::OpConversionPattern<hlfir::EvaluateInMemoryOp>{ctx} {}
914+
llvm::LogicalResult
915+
matchAndRewrite(hlfir::EvaluateInMemoryOp evalInMemOp, OpAdaptor adaptor,
916+
mlir::ConversionPatternRewriter &rewriter) const override {
917+
mlir::Location loc = evalInMemOp->getLoc();
918+
fir::FirOpBuilder builder(rewriter, evalInMemOp.getOperation());
919+
auto [temp, isHeapAlloc] = hlfir::computeEvaluateOpInNewTemp(
920+
loc, builder, evalInMemOp, adaptor.getShape(), adaptor.getTypeparams());
921+
mlir::Value bufferizedExpr =
922+
packageBufferizedExpr(loc, builder, temp, isHeapAlloc);
923+
rewriter.replaceOp(evalInMemOp, bufferizedExpr);
924+
return mlir::success();
925+
}
926+
};
927+
908928
class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
909929
public:
910930
void runOnOperation() override {
@@ -918,12 +938,13 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
918938
auto module = this->getOperation();
919939
auto *context = &getContext();
920940
mlir::RewritePatternSet patterns(context);
921-
patterns.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
922-
AssociateOpConversion, CharExtremumOpConversion,
923-
ConcatOpConversion, DestroyOpConversion,
924-
ElementalOpConversion, EndAssociateOpConversion,
925-
NoReassocOpConversion, SetLengthOpConversion,
926-
ShapeOfOpConversion, GetLengthOpConversion>(context);
941+
patterns
942+
.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
943+
AssociateOpConversion, CharExtremumOpConversion,
944+
ConcatOpConversion, DestroyOpConversion, ElementalOpConversion,
945+
EndAssociateOpConversion, EvaluateInMemoryOpConversion,
946+
NoReassocOpConversion, SetLengthOpConversion,
947+
ShapeOfOpConversion, GetLengthOpConversion>(context);
927948
mlir::ConversionTarget target(*context);
928949
// Note that YieldElementOp is not marked as an illegal operation.
929950
// It must be erased by its parent converter and there is no explicit

0 commit comments

Comments
 (0)