Skip to content

Commit d37affb

Browse files
authored
[mlir][sparse] add a sparse_tensor.print operation (#83321)
This operation is mainly used for testing and debugging purposes but provides a very convenient way to quickly inspect the contents of a sparse tensor (all components over all stored levels). Example: [ [ 1, 0, 2, 0, 0, 0, 0, 0 ], [ 0, 0, 0, 0, 0, 0, 0, 0 ], [ 0, 0, 0, 0, 0, 0, 0, 0 ], [ 0, 0, 3, 4, 0, 5, 0, 0 ] when stored sparse as DCSC prints as ---- Sparse Tensor ---- nse = 5 pos[0] : ( 0, 4, ) crd[0] : ( 0, 2, 3, 5, ) pos[1] : ( 0, 1, 3, 4, 5, ) crd[1] : ( 0, 0, 3, 3, 3, ) values : ( 1, 2, 3, 4, 5, ) ----
1 parent b339c88 commit d37affb

File tree

3 files changed

+331
-1
lines changed

3 files changed

+331
-1
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,4 +1453,26 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
14531453
let hasVerifier = 1;
14541454
}
14551455

1456+
//===----------------------------------------------------------------------===//
1457+
// Sparse Tensor Debugging Operations.
1458+
//===----------------------------------------------------------------------===//
1459+
1460+
def SparseTensor_PrintOp : SparseTensor_Op<"print">,
1461+
Arguments<(ins AnySparseTensor:$tensor)> {
1462+
string summary = "Prints a sparse tensor (for testing and debugging)";
1463+
string description = [{
1464+
Prints the individual components of a sparse tensors (the positions,
1465+
coordinates, and values components) to stdout for testing and debugging
1466+
purposes. This operation lowers to just a few primitives in a light-weight
1467+
runtime support to simplify supporting this operation on new platforms.
1468+
1469+
Example:
1470+
1471+
```mlir
1472+
sparse_tensor.print %tensor : tensor<1024x1024xf64, #CSR>
1473+
```
1474+
}];
1475+
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
1476+
}
1477+
14561478
#endif // SPARSETENSOR_OPS

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2222
#include "mlir/Dialect/SCF/IR/SCF.h"
2323
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24+
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
2425
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
2526
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
2627
#include "mlir/Dialect/Tensor/IR/Tensor.h"
28+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2729
#include "mlir/IR/AffineMap.h"
2830
#include "mlir/IR/Matchers.h"
2931
#include "mlir/Support/LLVM.h"
@@ -598,6 +600,96 @@ struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
598600
}
599601
};
600602

