Skip to content

[mlir][sparse] add canonicalization patterns for IterateOp. #95569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 14, 2024

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jun 14, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/95569.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+8)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+34)
  • (modified) mlir/test/Dialect/SparseTensor/canonicalize.mlir (+19-1)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 5ae6f9f3443f8..a20de92d2d3ed 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1601,6 +1601,13 @@ def IterateOp : SparseTensor_Op<"iterate",
     BlockArgument getIterator() {
       return getRegion().getArguments().front();
     }
+    std::optional<BlockArgument> getLvlCrd(Level lvl) {
+      if (getCrdUsedLvls()[lvl]) {
+        uint64_t mask = (1 << lvl) - 1;
+        return getCrds()[llvm::popcount(mask & getCrdUsedLvls())];
+      }
+      return std::nullopt;
+    }
     Block::BlockArgListType getCrds() {
       // The first block argument is iterator, the remaining arguments are
       // referenced coordinates.
@@ -1613,6 +1620,7 @@ def IterateOp : SparseTensor_Op<"iterate",
 
   let hasVerifier = 1;
   let hasRegionVerifier = 1;
+  let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 232d25d718c65..ac711769ed2ea 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/Bitset.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
 
@@ -2266,6 +2267,39 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
+struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IterateOp iterateOp,
+                                PatternRewriter &rewriter) const override {
+    LevelSet newUsedLvls(0);
+    llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
+    for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
+      if (auto crd = iterateOp.getLvlCrd(i)) {
+        if (crd->getUsers().empty())
+          toRemove.set(crd->getArgNumber());
+        else
+          newUsedLvls.set(i);
+      }
+    }
+
+    // All coordinates are used.
+    if (toRemove.none())
+      return failure();
+
+    rewriter.startOpModification(iterateOp);
+    iterateOp.setCrdUsedLvls(newUsedLvls);
+    iterateOp.getBody()->eraseArguments(toRemove);
+    rewriter.finalizeOpModification(iterateOp);
+    return success();
+  }
+};
+
+void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+                                            mlir::MLIRContext *context) {
+  results.add<RemoveUnusedLvlCrds>(context);
+}
+
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::Argument iterator;
   OpAsmParser::UnresolvedOperand iterSpace;
diff --git a/mlir/test/Dialect/SparseTensor/canonicalize.mlir b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
index b1d3d7916c142..37b6c89f43be1 100644
--- a/mlir/test/Dialect/SparseTensor/canonicalize.mlir
+++ b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
 
 #BCOO = #sparse_tensor.encoding<{
   map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton)
@@ -21,3 +21,21 @@ func.func @sparse_slice_canonicalize(%arg0 : tensor<?x?x?xf32, #BCOO>, %arg1 : i
   %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32, #BCOO> to tensor<?x?x?xf32, #BCOO>
   return %0 : tensor<?x?x?xf32, #BCOO>
 }
+
+// -----
+
+#CSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (i : dense, j : compressed)
+}>
+
+// Make sure that the first unused coordinate is optimized.
+// CHECK-LABEL: @sparse_iterate_canonicalize
+// CHECK:         sparse_tensor.iterate {{.*}} at(_, %{{.*}})
+func.func @sparse_iterate_canonicalize(%sp : tensor<?x?xf64, #CSR>) {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 to 2
+      : tensor<?x?xf64, #CSR> -> !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
+  sparse_tensor.iterate %it1 in %l1 at (%coord0, %coord1) : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2> {
+    "test.op"(%coord1) : (index) -> ()
+  }
+  return
+}

@aartbik aartbik changed the title [mlir][sparse] add conanicalization patterns for IterateOp. [mlir][sparse] add canonicalization patterns for IterateOp. Jun 14, 2024
@PeimingLiu PeimingLiu merged commit a43d79a into llvm:main Jun 14, 2024
4 of 6 checks passed
@PeimingLiu PeimingLiu deleted the iter-canonical branch June 14, 2024 17:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants