-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir] Add forall canonicalization to replace constant induction vars #112764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: None (Max191) ChangesAdds a canonicalization pattern for scf.forall that replaces constant induction variables with a constant index. There is a similar canonicalization that completely removes constant induction variables from the loop, but that pattern does not apply on foralls with mappings, so this one is necessary for those cases. Full diff: https://github.com/llvm/llvm-project/pull/112764.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2582d4e0df1920..7789f21af00780 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1767,6 +1767,32 @@ struct ForallOpSingleOrZeroIterationDimsFolder
}
};
+struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
+ using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForallOp op,
+ PatternRewriter &rewriter) const override {
+ // Replace all induction vars with a single trip count with their lower
+ // bound.
+ Location loc = op.getLoc();
+ bool replacedIv = false;
+ for (auto [lb, ub, step, iv] :
+ llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
+ op.getMixedStep(), op.getInductionVars())) {
+ if (iv.getUses().begin() == iv.getUses().end())
+ continue;
+ auto numIterations = constantTripCount(lb, ub, step);
+ if (!numIterations.has_value() || numIterations.value() != 1) {
+ continue;
+ }
+ rewriter.replaceAllUsesWith(
+ iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
+ return success();
+ }
+ return failure();
+ }
+};
+
struct FoldTensorCastOfOutputIntoForallOp
: public OpRewritePattern<scf::ForallOp> {
using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
@@ -1851,7 +1877,8 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
- ForallOpSingleOrZeroIterationDimsFolder>(context);
+ ForallOpSingleOrZeroIterationDimsFolder,
+ ForallOpReplaceConstantInductionVar>(context);
}
/// Given the region at `index`, or the parent operation if `index` is None,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c68369a8e4fce7..6f4703c04dc768 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1632,6 +1632,8 @@ func.func @do_not_inline_distributed_forall_loop(
}
// CHECK-LABEL: @do_not_inline_distributed_forall_loop
// CHECK: scf.forall
+// CHECK: tensor.extract_slice %{{.*}}[0, 0] [2, 3] [1, 1]
+// CHECK: tensor.parallel_insert_slice %{{.*}}[0, 0] [2, 3] [1, 1]
// -----
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just one suggestion about the comment.
Signed-off-by: Max Dawkins <[email protected]>
Signed-off-by: Max Dawkins <[email protected]>
Signed-off-by: Max Dawkins <[email protected]>
Signed-off-by: Max Dawkins <[email protected]>
ce38981
to
731a481
Compare
Adds a canonicalization pattern for scf.forall that replaces constant induction variables with a constant index. There is a similar canonicalization that completely removes constant induction variables from the loop, but that pattern does not apply on foralls with mappings, so this one is necessary for those cases.