Skip to content

Commit 1ae2446

Browse files
authored
[mlir] Add forall canonicalization to replace constant induction vars (#112764)
Adds a canonicalization pattern for scf.forall that replaces constant induction variables with a constant index. There is a similar canonicalization that completely removes constant induction variables from the loop, but that pattern does not apply on foralls with mappings, so this one is necessary for those cases. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 952dafb commit 1ae2446

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1767,6 +1767,31 @@ struct ForallOpSingleOrZeroIterationDimsFolder
17671767
}
17681768
};
17691769

1770+
/// Replace all induction vars with a single trip count with their lower bound.
1771+
struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1772+
using OpRewritePattern<ForallOp>::OpRewritePattern;
1773+
1774+
LogicalResult matchAndRewrite(ForallOp op,
1775+
PatternRewriter &rewriter) const override {
1776+
Location loc = op.getLoc();
1777+
bool changed = false;
1778+
for (auto [lb, ub, step, iv] :
1779+
llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1780+
op.getMixedStep(), op.getInductionVars())) {
1781+
if (iv.getUses().begin() == iv.getUses().end())
1782+
continue;
1783+
auto numIterations = constantTripCount(lb, ub, step);
1784+
if (!numIterations.has_value() || numIterations.value() != 1) {
1785+
continue;
1786+
}
1787+
rewriter.replaceAllUsesWith(
1788+
iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1789+
changed = true;
1790+
}
1791+
return success(changed);
1792+
}
1793+
};
1794+
17701795
struct FoldTensorCastOfOutputIntoForallOp
17711796
: public OpRewritePattern<scf::ForallOp> {
17721797
using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
@@ -1851,7 +1876,8 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
18511876
MLIRContext *context) {
18521877
results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
18531878
ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1854-
ForallOpSingleOrZeroIterationDimsFolder>(context);
1879+
ForallOpSingleOrZeroIterationDimsFolder,
1880+
ForallOpReplaceConstantInductionVar>(context);
18551881
}
18561882

18571883
/// Given the region at `index`, or the parent operation if `index` is None,

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1617,7 +1617,7 @@ func.func @do_not_inline_distributed_forall_loop(
16171617
%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
16181618
%cst = arith.constant 0.000000e+00 : f32
16191619
%0 = tensor.empty() : tensor<8x8xf32>
1620-
%1 = scf.forall (%i, %j) = (0, 0) to (1, 1) step (8, 8)
1620+
%1 = scf.forall (%i, %j) = (0, 4) to (1, 5) step (8, 8)
16211621
shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
16221622
%slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1]
16231623
: tensor<8x8xf32> to tensor<2x3xf32>
@@ -1632,6 +1632,8 @@ func.func @do_not_inline_distributed_forall_loop(
16321632
}
16331633
// CHECK-LABEL: @do_not_inline_distributed_forall_loop
16341634
// CHECK: scf.forall
1635+
// CHECK: tensor.extract_slice %{{.*}}[0, 4] [2, 3] [1, 1]
1636+
// CHECK: tensor.parallel_insert_slice %{{.*}}[0, 4] [2, 3] [1, 1]
16351637

16361638
// -----
16371639

0 commit comments

Comments
 (0)