603+
/// Sparse rewriting rule for the print operator. This operation is mainly used
604+
/// for debugging and testing. As such, it lowers to the vector.print operation
605+
/// which only require very light-weight runtime support.
606+
struct PrintRewriter : public OpRewritePattern<PrintOp> {
607+
public:
608+
using OpRewritePattern::OpRewritePattern;
609+
LogicalResult matchAndRewrite(PrintOp op,
610+
PatternRewriter &rewriter) const override {
611+
Location loc = op.getLoc();
612+
auto tensor = op.getTensor();
613+
auto stt = getSparseTensorType(tensor);
614+
// Header with NSE.
615+
auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
616+
rewriter.create<vector::PrintOp>(
617+
loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
618+
rewriter.create<vector::PrintOp>(loc, nse);
619+
// Use the "codegen" foreach loop construct to iterate over
620+
// all typical sparse tensor components for printing.
621+
foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc,
622+
&tensor](Type tp, FieldIndex,
623+
SparseTensorFieldKind kind,
624+
Level l, LevelType) {
625+
switch (kind) {
626+
case SparseTensorFieldKind::StorageSpec: {
627+
break;
628+
}
629+
case SparseTensorFieldKind::PosMemRef: {
630+
auto lvl = constantIndex(rewriter, loc, l);
631+
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
632+
rewriter.create<vector::PrintOp>(
633+
loc, lvl, vector::PrintPunctuation::NoPunctuation);
634+
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
635+
auto pos = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
636+
printContents(rewriter, loc, tp, pos);
637+
break;
638+
}
639+
case SparseTensorFieldKind::CrdMemRef: {
640+
auto lvl = constantIndex(rewriter, loc, l);
641+
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
642+
rewriter.create<vector::PrintOp>(
643+
loc, lvl, vector::PrintPunctuation::NoPunctuation);
644+
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
645+
auto crd = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
646+
printContents(rewriter, loc, tp, crd);
647+
break;
648+
}
649+
case SparseTensorFieldKind::ValMemRef: {
650+
rewriter.create<vector::PrintOp>(loc,
651+
rewriter.getStringAttr("values : "));
652+
auto val = rewriter.create<ToValuesOp>(loc, tp, tensor);
653+
printContents(rewriter, loc, tp, val);
654+
break;
655+
}
656+
}
657+
return true;
658+
});
659+
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
660+
rewriter.eraseOp(op);
661+
return success();
662+
}
663+
664+
private:
665+
// Helper to print contents of a single memref. Note that for the "push_back"
666+
// vectors, this prints the full capacity, not just the size. This is done
667+
// on purpose, so that clients see how much storage has been allocated in
668+
// total. Contents of the extra capacity in the buffer may be uninitialized
669+
// (unless the flag enable-buffer-initialization is set to true).
670+
//
671+
// Generates code to print:
672+
// ( a0, a1, ... )
673+
static void printContents(PatternRewriter &rewriter, Location loc, Type tp,
674+
Value vec) {
675+
// Open bracket.
676+
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
677+
// For loop over elements.
678+
auto zero = constantIndex(rewriter, loc, 0);
679+
auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
680+
auto step = constantIndex(rewriter, loc, 1);
681+
auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
682+
rewriter.setInsertionPointToStart(forOp.getBody());
683+
auto idx = forOp.getInductionVar();
684+
auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
685+
rewriter.create<vector::PrintOp>(loc, val, vector::PrintPunctuation::Comma);
686+
rewriter.setInsertionPointAfter(forOp);
687+
// Close bracket and end of line.
688+
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
689+
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
690+
}
691+
};
692+
601693
/// Sparse rewriting rule for sparse-to-sparse reshape operator.
602694
struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
603695
public:
@@ -1284,7 +1376,8 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
12841376

12851377
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
12861378
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1287-
GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
1379+
GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1380+
patterns.getContext());
12881381
}
12891382

