Skip to content

Commit ad1083d

Browse files
author
Peiming Liu
authored
[mlir][sparse] introduce new pass to propagate sparse encodings. (llvm#92052)
1 parent 23f8fac commit ad1083d

File tree

3 files changed

+55
-0
lines changed

3 files changed

+55
-0
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ void populateSparseAssembler(RewritePatternSet &patterns, bool directOut);
6565
std::unique_ptr<Pass> createSparseAssembler();
6666
std::unique_ptr<Pass> createSparseAssembler(bool directOut);
6767

68+
//===----------------------------------------------------------------------===//
69+
// The SparseEncodingPropagation pass.
70+
//===----------------------------------------------------------------------===//
71+
72+
std::unique_ptr<Pass> createSparseEncodingPropagationPass();
73+
6874
//===----------------------------------------------------------------------===//
6975
// The SparseReinterpretMap pass.
7076
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,42 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
4040
];
4141
}
4242

43+
def SparseEncodingPropagation : Pass<"sparse-encoding-propagation", "func::FuncOp"> {
44+
let summary = "Propagate sparse tensor encodings";
45+
let description = [{
46+
A pass that propagates sparse tensor encodings.
47+
48+
Background: To avoid introducing repetitive operations, sparse tensors
49+
in MLIR try to reuse tensor operations whenever available. However, most
50+
tensor operations are canonicalized/transformed without the knowledge
51+
of sparsity. The pass tries to propagate missing sparse encodings.
52+
53+
For example:
54+
```mlir
55+
%s = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
56+
: tensor<2x3xf32, #sparse> to tensor<2x1xf32, #sparse>
57+
58+
// After rank reducing (by tensor dialect transformation)
59+
%t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
60+
: tensor<2x3xf32, #sparse> to tensor<2xf32>
61+
%s = tensor.expand_shape [[0, 1]] %t
62+
: tensor<2xf32> to tensor<2x1xf32, #sparse>
63+
64+
// After sparsity propagation
65+
%t = tensor.extract_slice %input[0, 0,] [2, 1] [1, 1]
66+
: tensor<2x3xf32, #sparse> to tensor<2xf32, #sparse1>
67+
%s = tensor.expand_shape [[0, 1]] %t
68+
: tensor<2xf32, #sparse1> to tensor<2x1xf32, #sparse>
69+
```
70+
}];
71+
72+
let constructor = "mlir::createSparseEncodingPropagationPass()";
73+
let dependentDialects = [
74+
"sparse_tensor::SparseTensorDialect",
75+
"tensor::TensorDialect",
76+
];
77+
}
78+
4379
def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
4480
let summary = "Reinterprets sparse tensor type mappings";
4581
let description = [{

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
namespace mlir {
2525
#define GEN_PASS_DEF_SPARSEASSEMBLER
26+
#define GEN_PASS_DEF_SPARSEENCODINGPROPAGATION
2627
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
2728
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
2829
#define GEN_PASS_DEF_SPARSIFICATIONPASS
@@ -60,6 +61,14 @@ struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
6061
}
6162
};
6263

64+
struct SparseEncodingPropagation
65+
: public impl::SparseEncodingPropagationBase<SparseEncodingPropagation> {
66+
SparseEncodingPropagation() = default;
67+
SparseEncodingPropagation(const SparseEncodingPropagation &pass) = default;
68+
69+
void runOnOperation() override {}
70+
};
71+
6372
struct SparseReinterpretMap
6473
: public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
6574
SparseReinterpretMap() = default;
@@ -398,6 +407,10 @@ std::unique_ptr<Pass> mlir::createSparseAssembler() {
398407
return std::make_unique<SparseAssembler>();
399408
}
400409

410+
std::unique_ptr<Pass> mlir::createSparseEncodingPropagationPass() {
411+
return std::make_unique<SparseEncodingPropagation>();
412+
}
413+
401414
std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
402415
return std::make_unique<SparseReinterpretMap>();
403416
}

0 commit comments

Comments
 (0)