Skip to content

Commit 5eef9ba

Browse files
authored
[flang] Inline hlfir.cshift as hlfir.elemental. (#119480)
1 parent ba373a2 commit 5eef9ba

File tree

2 files changed

+426
-0
lines changed

2 files changed

+426
-0
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "flang/Optimizer/Builder/Complex.h"
1414
#include "flang/Optimizer/Builder/FIRBuilder.h"
1515
#include "flang/Optimizer/Builder/HLFIRTools.h"
16+
#include "flang/Optimizer/Builder/IntrinsicCall.h"
1617
#include "flang/Optimizer/Dialect/FIRDialect.h"
1718
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
1819
#include "flang/Optimizer/HLFIR/HLFIROps.h"
@@ -331,6 +332,108 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
331332
}
332333
};
333334

335+
class CShiftAsElementalConversion
336+
: public mlir::OpRewritePattern<hlfir::CShiftOp> {
337+
public:
338+
using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
339+
340+
explicit CShiftAsElementalConversion(mlir::MLIRContext *ctx)
341+
: OpRewritePattern(ctx) {
342+
setHasBoundedRewriteRecursion();
343+
}
344+
345+
llvm::LogicalResult
346+
matchAndRewrite(hlfir::CShiftOp cshift,
347+
mlir::PatternRewriter &rewriter) const override {
348+
using Fortran::common::maxRank;
349+
350+
mlir::Location loc = cshift.getLoc();
351+
fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
352+
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType());
353+
assert(expr &&
354+
"expected an expression type for the result of hlfir.cshift");
355+
mlir::Type elementType = expr.getElementType();
356+
hlfir::Entity array = hlfir::Entity{cshift.getArray()};
357+
mlir::Value arrayShape = hlfir::genShape(loc, builder, array);
358+
llvm::SmallVector<mlir::Value> arrayExtents =
359+
hlfir::getExplicitExtentsFromShape(arrayShape, builder);
360+
unsigned arrayRank = expr.getRank();
361+
llvm::SmallVector<mlir::Value, 1> typeParams;
362+
hlfir::genLengthParameters(loc, builder, array, typeParams);
363+
hlfir::Entity shift = hlfir::Entity{cshift.getShift()};
364+
// The new index computation involves MODULO, which is not implemented
365+
// for IndexType, so use I64 instead.
366+
mlir::Type calcType = builder.getI64Type();
367+
368+
mlir::Value one = builder.createIntegerConstant(loc, calcType, 1);
369+
mlir::Value shiftVal;
370+
if (shift.isScalar()) {
371+
shiftVal = hlfir::loadTrivialScalar(loc, builder, shift);
372+
shiftVal = builder.createConvert(loc, calcType, shiftVal);
373+
}
374+
375+
int64_t dimVal = 1;
376+
if (arrayRank == 1) {
377+
// When it is a 1D CSHIFT, we may assume that the DIM argument
378+
// (whether it is present or absent) is equal to 1, otherwise,
379+
// the program is illegal.
380+
assert(shiftVal && "SHIFT must be scalar");
381+
} else {
382+
if (mlir::Value dim = cshift.getDim())
383+
dimVal = fir::getIntIfConstant(dim).value_or(0);
384+
assert(dimVal > 0 && dimVal <= arrayRank &&
385+
"DIM must be present and a positive constant not exceeding "
386+
"the array's rank");
387+
}
388+
389+
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
390+
mlir::ValueRange inputIndices) -> hlfir::Entity {
391+
llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
392+
if (!shift.isScalar()) {
393+
// When the array is not a vector, section
394+
// (s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)
395+
// of the result has a value equal to:
396+
// CSHIFT(ARRAY(s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)),
397+
// SH, 1),
398+
// where SH is either SHIFT (if scalar) or
399+
// SHIFT(s(1), s(2), ..., s(dim-1), s(dim+1), ..., s(n)).
400+
llvm::SmallVector<mlir::Value, maxRank> shiftIndices{indices};
401+
shiftIndices.erase(shiftIndices.begin() + dimVal - 1);
402+
hlfir::Entity shiftElement =
403+
hlfir::getElementAt(loc, builder, shift, shiftIndices);
404+
shiftVal = hlfir::loadTrivialScalar(loc, builder, shiftElement);
405+
shiftVal = builder.createConvert(loc, calcType, shiftVal);
406+
}
407+
408+
// Element i of the result (1-based) is element
409+
// 'MODULO(i + SH - 1, SIZE(ARRAY)) + 1' (1-based) of the original
410+
// ARRAY (or its section, when ARRAY is not a vector).
411+
mlir::Value index =
412+
builder.createConvert(loc, calcType, inputIndices[dimVal - 1]);
413+
mlir::Value extent = arrayExtents[dimVal - 1];
414+
mlir::Value newIndex =
415+
builder.create<mlir::arith::AddIOp>(loc, index, shiftVal);
416+
newIndex = builder.create<mlir::arith::SubIOp>(loc, newIndex, one);
417+
newIndex = fir::IntrinsicLibrary{builder, loc}.genModulo(
418+
calcType, {newIndex, builder.createConvert(loc, calcType, extent)});
419+
newIndex = builder.create<mlir::arith::AddIOp>(loc, newIndex, one);
420+
newIndex = builder.createConvert(loc, builder.getIndexType(), newIndex);
421+
422+
indices[dimVal - 1] = newIndex;
423+
hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
424+
return hlfir::loadTrivialScalar(loc, builder, element);
425+
};
426+
427+
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
428+
loc, builder, elementType, arrayShape, typeParams, genKernel,
429+
/*isUnordered=*/true,
430+
array.isPolymorphic() ? static_cast<mlir::Value>(array) : nullptr,
431+
cshift.getResult().getType());
432+
rewriter.replaceOp(cshift, elementalOp);
433+
return mlir::success();
434+
}
435+
};
436+
334437
class SimplifyHLFIRIntrinsics
335438
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
336439
public:
@@ -339,6 +442,7 @@ class SimplifyHLFIRIntrinsics
339442
mlir::RewritePatternSet patterns(context);
340443
patterns.insert<TransposeAsElementalConversion>(context);
341444
patterns.insert<SumAsElementalConversion>(context);
445+
patterns.insert<CShiftAsElementalConversion>(context);
342446
mlir::ConversionTarget target(*context);
343447
// don't transform transpose of polymorphic arrays (not currently supported
344448
// by hlfir.elemental)
@@ -375,6 +479,24 @@ class SimplifyHLFIRIntrinsics
375479
}
376480
return true;
377481
});
482+
target.addDynamicallyLegalOp<hlfir::CShiftOp>([](hlfir::CShiftOp cshift) {
483+
unsigned resultRank = hlfir::Entity{cshift}.getRank();
484+
if (resultRank == 1)
485+
return false;
486+
487+
mlir::Value dim = cshift.getDim();
488+
if (!dim)
489+
return false;
490+
491+
// If DIM is present, then it must be constant to please
492+
// the conversion. In addition, ignore cases with
493+
// illegal DIM values.
494+
if (auto dimVal = fir::getIntIfConstant(dim))
495+
if (*dimVal > 0 && *dimVal <= resultRank)
496+
return false;
497+
498+
return true;
499+
});
378500
target.markUnknownOpDynamicallyLegal(
379501
[](mlir::Operation *) { return true; });
380502
if (mlir::failed(mlir::applyFullConversion(getOperation(), target,

0 commit comments

Comments
 (0)