@@ -382,6 +382,77 @@ LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter,
382
382
return success ();
383
383
}
384
384
385
+ // / Canonicalize AffineMinOp operations in the context of for loops with a known
386
+ // / range. Call `canonicalizeAffineMinOp` and add the following constraints to
387
+ // / the constraint system (along with the missing dimensions):
388
+ // /
389
+ // / * iv >= lb
390
+ // / * iv < lb + step * ((ub - lb - 1) floorDiv step) + 1
391
+ // /
392
+ // / Note: Due to limitations of FlatAffineConstraints, only constant step sizes
393
+ // / are currently supported.
394
+ LogicalResult mlir::scf::canonicalizeAffineMinOpInLoop (
395
+ AffineMinOp minOp, RewriterBase &rewriter,
396
+ function_ref<LogicalResult(Value, Value &, Value &, Value &)> loopMatcher) {
397
+ FlatAffineValueConstraints constraints;
398
+ DenseSet<Value> allIvs;
399
+
400
+ // Find all iteration variables among `minOp`'s operands add constrain them.
401
+ for (Value operand : minOp.operands ()) {
402
+ // Skip duplicate ivs.
403
+ if (llvm::find (allIvs, operand) != allIvs.end ())
404
+ continue ;
405
+
406
+ // If `operand` is an iteration variable: Find corresponding loop
407
+ // bounds and step.
408
+ Value iv = operand;
409
+ Value lb, ub, step;
410
+ if (failed (loopMatcher (operand, lb, ub, step)))
411
+ continue ;
412
+ allIvs.insert (iv);
413
+
414
+ // FlatAffineConstraints does not support semi-affine expressions.
415
+ // Therefore, only constant step values are supported.
416
+ auto stepInt = getConstantIntValue (step);
417
+ if (!stepInt)
418
+ continue ;
419
+
420
+ unsigned dimIv = constraints.addDimId (iv);
421
+ unsigned dimLb = constraints.addDimId (lb);
422
+ unsigned dimUb = constraints.addDimId (ub);
423
+
424
+ // If loop lower/upper bounds are constant: Add EQ constraint.
425
+ Optional<int64_t > lbInt = getConstantIntValue (lb);
426
+ Optional<int64_t > ubInt = getConstantIntValue (ub);
427
+ if (lbInt)
428
+ constraints.addBound (FlatAffineConstraints::EQ, dimLb, *lbInt);
429
+ if (ubInt)
430
+ constraints.addBound (FlatAffineConstraints::EQ, dimUb, *ubInt);
431
+
432
+ // iv >= lb (equiv.: iv - lb >= 0)
433
+ SmallVector<int64_t > ineqLb (constraints.getNumCols (), 0 );
434
+ ineqLb[dimIv] = 1 ;
435
+ ineqLb[dimLb] = -1 ;
436
+ constraints.addInequality (ineqLb);
437
+
438
+ // iv < lb + step * ((ub - lb - 1) floorDiv step) + 1
439
+ AffineExpr exprLb = lbInt ? rewriter.getAffineConstantExpr (*lbInt)
440
+ : rewriter.getAffineDimExpr (dimLb);
441
+ AffineExpr exprUb = ubInt ? rewriter.getAffineConstantExpr (*ubInt)
442
+ : rewriter.getAffineDimExpr (dimUb);
443
+ AffineExpr ivUb =
444
+ exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1 ).floorDiv (*stepInt)));
445
+ auto map = AffineMap::get (
446
+ /* dimCount=*/ constraints.getNumDimIds (),
447
+ /* symbolCount=*/ constraints.getNumSymbolIds (), /* result=*/ ivUb);
448
+
449
+ if (failed (constraints.addBound (FlatAffineConstraints::UB, dimIv, map)))
450
+ return failure ();
451
+ }
452
+
453
+ return canonicalizeAffineMinOp (rewriter, minOp, constraints);
454
+ }
455
+
385
456
static constexpr char kPeeledLoopLabel [] = " __peeled_loop__" ;
386
457
static constexpr char kPartialIterationLabel [] = " __partial_iteration__" ;
387
458
@@ -423,6 +494,39 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
423
494
// / the direct parent.
424
495
bool skipPartial;
425
496
};
497
+
498
+ // / Canonicalize AffineMinOp operations in the context of scf.for and
499
+ // / scf.parallel loops with a known range.
500
+ struct AffineMinSCFCanonicalizationPattern
501
+ : public OpRewritePattern<AffineMinOp> {
502
+ using OpRewritePattern<AffineMinOp>::OpRewritePattern;
503
+
504
+ LogicalResult matchAndRewrite (AffineMinOp minOp,
505
+ PatternRewriter &rewriter) const override {
506
+ auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) {
507
+ if (scf::ForOp forOp = scf::getForInductionVarOwner (iv)) {
508
+ lb = forOp.lowerBound ();
509
+ ub = forOp.upperBound ();
510
+ step = forOp.step ();
511
+ return success ();
512
+ }
513
+ if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner (iv)) {
514
+ for (unsigned idx = 0 ; idx < parOp.getNumLoops (); ++idx) {
515
+ if (parOp.getInductionVars ()[idx] == iv) {
516
+ lb = parOp.lowerBound ()[idx];
517
+ ub = parOp.upperBound ()[idx];
518
+ step = parOp.step ()[idx];
519
+ return success ();
520
+ }
521
+ }
522
+ return failure ();
523
+ }
524
+ return failure ();
525
+ };
526
+
527
+ return scf::canonicalizeAffineMinOpInLoop (minOp, rewriter, loopMatcher);
528
+ }
529
+ };
426
530
} // namespace
427
531
428
532
namespace {
@@ -456,8 +560,24 @@ struct ForLoopPeeling : public SCFForLoopPeelingBase<ForLoopPeeling> {
456
560
});
457
561
}
458
562
};
563
+
564
+ struct AffineMinSCFCanonicalization
565
+ : public AffineMinSCFCanonicalizationBase<AffineMinSCFCanonicalization> {
566
+ void runOnFunction () override {
567
+ FuncOp funcOp = getFunction ();
568
+ MLIRContext *ctx = funcOp.getContext ();
569
+ RewritePatternSet patterns (ctx);
570
+ patterns.add <AffineMinSCFCanonicalizationPattern>(ctx);
571
+ if (failed (applyPatternsAndFoldGreedily (funcOp, std::move (patterns))))
572
+ signalPassFailure ();
573
+ }
574
+ };
459
575
} // namespace
460
576
577
+ std::unique_ptr<Pass> mlir::createAffineMinSCFCanonicalizationPass () {
578
+ return std::make_unique<AffineMinSCFCanonicalization>();
579
+ }
580
+
461
581
std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass () {
462
582
return std::make_unique<ParallelLoopSpecialization>();
463
583
}
@@ -469,3 +589,8 @@ std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
469
589
std::unique_ptr<Pass> mlir::createForLoopPeelingPass () {
470
590
return std::make_unique<ForLoopPeeling>();
471
591
}
592
+
593
+ void mlir::scf::populateSCFLoopBodyCanonicalizationPatterns (
594
+ RewritePatternSet &patterns) {
595
+ patterns.insert <AffineMinSCFCanonicalizationPattern>(patterns.getContext ());
596
+ }
0 commit comments