12901383
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
//--------------------------------------------------------------------------------------------------
2+
// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
3+
//
4+
// Set-up that's shared across all tests in this directory. In principle, this
5+
// config could be moved to lit.local.cfg. However, there are downstream users that
6+
// do not use these LIT config files. Hence why this is kept inline.
7+
//
8+
// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
9+
// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
10+
// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
11+
// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
12+
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
13+
// DEFINE: %{run_opts} = -e main -entry-point-result=void
14+
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
15+
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
16+
//
17+
// DEFINE: %{env} =
18+
//--------------------------------------------------------------------------------------------------
19+
20+
// RUN: %{compile} | %{run} | FileCheck %s
21+
//
22+
// Do the same run, but now with direct IR generation.
23+
// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
24+
// RUN: %{compile} | %{run} | FileCheck %s
25+
//
26+
27+
#AllDense = #sparse_tensor.encoding<{
28+
map = (i, j) -> (
29+
i : dense,
30+
j : dense
31+
)
32+
}>
33+
34+
#AllDenseT = #sparse_tensor.encoding<{
35+
map = (i, j) -> (
36+
j : dense,
37+
i : dense
38+
)
39+
}>
40+
41+
#CSR = #sparse_tensor.encoding<{
42+
map = (i, j) -> (
43+
i : dense,
44+
j : compressed
45+
)
46+
}>
47+
48+
#DCSR = #sparse_tensor.encoding<{
49+
map = (i, j) -> (
50+
i : compressed,
51+
j : compressed
52+
)
53+
}>
54+
55+
#CSC = #sparse_tensor.encoding<{
56+
map = (i, j) -> (
57+
j : dense,
58+
i : compressed
59+
)
60+
}>
61+
62+
#DCSC = #sparse_tensor.encoding<{
63+
map = (i, j) -> (
64+
j : compressed,
65+
i : compressed
66+
)
67+
}>
68+
69+
#BSR = #sparse_tensor.encoding<{
70+
map = (i, j) -> (
71+
i floordiv 2 : compressed,
72+
j floordiv 4 : compressed,
73+
i mod 2 : dense,
74+
j mod 4 : dense
75+
)
76+
}>
77+
78+
#BSRC = #sparse_tensor.encoding<{
79+
map = (i, j) -> (
80+
i floordiv 2 : compressed,
81+
j floordiv 4 : compressed,
82+
j mod 4 : dense,
83+
i mod 2 : dense
84+
)
85+
}>
86+
87+
#BSC = #sparse_tensor.encoding<{
88+
map = (i, j) -> (
89+
j floordiv 4 : compressed,
90+
i floordiv 2 : compressed,
91+
i mod 2 : dense,
92+
j mod 4 : dense
93+
)
94+
}>
95+
96+
#BSCC = #sparse_tensor.encoding<{
97+
map = (i, j) -> (
98+
j floordiv 4 : compressed,
99+
i floordiv 2 : compressed,
100+
j mod 4 : dense,
101+
i mod 2 : dense
102+
)
103+
}>
104+
105+
module {
106+
107+
//
108+
// Main driver that tests sparse tensor storage.
109+
//
110+
func.func @main() {
111+
%x = arith.constant dense <[
112+
[ 1, 0, 2, 0, 0, 0, 0, 0 ],
113+
[ 0, 0, 0, 0, 0, 0, 0, 0 ],
114+
[ 0, 0, 0, 0, 0, 0, 0, 0 ],
115+
[ 0, 0, 3, 4, 0, 5, 0, 0 ] ]> : tensor<4x8xi32>
116+
117+
%a = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #CSR>
118+
%b = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #DCSR>
119+
%c = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #CSC>
120+
%d = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #DCSC>
121+
%e = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSR>
122+
%f = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSRC>
123+
%g = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSC>
124+
%h = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSCC>
125+
126+
//
127+
// CHECK: ---- Sparse Tensor ----
128+
// CHECK-NEXT: nse = 5
129+
// CHECK-NEXT: pos[1] : ( 0, 2, 2, 2, 5,
130+
// CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
131+
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
132+
// CHECK-NEXT: ----
133+
sparse_tensor.print %a : tensor<4x8xi32, #CSR>
134+
135+
// CHECK-NEXT: ---- Sparse Tensor ----
136+
// CHECK-NEXT: nse = 5
137+
// CHECK-NEXT: pos[0] : ( 0, 2,
138+
// CHECK-NEXT: crd[0] : ( 0, 3,
139+
// CHECK-NEXT: pos[1] : ( 0, 2, 5,
140+
// CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
141+
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
142+
// CHECK-NEXT: ----
143+
sparse_tensor.print %b : tensor<4x8xi32, #DCSR>
144+
145+
// CHECK-NEXT: ---- Sparse Tensor ----
146+
// CHECK-NEXT: nse = 5
147+
// CHECK-NEXT: pos[1] : ( 0, 1, 1, 3, 4, 4, 5, 5, 5,
148+
// CHECK-NEXT: crd[1] : ( 0, 0, 3, 3, 3,
149+
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
150+
// CHECK-NEXT: ----
151+
sparse_tensor.print %c : tensor<4x8xi32, #CSC>
152+
153+
// CHECK-NEXT: ---- Sparse Tensor ----
154+
// CHECK-NEXT: nse = 5
155+
// CHECK-NEXT: pos[0] : ( 0, 4,
156+
// CHECK-NEXT: crd[0] : ( 0, 2, 3, 5,
157+
// CHECK-NEXT: pos[1] : ( 0, 1, 3, 4, 5,
158+
// CHECK-NEXT: crd[1] : ( 0, 0, 3, 3, 3,
159+
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
160+
// CHECK-NEXT: ----
161+
sparse_tensor.print %d : tensor<4x8xi32, #DCSC>
162+
163+
// CHECK-NEXT: ---- Sparse Tensor ----
164+
// CHECK-NEXT: nse = 24
165+
// CHECK-NEXT: pos[0] : ( 0, 2,
166+
// CHECK-NEXT: crd[0] : ( 0, 1,
167+
// CHECK-NEXT: pos[1] : ( 0, 1, 3,
168+
// CHECK-NEXT: crd[1] : ( 0, 0, 1,
169+
// CHECK-NEXT: values : ( 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 5, 0, 0,
170+
// CHECK-NEXT: ----
171+
sparse_tensor.print %e : tensor<4x8xi32, #BSR>
172+
173+
// CHECK-NEXT: ---- Sparse Tensor ----
174+
// CHECK-NEXT: nse = 24
175+
// CHECK-NEXT: pos[0] : ( 0, 2,
176+
// CHECK-NEXT: crd[0] : ( 0, 1,
177+
// CHECK-NEXT: pos[1] : ( 0, 1, 3,
178+
// CHECK-NEXT: crd[1] : ( 0, 0, 1,
179+
// CHECK-NEXT: values : ( 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 5, 0, 0, 0, 0,
180+
// CHECK-NEXT: ----
181+
sparse_tensor.print %f : tensor<4x8xi32, #BSRC>
182+
183+
// CHECK-NEXT: ---- Sparse Tensor ----
184+
// CHECK-NEXT: nse = 24
185+
// CHECK-NEXT: pos[0] : ( 0, 2,
186+
// CHECK-NEXT: crd[0] : ( 0, 1,
187+
// CHECK-NEXT: pos[1] : ( 0, 2, 3,
188+
// CHECK-NEXT: crd[1] : ( 0, 1, 1,
189+
// CHECK-NEXT: values : ( 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 5, 0, 0,
190+
// CHECK-NEXT: ----
191+
sparse_tensor.print %g : tensor<4x8xi32, #BSC>
192+
193+
// CHECK-NEXT: ---- Sparse Tensor ----
194+
// CHECK-NEXT: nse = 24
195+
// CHECK-NEXT: pos[0] : ( 0, 2,
196+
// CHECK-NEXT: crd[0] : ( 0, 1,
197+
// CHECK-NEXT: pos[1] : ( 0, 2, 3,
198+
// CHECK-NEXT: crd[1] : ( 0, 1, 1,
199+
// CHECK-NEXT: values : ( 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 5, 0, 0, 0, 0,
200+
// CHECK-NEXT: ----
201+
sparse_tensor.print %h : tensor<4x8xi32, #BSCC>
202+
203+
// Release the resources.
204+
bufferization.dealloc_tensor %a : tensor<4x8xi32, #CSR>
205+
bufferization.dealloc_tensor %b : tensor<4x8xi32, #DCSR>
206+
bufferization.dealloc_tensor %c : tensor<4x8xi32, #CSC>
207+
bufferization.dealloc_tensor %d : tensor<4x8xi32, #DCSC>
208+
bufferization.dealloc_tensor %e : tensor<4x8xi32, #BSR>
209+
bufferization.dealloc_tensor %f : tensor<4x8xi32, #BSRC>
210+
bufferization.dealloc_tensor %g : tensor<4x8xi32, #BSC>
211+
bufferization.dealloc_tensor %h : tensor<4x8xi32, #BSCC>
212+
213+
return
214+
}
215+
}

0 commit comments

Comments
 (0)