Skip to content

[mlir] Add forall canonicalization to replace constant induction vars #112764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 18, 2024

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Oct 17, 2024

Adds a canonicalization pattern for scf.forall that replaces constant induction variables with a constant index. There is a similar canonicalization that completely removes constant induction variables from the loop, but that pattern does not apply on foralls with mappings, so this one is necessary for those cases.

@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

Adds a canonicalization pattern for scf.forall that replaces constant induction variables with a constant index. There is a similar canonicalization that completely removes constant induction variables from the loop, but that pattern does not apply on foralls with mappings, so this one is necessary for those cases.


Full diff: https://github.com/llvm/llvm-project/pull/112764.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+28-1)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+2)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2582d4e0df1920..7789f21af00780 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1767,6 +1767,32 @@ struct ForallOpSingleOrZeroIterationDimsFolder
   }
 };
 
+struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
+  using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForallOp op,
+                                PatternRewriter &rewriter) const override {
+    // Replace all induction vars with a single trip count with their lower
+    // bound.
+    Location loc = op.getLoc();
+    bool replacedIv = false;
+    for (auto [lb, ub, step, iv] :
+         llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
+                   op.getMixedStep(), op.getInductionVars())) {
+      if (iv.getUses().begin() == iv.getUses().end())
+        continue;
+      auto numIterations = constantTripCount(lb, ub, step);
+      if (!numIterations.has_value() || numIterations.value() != 1) {
+        continue;
+      }
+      rewriter.replaceAllUsesWith(
+          iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
+      return success();
+    }
+    return failure();
+  }
+};
+
 struct FoldTensorCastOfOutputIntoForallOp
     : public OpRewritePattern<scf::ForallOp> {
   using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
@@ -1851,7 +1877,8 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
   results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
               ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
-              ForallOpSingleOrZeroIterationDimsFolder>(context);
+              ForallOpSingleOrZeroIterationDimsFolder,
+              ForallOpReplaceConstantInductionVar>(context);
 }
 
 /// Given the region at `index`, or the parent operation if `index` is None,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c68369a8e4fce7..6f4703c04dc768 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1632,6 +1632,8 @@ func.func @do_not_inline_distributed_forall_loop(
 }
 // CHECK-LABEL: @do_not_inline_distributed_forall_loop
 // CHECK: scf.forall
+// CHECK:   tensor.extract_slice %{{.*}}[0, 0] [2, 3] [1, 1]
+// CHECK:   tensor.parallel_insert_slice %{{.*}}[0, 0] [2, 3] [1, 1]
 
 // -----
 

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one suggestion about the comment.

@Max191 Max191 force-pushed the forall-const-iv-canonicalization branch from ce38981 to 731a481 Compare October 18, 2024 13:39
@Max191 Max191 merged commit 1ae2446 into llvm:main Oct 18, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants