Skip to content

Commit 98aa694

Browse files
[mlir][scf] Add general affine.min canonicalization pattern
This canonicalization simplifies affine.min operations inside "for loop"-like operations (e.g., scf.for and scf.parallel) based on two invariants: * iv >= lb * iv < lb + step * ((ub - lb - 1) floorDiv step) + 1 This commit adds a new pass `canonicalize-scf-affine-min` (instead of being a canonicalization pattern) to avoid dependencies between the Affine dialect and the SCF dialect. Differential Revision: https://reviews.llvm.org/D107731
1 parent 120d97b commit 98aa694

File tree

6 files changed

+355
-1
lines changed

6 files changed

+355
-1
lines changed

mlir/include/mlir/Dialect/SCF/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ std::unique_ptr<Pass> createForLoopSpecializationPass();
2828
/// better vectorization.
2929
std::unique_ptr<Pass> createForLoopPeelingPass();
3030

31+
/// Creates a pass that canonicalizes affine.min ops in scf.for loops with
32+
/// known lower and upper bounds.
33+
std::unique_ptr<Pass> createAffineMinSCFCanonicalizationPass();
34+
3135
/// Creates a loop fusion pass which fuses parallel loops.
3236
std::unique_ptr<Pass> createParallelLoopFusionPass();
3337

mlir/include/mlir/Dialect/SCF/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@ def SCFBufferize : FunctionPass<"scf-bufferize"> {
1717
let dependentDialects = ["memref::MemRefDialect"];
1818
}
1919

20+
// Note: Making this a canonicalization pattern would require a dependency
21+
// of the SCF dialect on the Affine dialect or vice versa.
22+
def AffineMinSCFCanonicalization
23+
: FunctionPass<"canonicalize-scf-affine-min"> {
24+
let summary = "Canonicalize affine.min ops in the context of SCF loops with "
25+
"known bounds";
26+
let constructor = "mlir::createAffineMinSCFCanonicalizationPass()";
27+
let dependentDialects = ["AffineDialect"];
28+
}
29+
2030
def SCFForLoopPeeling
2131
: FunctionPass<"for-loop-peeling"> {
2232
let summary = "Peel `for` loops at their upper bounds.";

mlir/include/mlir/Dialect/SCF/Transforms.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
#ifndef MLIR_DIALECT_SCF_TRANSFORMS_H_
1414
#define MLIR_DIALECT_SCF_TRANSFORMS_H_
1515

16+
#include "mlir/Support/LLVM.h"
1617
#include "llvm/ADT/ArrayRef.h"
1718

1819
namespace mlir {
1920

21+
class AffineMinOp;
2022
class ConversionTarget;
2123
struct LogicalResult;
2224
class MLIRContext;
@@ -26,6 +28,7 @@ class TypeConverter;
2628
class RewritePatternSet;
2729
using OwningRewritePatternList = RewritePatternSet;
2830
class Operation;
31+
class Value;
2932

3033
namespace scf {
3134

@@ -34,6 +37,21 @@ class ForOp;
3437
class ParallelOp;
3538
class ForOp;
3639

40+
/// Try to canonicalize an affine.min operation in the context of `for` loops
41+
/// with a known range.
42+
///
43+
/// `loopMatcher` is used to retrieve loop bounds and step size for a given
44+
/// iteration variable: If the first parameter is an iteration variable, return
45+
/// lower/upper bounds via the second/third parameter and the step size via the
46+
/// last parameter. The function should return `success` in that case. If the
47+
/// first parameter is not an iteration variable, return `failure`.
48+
///
49+
/// Note: `loopMatcher` allows this function to be used with any "for loop"-like
50+
/// operation (scf.for, scf.parallel and even ops defined in other dialects).
51+
LogicalResult canonicalizeAffineMinOpInLoop(
52+
AffineMinOp minOp, RewriterBase &rewriter,
53+
function_ref<LogicalResult(Value, Value &, Value &, Value &)> loopMatcher);
54+
3755
/// Fuses all adjacent scf.parallel operations with identical bounds and step
3856
/// into one scf.parallel operations. Uses a naive aliasing and dependency
3957
/// analysis.
@@ -149,6 +167,11 @@ struct PipeliningOption {
149167
void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
150168
const PipeliningOption &options);
151169

170+
/// Populate patterns for canonicalizing operations inside SCF loop bodies.
171+
/// At the moment, only affine.min computations with iteration variables,
172+
/// loop bounds and loop steps are canonicalized.
173+
void populateSCFLoopBodyCanonicalizationPatterns(RewritePatternSet &patterns);
174+
152175
} // namespace scf
153176
} // namespace mlir
154177

mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,77 @@ LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter,
382382
return success();
383383
}
384384

