13
13
#include " flang/Optimizer/Builder/Complex.h"
14
14
#include " flang/Optimizer/Builder/FIRBuilder.h"
15
15
#include " flang/Optimizer/Builder/HLFIRTools.h"
16
+ #include " flang/Optimizer/Builder/IntrinsicCall.h"
16
17
#include " flang/Optimizer/Dialect/FIRDialect.h"
17
18
#include " flang/Optimizer/HLFIR/HLFIRDialect.h"
18
19
#include " flang/Optimizer/HLFIR/HLFIROps.h"
@@ -331,6 +332,108 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
331
332
}
332
333
};
333
334
335
+ class CShiftAsElementalConversion
336
+ : public mlir::OpRewritePattern<hlfir::CShiftOp> {
337
+ public:
338
+ using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
339
+
340
+ explicit CShiftAsElementalConversion (mlir::MLIRContext *ctx)
341
+ : OpRewritePattern(ctx) {
342
+ setHasBoundedRewriteRecursion ();
343
+ }
344
+
345
+ llvm::LogicalResult
346
+ matchAndRewrite (hlfir::CShiftOp cshift,
347
+ mlir::PatternRewriter &rewriter) const override {
348
+ using Fortran::common::maxRank;
349
+
350
+ mlir::Location loc = cshift.getLoc ();
351
+ fir::FirOpBuilder builder{rewriter, cshift.getOperation ()};
352
+ hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType ());
353
+ assert (expr &&
354
+ " expected an expression type for the result of hlfir.cshift" );
355
+ mlir::Type elementType = expr.getElementType ();
356
+ hlfir::Entity array = hlfir::Entity{cshift.getArray ()};
357
+ mlir::Value arrayShape = hlfir::genShape (loc, builder, array);
358
+ llvm::SmallVector<mlir::Value> arrayExtents =
359
+ hlfir::getExplicitExtentsFromShape (arrayShape, builder);
360
+ unsigned arrayRank = expr.getRank ();
361
+ llvm::SmallVector<mlir::Value, 1 > typeParams;
362
+ hlfir::genLengthParameters (loc, builder, array, typeParams);
363
+ hlfir::Entity shift = hlfir::Entity{cshift.getShift ()};
364
+ // The new index computation involves MODULO, which is not implemented
365
+ // for IndexType, so use I64 instead.
366
+ mlir::Type calcType = builder.getI64Type ();
367
+
368
+ mlir::Value one = builder.createIntegerConstant (loc, calcType, 1 );
369
+ mlir::Value shiftVal;
370
+ if (shift.isScalar ()) {
371
+ shiftVal = hlfir::loadTrivialScalar (loc, builder, shift);
372
+ shiftVal = builder.createConvert (loc, calcType, shiftVal);
373
+ }
374
+
375
+ int64_t dimVal = 1 ;
376
+ if (arrayRank == 1 ) {
377
+ // When it is a 1D CSHIFT, we may assume that the DIM argument
378
+ // (whether it is present or absent) is equal to 1, otherwise,
379
+ // the program is illegal.
380
+ assert (shiftVal && " SHIFT must be scalar" );
381
+ } else {
382
+ if (mlir::Value dim = cshift.getDim ())
383
+ dimVal = fir::getIntIfConstant (dim).value_or (0 );
384
+ assert (dimVal > 0 && dimVal <= arrayRank &&
385
+ " DIM must be present and a positive constant not exceeding "
386
+ " the array's rank" );
387
+ }
388
+
389
+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
390
+ mlir::ValueRange inputIndices) -> hlfir::Entity {
391
+ llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
392
+ if (!shift.isScalar ()) {
393
+ // When the array is not a vector, section
394
+ // (s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)
395
+ // of the result has a value equal to:
396
+ // CSHIFT(ARRAY(s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)),
397
+ // SH, 1),
398
+ // where SH is either SHIFT (if scalar) or
399
+ // SHIFT(s(1), s(2), ..., s(dim-1), s(dim+1), ..., s(n)).
400
+ llvm::SmallVector<mlir::Value, maxRank> shiftIndices{indices};
401
+ shiftIndices.erase (shiftIndices.begin () + dimVal - 1 );
402
+ hlfir::Entity shiftElement =
403
+ hlfir::getElementAt (loc, builder, shift, shiftIndices);
404
+ shiftVal = hlfir::loadTrivialScalar (loc, builder, shiftElement);
405
+ shiftVal = builder.createConvert (loc, calcType, shiftVal);
406
+ }
407
+
408
+ // Element i of the result (1-based) is element
409
+ // 'MODULO(i + SH - 1, SIZE(ARRAY)) + 1' (1-based) of the original
410
+ // ARRAY (or its section, when ARRAY is not a vector).
411
+ mlir::Value index =
412
+ builder.createConvert (loc, calcType, inputIndices[dimVal - 1 ]);
413
+ mlir::Value extent = arrayExtents[dimVal - 1 ];
414
+ mlir::Value newIndex =
415
+ builder.create <mlir::arith::AddIOp>(loc, index, shiftVal);
416
+ newIndex = builder.create <mlir::arith::SubIOp>(loc, newIndex, one);
417
+ newIndex = fir::IntrinsicLibrary{builder, loc}.genModulo (
418
+ calcType, {newIndex, builder.createConvert (loc, calcType, extent)});
419
+ newIndex = builder.create <mlir::arith::AddIOp>(loc, newIndex, one);
420
+ newIndex = builder.createConvert (loc, builder.getIndexType (), newIndex);
421
+
422
+ indices[dimVal - 1 ] = newIndex;
423
+ hlfir::Entity element = hlfir::getElementAt (loc, builder, array, indices);
424
+ return hlfir::loadTrivialScalar (loc, builder, element);
425
+ };
426
+
427
+ hlfir::ElementalOp elementalOp = hlfir::genElementalOp (
428
+ loc, builder, elementType, arrayShape, typeParams, genKernel,
429
+ /* isUnordered=*/ true ,
430
+ array.isPolymorphic () ? static_cast <mlir::Value>(array) : nullptr ,
431
+ cshift.getResult ().getType ());
432
+ rewriter.replaceOp (cshift, elementalOp);
433
+ return mlir::success ();
434
+ }
435
+ };
436
+
334
437
class SimplifyHLFIRIntrinsics
335
438
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
336
439
public:
@@ -339,6 +442,7 @@ class SimplifyHLFIRIntrinsics
339
442
mlir::RewritePatternSet patterns (context);
340
443
patterns.insert <TransposeAsElementalConversion>(context);
341
444
patterns.insert <SumAsElementalConversion>(context);
445
+ patterns.insert <CShiftAsElementalConversion>(context);
342
446
mlir::ConversionTarget target (*context);
343
447
// don't transform transpose of polymorphic arrays (not currently supported
344
448
// by hlfir.elemental)
@@ -375,6 +479,24 @@ class SimplifyHLFIRIntrinsics
375
479
}
376
480
return true ;
377
481
});
482
+ target.addDynamicallyLegalOp <hlfir::CShiftOp>([](hlfir::CShiftOp cshift) {
483
+ unsigned resultRank = hlfir::Entity{cshift}.getRank ();
484
+ if (resultRank == 1 )
485
+ return false ;
486
+
487
+ mlir::Value dim = cshift.getDim ();
488
+ if (!dim)
489
+ return false ;
490
+
491
+ // If DIM is present, then it must be constant to please
492
+ // the conversion. In addition, ignore cases with
493
+ // illegal DIM values.
494
+ if (auto dimVal = fir::getIntIfConstant (dim))
495
+ if (*dimVal > 0 && *dimVal <= resultRank)
496
+ return false ;
497
+
498
+ return true ;
499
+ });
378
500
target.markUnknownOpDynamicallyLegal (
379
501
[](mlir::Operation *) { return true ; });
380
502
if (mlir::failed (mlir::applyFullConversion (getOperation (), target,
0 commit comments