Skip to content

Commit 7f71eba

Browse files
author
Peiming Liu
committed
[mlir][sparse] introduce a pass to stage complex sparse operations into simple steps
1 parent e9fa188 commit 7f71eba

File tree

5 files changed

+43
-0
lines changed

5 files changed

+43
-0
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ std::unique_ptr<Pass> createSparsificationPass();
8787
std::unique_ptr<Pass>
8888
createSparsificationPass(const SparsificationOptions &options);
8989

90+
//===----------------------------------------------------------------------===//
91+
// The StageSparseOperations pass.
92+
//===----------------------------------------------------------------------===//
93+
94+
/// Sets up StageSparseOperation rewriting rules.
95+
void populateStageSparseOperationsPatterns(RewritePatternSet &patterns);
96+
97+
std::unique_ptr<Pass> createStageSparseOperationsPass();
98+
9099
//===----------------------------------------------------------------------===//
91100
// The PostSparsificationRewriting pass.
92101
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,18 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
123123
];
124124
}
125125

126+
def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> {
127+
let summary = "Decompose a complex sparse operations into multiple stages";
128+
let description = [{
129+
A pass that decomposes a complex sparse operations into multiple stages.
130+
E.g., CSR -> CSC conversion is staged into CSR -> COO (unordered) -> sort -> CSC.
131+
}];
132+
let constructor = "mlir::createStageSparseOperationsPass()";
133+
let dependentDialects = [
134+
"sparse_tensor::SparseTensorDialect",
135+
];
136+
}
137+
126138
def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
127139
let summary = "Applies sparse tensor rewriting rules after sparsification";
128140
let description = [{

mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
1414
SparseVectorization.cpp
1515
Sparsification.cpp
1616
SparsificationAndBufferizationPass.cpp
17+
StageSparseOperations.cpp
1718

1819
ADDITIONAL_HEADER_DIRS
1920
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace mlir {
3030
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
3131
#define GEN_PASS_DEF_SPARSEVECTORIZATION
3232
#define GEN_PASS_DEF_SPARSEGPUCODEGEN
33+
#define GEN_PASS_DEF_STAGESPARSEOPERATIONS
3334
#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
3435
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
3536
} // namespace mlir
@@ -92,6 +93,18 @@ struct SparsificationPass
9293
}
9394
};
9495

96+
struct StageSparseOperationsPass
97+
: public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
98+
StageSparseOperationsPass() = default;
99+
StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
100+
void runOnOperation() override {
101+
auto *ctx = &getContext();
102+
RewritePatternSet patterns(ctx);
103+
populateStageSparseOperationsPatterns(patterns);
104+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
105+
}
106+
};
107+
95108
struct PostSparsificationRewritePass
96109
: public impl::PostSparsificationRewriteBase<
97110
PostSparsificationRewritePass> {
@@ -384,6 +397,10 @@ mlir::createSparsificationPass(const SparsificationOptions &options) {
384397
return std::make_unique<SparsificationPass>(options);
385398
}
386399

400+
std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
401+
return std::make_unique<StageSparseOperationsPass>();
402+
}
403+
387404
std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
388405
return std::make_unique<PostSparsificationRewritePass>();
389406
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
2+
3+
void mlir::populateStageSparseOperationsPatterns(
4+
RewritePatternSet & /*patterns*/) {}

0 commit comments

Comments
 (0)