385+
/// Canonicalize AffineMinOp operations in the context of for loops with a known
386+
/// range. Call `canonicalizeAffineMinOp` and add the following constraints to
387+
/// the constraint system (along with the missing dimensions):
388+
///
389+
/// * iv >= lb
390+
/// * iv < lb + step * ((ub - lb - 1) floorDiv step) + 1
391+
///
392+
/// Note: Due to limitations of FlatAffineConstraints, only constant step sizes
393+
/// are currently supported.
394+
LogicalResult mlir::scf::canonicalizeAffineMinOpInLoop(
395+
AffineMinOp minOp, RewriterBase &rewriter,
396+
function_ref<LogicalResult(Value, Value &, Value &, Value &)> loopMatcher) {
397+
FlatAffineValueConstraints constraints;
398+
DenseSet<Value> allIvs;
399+
400+
// Find all iteration variables among `minOp`'s operands add constrain them.
401+
for (Value operand : minOp.operands()) {
402+
// Skip duplicate ivs.
403+
if (llvm::find(allIvs, operand) != allIvs.end())
404+
continue;
405+
406+
// If `operand` is an iteration variable: Find corresponding loop
407+
// bounds and step.
408+
Value iv = operand;
409+
Value lb, ub, step;
410+
if (failed(loopMatcher(operand, lb, ub, step)))
411+
continue;
412+
allIvs.insert(iv);
413+
414+
// FlatAffineConstraints does not support semi-affine expressions.
415+
// Therefore, only constant step values are supported.
416+
auto stepInt = getConstantIntValue(step);
417+
if (!stepInt)
418+
continue;
419+
420+
unsigned dimIv = constraints.addDimId(iv);
421+
unsigned dimLb = constraints.addDimId(lb);
422+
unsigned dimUb = constraints.addDimId(ub);
423+
424+
// If loop lower/upper bounds are constant: Add EQ constraint.
425+
Optional<int64_t> lbInt = getConstantIntValue(lb);
426+
Optional<int64_t> ubInt = getConstantIntValue(ub);
427+
if (lbInt)
428+
constraints.addBound(FlatAffineConstraints::EQ, dimLb, *lbInt);
429+
if (ubInt)
430+
constraints.addBound(FlatAffineConstraints::EQ, dimUb, *ubInt);
431+
432+
// iv >= lb (equiv.: iv - lb >= 0)
433+
SmallVector<int64_t> ineqLb(constraints.getNumCols(), 0);
434+
ineqLb[dimIv] = 1;
435+
ineqLb[dimLb] = -1;
436+
constraints.addInequality(ineqLb);
437+
438+
// iv < lb + step * ((ub - lb - 1) floorDiv step) + 1
439+
AffineExpr exprLb = lbInt ? rewriter.getAffineConstantExpr(*lbInt)
440+
: rewriter.getAffineDimExpr(dimLb);
441+
AffineExpr exprUb = ubInt ? rewriter.getAffineConstantExpr(*ubInt)
442+
: rewriter.getAffineDimExpr(dimUb);
443+
AffineExpr ivUb =
444+
exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1).floorDiv(*stepInt)));
445+
auto map = AffineMap::get(
446+
/*dimCount=*/constraints.getNumDimIds(),
447+
/*symbolCount=*/constraints.getNumSymbolIds(), /*result=*/ivUb);
448+
449+
if (failed(constraints.addBound(FlatAffineConstraints::UB, dimIv, map)))
450+
return failure();
451+
}
452+
453+
return canonicalizeAffineMinOp(rewriter, minOp, constraints);
454+
}
455+
385456
static constexpr char kPeeledLoopLabel[] = "__peeled_loop__";
386457
static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
387458

@@ -423,6 +494,39 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
423494
/// the direct parent.
424495
bool skipPartial;
425496
};
497+
498+
/// Canonicalize AffineMinOp operations in the context of scf.for and
499+
/// scf.parallel loops with a known range.
500+
struct AffineMinSCFCanonicalizationPattern
501+
: public OpRewritePattern<AffineMinOp> {
502+
using OpRewritePattern<AffineMinOp>::OpRewritePattern;
503+
504+
LogicalResult matchAndRewrite(AffineMinOp minOp,
505+
PatternRewriter &rewriter) const override {
506+
auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) {
507+
if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) {
508+
lb = forOp.lowerBound();
509+
ub = forOp.upperBound();
510+
step = forOp.step();
511+
return success();
512+
}
513+
if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) {
514+
for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) {
515+
if (parOp.getInductionVars()[idx] == iv) {
516+
lb = parOp.lowerBound()[idx];
517+
ub = parOp.upperBound()[idx];
518+
step = parOp.step()[idx];
519+
return success();
520+
}
521+
}
522+
return failure();
523+
}
524+
return failure();
525+
};
526+
527+
return scf::canonicalizeAffineMinOpInLoop(minOp, rewriter, loopMatcher);
528+
}
529+
};
426530
} // namespace
427531

428532
namespace {
@@ -456,8 +560,24 @@ struct ForLoopPeeling : public SCFForLoopPeelingBase<ForLoopPeeling> {
456560
});
457561
}
458562
};
563+
564+
struct AffineMinSCFCanonicalization
565+
: public AffineMinSCFCanonicalizationBase<AffineMinSCFCanonicalization> {
566+
void runOnFunction() override {
567+
FuncOp funcOp = getFunction();
568+
MLIRContext *ctx = funcOp.getContext();
569+
RewritePatternSet patterns(ctx);
570+
patterns.add<AffineMinSCFCanonicalizationPattern>(ctx);
571+
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
572+
signalPassFailure();
573+
}
574+
};
459575
} // namespace
460576

577+
std::unique_ptr<Pass> mlir::createAffineMinSCFCanonicalizationPass() {
578+
return std::make_unique<AffineMinSCFCanonicalization>();
579+
}
580+
461581
std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
462582
return std::make_unique<ParallelLoopSpecialization>();
463583
}
@@ -469,3 +589,8 @@ std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
469589
std::unique_ptr<Pass> mlir::createForLoopPeelingPass() {
470590
return std::make_unique<ForLoopPeeling>();
471591
}
592+
593+
void mlir::scf::populateSCFLoopBodyCanonicalizationPatterns(
594+
RewritePatternSet &patterns) {
595+
patterns.insert<AffineMinSCFCanonicalizationPattern>(patterns.getContext());
596+
}

0 commit comments

Comments
 (0)