-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[flang] Expand SUM(DIM=CONSTANT) into an hlfir.elemental. #118556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
// into the calling function. | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "flang/Optimizer/Builder/Complex.h" | ||
#include "flang/Optimizer/Builder/FIRBuilder.h" | ||
#include "flang/Optimizer/Builder/HLFIRTools.h" | ||
#include "flang/Optimizer/Dialect/FIRDialect.h" | ||
|
@@ -90,13 +91,198 @@ class TransposeAsElementalConversion | |
} | ||
}; | ||
|
||
// Expand the SUM(DIM=CONSTANT) operation into . | ||
class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> { | ||
public: | ||
using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern; | ||
|
||
llvm::LogicalResult | ||
matchAndRewrite(hlfir::SumOp sum, | ||
mlir::PatternRewriter &rewriter) const override { | ||
mlir::Location loc = sum.getLoc(); | ||
fir::FirOpBuilder builder{rewriter, sum.getOperation()}; | ||
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType()); | ||
assert(expr && "expected an expression type for the result of hlfir.sum"); | ||
mlir::Type elementType = expr.getElementType(); | ||
hlfir::Entity array = hlfir::Entity{sum.getArray()}; | ||
mlir::Value mask = sum.getMask(); | ||
mlir::Value dim = sum.getDim(); | ||
int64_t dimVal = fir::getIntIfConstant(dim).value_or(0); | ||
assert(dimVal > 0 && "DIM must be present and a positive constant"); | ||
mlir::Value resultShape, dimExtent; | ||
std::tie(resultShape, dimExtent) = | ||
genResultShape(loc, builder, array, dimVal); | ||
|
||
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder, | ||
mlir::ValueRange inputIndices) -> hlfir::Entity { | ||
// Loop over all indices in the DIM dimension, and reduce all values. | ||
// We do not need to create the reduction loop always: if we can | ||
// slice the input array given the inputIndices, then we can | ||
// just apply a new SUM operation (total reduction) to the slice. | ||
// For the time being, generate the explicit loop because the slicing | ||
// requires generating an elemental operation for the input array | ||
// (and the mask, if present). | ||
// TODO: produce the slices and new SUM after adding a pattern | ||
// for expanding total reduction SUM case. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
mlir::Type indexType = builder.getIndexType(); | ||
auto one = builder.createIntegerConstant(loc, indexType, 1); | ||
auto ub = builder.createConvert(loc, indexType, dimExtent); | ||
|
||
// Initial value for the reduction. | ||
mlir::Value initValue = genInitValue(loc, builder, elementType); | ||
|
||
// The reduction loop may be unordered if FastMathFlags::reassoc | ||
// transformations are allowed. The integer reduction is always | ||
// unordered. | ||
bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) || | ||
static_cast<bool>(sum.getFastmath() & | ||
mlir::arith::FastMathFlags::reassoc); | ||
jeanPerier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// If the mask is present and is a scalar, then we'd better load its value | ||
// outside of the reduction loop making the loop unswitching easier. | ||
// Maybe it is worth hoisting it from the elemental operation as well. | ||
if (mask) { | ||
hlfir::Entity maskValue{mask}; | ||
if (maskValue.isScalar()) | ||
mask = hlfir::loadTrivialScalar(loc, builder, maskValue); | ||
} | ||
|
||
// NOTE: the outer elemental operation may be lowered into | ||
// omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction | ||
// loop may appear disjoint from the workshare loop nest. | ||
// Moreover, the inner loop is not strictly nested (due to the reduction | ||
// starting value initialization), and the above omp dialect operations | ||
// cannot produce results. | ||
// It is unclear what we should do about it yet. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we will need some reduction concept in HLFIR and we can leave later passes map that to whatever parallelism concept. Also, I do not know of OpenMP/ACC handle reductions of expression (looking at OpenMP standard 5.2 section 5.5.8, it looks like the list item of an OpenMP reduction must be variables (array/array sections)), so I am not sure if the MLIR operations will be able to represent reduction on expression operands (A+B), while there is technically an opportunity to parallelize. Anyway, I think that is not in the scope of this patch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the idea of having a reduction concept. That might make it easier to share code across these various optimized intrinsic implementations. I think it could take a lot of work for OpenMP to take advantage of this. I'm happy to discuss more if you are interested. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it is possible to represent the reduction using a temporary reduction storage instead of the iter_args, but this requires making sure that the storage has proper data sharing attributes with regards to the enclosing parallel constructs. I would prefer to keep it as-is right now, and then think about OpenMP cases as they arise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is okay. Most intrinsics are going to be evaluated in a single thread in WORKSHARE for now (which is what some other compilers do too). In this case I think SUM would be best implemented with a special rewrite pattern for openmp using a reduction clause. In general, implementing good multithreaded versions of these intrinsics that are useful on both CPU and offloading devices is quite hard. My opinion is that we should only attempt this when there is a concrete performance case to benchmark. I wouldn't want this relatively rare openmp construct (with historically poor compiler support) to make performance work in the rest of the compiler more difficult. |
||
auto doLoop = builder.create<fir::DoLoopOp>( | ||
loc, one, ub, one, isUnordered, /*finalCountValue=*/false, | ||
mlir::ValueRange{initValue}); | ||
|
||
// Address the input array using the reduction loop's IV | ||
// for the DIM dimension. | ||
mlir::Value iv = doLoop.getInductionVar(); | ||
llvm::SmallVector<mlir::Value> indices{inputIndices}; | ||
indices.insert(indices.begin() + dimVal - 1, iv); | ||
|
||
mlir::OpBuilder::InsertionGuard guard(builder); | ||
builder.setInsertionPointToStart(doLoop.getBody()); | ||
mlir::Value reductionValue = doLoop.getRegionIterArgs()[0]; | ||
fir::IfOp ifOp; | ||
if (mask) { | ||
// Make the reduction value update conditional on the value | ||
// of the mask. | ||
hlfir::Entity maskValue{mask}; | ||
if (!maskValue.isScalar()) { | ||
// If the mask is an array, use the elemental and the loop indices | ||
// to address the proper mask element. | ||
maskValue = hlfir::getElementAt(loc, builder, maskValue, indices); | ||
maskValue = hlfir::loadTrivialScalar(loc, builder, maskValue); | ||
} | ||
mlir::Value isUnmasked = | ||
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue); | ||
ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked, | ||
/*withElseRegion=*/true); | ||
// In the 'else' block return the current reduction value. | ||
builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); | ||
builder.create<fir::ResultOp>(loc, reductionValue); | ||
|
||
// In the 'then' block do the actual addition. | ||
builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); | ||
} | ||
|
||
hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices); | ||
hlfir::Entity elementValue = | ||
hlfir::loadTrivialScalar(loc, builder, element); | ||
// NOTE: we can use "Kahan summation" same way as the runtime | ||
// (e.g. when fast-math is not allowed), but let's start with | ||
// the simple version. | ||
reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue); | ||
builder.create<fir::ResultOp>(loc, reductionValue); | ||
|
||
if (ifOp) { | ||
builder.setInsertionPointAfter(ifOp); | ||
builder.create<fir::ResultOp>(loc, ifOp.getResult(0)); | ||
} | ||
|
||
return hlfir::Entity{doLoop.getResult(0)}; | ||
}; | ||
hlfir::ElementalOp elementalOp = hlfir::genElementalOp( | ||
loc, builder, elementType, resultShape, {}, genKernel, | ||
/*isUnordered=*/true, /*polymorphicMold=*/nullptr, | ||
sum.getResult().getType()); | ||
|
||
// it wouldn't be safe to replace block arguments with a different | ||
// hlfir.expr type. Types can differ due to differing amounts of shape | ||
// information | ||
assert(elementalOp.getResult().getType() == sum.getResult().getType()); | ||
|
||
rewriter.replaceOp(sum, elementalOp); | ||
return mlir::success(); | ||
} | ||
|
||
private: | ||
// Return fir.shape specifying the shape of the result | ||
// of a SUM reduction with DIM=dimVal. The second return value | ||
// is the extent of the DIM dimension. | ||
static std::tuple<mlir::Value, mlir::Value> | ||
genResultShape(mlir::Location loc, fir::FirOpBuilder &builder, | ||
hlfir::Entity array, int64_t dimVal) { | ||
mlir::Value inShape = hlfir::genShape(loc, builder, array); | ||
llvm::SmallVector<mlir::Value> inExtents = | ||
hlfir::getExplicitExtentsFromShape(inShape, builder); | ||
if (inShape.getUses().empty()) | ||
inShape.getDefiningOp()->erase(); | ||
|
||
mlir::Value dimExtent = inExtents[dimVal - 1]; | ||
inExtents.erase(inExtents.begin() + dimVal - 1); | ||
return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent}; | ||
} | ||
|
||
// Generate the initial value for a SUM reduction with the given | ||
// data type. | ||
static mlir::Value genInitValue(mlir::Location loc, | ||
fir::FirOpBuilder &builder, | ||
mlir::Type elementType) { | ||
if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) { | ||
const llvm::fltSemantics &sem = ty.getFloatSemantics(); | ||
return builder.createRealConstant(loc, elementType, | ||
llvm::APFloat::getZero(sem)); | ||
} else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) { | ||
mlir::Value initValue = genInitValue(loc, builder, ty.getElementType()); | ||
return fir::factory::Complex{builder, loc}.createComplex(ty, initValue, | ||
initValue); | ||
} else if (mlir::isa<mlir::IntegerType>(elementType)) { | ||
return builder.createIntegerConstant(loc, elementType, 0); | ||
} | ||
|
||
llvm_unreachable("unsupported SUM reduction type"); | ||
} | ||
|
||
// Generate scalar addition of the two values (of the same data type). | ||
static mlir::Value genScalarAdd(mlir::Location loc, | ||
fir::FirOpBuilder &builder, | ||
mlir::Value value1, mlir::Value value2) { | ||
mlir::Type ty = value1.getType(); | ||
assert(ty == value2.getType() && "reduction values' types do not match"); | ||
if (mlir::isa<mlir::FloatType>(ty)) | ||
return builder.create<mlir::arith::AddFOp>(loc, value1, value2); | ||
jeanPerier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else if (mlir::isa<mlir::ComplexType>(ty)) | ||
return builder.create<fir::AddcOp>(loc, value1, value2); | ||
else if (mlir::isa<mlir::IntegerType>(ty)) | ||
return builder.create<mlir::arith::AddIOp>(loc, value1, value2); | ||
|
||
llvm_unreachable("unsupported SUM reduction type"); | ||
} | ||
}; | ||
|
||
class SimplifyHLFIRIntrinsics | ||
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> { | ||
public: | ||
void runOnOperation() override { | ||
mlir::MLIRContext *context = &getContext(); | ||
mlir::RewritePatternSet patterns(context); | ||
patterns.insert<TransposeAsElementalConversion>(context); | ||
patterns.insert<SumAsElementalConversion>(context); | ||
mlir::ConversionTarget target(*context); | ||
// don't transform transpose of polymorphic arrays (not currently supported | ||
// by hlfir.elemental) | ||
|
@@ -105,6 +291,24 @@ class SimplifyHLFIRIntrinsics | |
return mlir::cast<hlfir::ExprType>(transpose.getType()) | ||
.isPolymorphic(); | ||
}); | ||
// Handle only SUM(DIM=CONSTANT) case for now. | ||
// It may be beneficial to expand the non-DIM case as well. | ||
// E.g. when the input array is an elemental array expression, | ||
// expanding the SUM into a total reduction loop nest | ||
// would avoid creating a temporary for the elemental array expression. | ||
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) { | ||
if (mlir::Value dim = sum.getDim()) { | ||
if (fir::getIntIfConstant(dim)) { | ||
if (!fir::isa_trivial(sum.getType())) { | ||
// Ignore the case SUM(a, DIM=X), where 'a' is a 1D array. | ||
// It is only legal when X is 1, and it should probably be | ||
// canonicalized into SUM(a). | ||
return false; | ||
} | ||
} | ||
} | ||
return true; | ||
}); | ||
target.markUnknownOpDynamicallyLegal( | ||
[](mlir::Operation *) { return true; }); | ||
if (mlir::failed(mlir::applyFullConversion(getOperation(), target, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
cast<>()
already has the assertion built in https://llvm.org/docs/ProgrammersManual.html#the-isa-cast-and-dyn-cast-templatesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am usually in favor of more explanatory assertion messages, so I would prefer to leave it here.