1
- // ===- SparseTensorLevel .cpp - Tensor management class -------------------===//
1
+ // ===- SparseTensorIterator .cpp ------------------------ -------------------===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
- #include " SparseTensorLevel .h"
9
+ #include " SparseTensorIterator .h"
10
10
#include " CodegenUtils.h"
11
11
12
12
#include " mlir/Dialect/MemRef/IR/MemRef.h"
@@ -46,21 +46,41 @@ using ValueTuple = std::tuple<Value, Value, Value>;
46
46
47
47
namespace {
48
48
49
+ template <bool hasPosBuffer>
49
50
class SparseLevel : public SparseTensorLevel {
51
+ // It is either an array of size 2 or size 1 depending on whether the sparse
52
+ // level requires a position array.
53
+ using BufferT = std::conditional_t <hasPosBuffer, std::array<Value, 2 >,
54
+ std::array<Value, 1 >>;
55
+
50
56
public:
51
57
SparseLevel (unsigned tid, Level lvl, LevelType lt, Value lvlSize,
52
- Value crdBuffer)
53
- : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
58
+ BufferT buffers)
59
+ : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
60
+
61
+ ValueRange getLvlBuffers () const override { return buffers; }
54
62
55
63
Value peekCrdAt (OpBuilder &b, Location l, ValueRange batchPrefix,
56
64
Value iv) const override {
57
65
SmallVector<Value> memCrd (batchPrefix);
58
66
memCrd.push_back (iv);
59
- return genIndexLoad (b, l, crdBuffer , memCrd);
67
+ return genIndexLoad (b, l, getCrdBuf () , memCrd);
60
68
}
61
69
62
70
protected:
63
- const Value crdBuffer;
71
+ template <typename T = void , typename = std::enable_if_t <hasPosBuffer, T>>
72
+ Value getPosBuf () const {
73
+ return buffers[0 ];
74
+ }
75
+
76
+ Value getCrdBuf () const {
77
+ if constexpr (hasPosBuffer)
78
+ return buffers[1 ];
79
+ else
80
+ return buffers[0 ];
81
+ }
82
+
83
+ const BufferT buffers;
64
84
};
65
85
66
86
class DenseLevel : public SparseTensorLevel {
@@ -72,6 +92,8 @@ class DenseLevel : public SparseTensorLevel {
72
92
llvm_unreachable (" locate random-accessible level instead" );
73
93
}
74
94
95
+ ValueRange getLvlBuffers () const override { return {}; }
96
+
75
97
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange, Value p,
76
98
Value max) const override {
77
99
Value posLo = MULI (p, lvlSize);
@@ -88,6 +110,8 @@ class BatchLevel : public SparseTensorLevel {
88
110
llvm_unreachable (" locate random-accessible level instead" );
89
111
}
90
112
113
+ ValueRange getLvlBuffers () const override { return {}; }
114
+
91
115
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange, Value p,
92
116
Value max) const override {
93
117
assert (max == nullptr && " Dense level can not be non-unique." );
@@ -96,11 +120,11 @@ class BatchLevel : public SparseTensorLevel {
96
120
}
97
121
};
98
122
99
- class CompressedLevel : public SparseLevel {
123
+ class CompressedLevel : public SparseLevel < /* hasPosBuf= */ true > {
100
124
public:
101
125
CompressedLevel (unsigned tid, Level lvl, LevelType lt, Value lvlSize,
102
126
Value posBuffer, Value crdBuffer)
103
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer ) {}
127
+ : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer} ) {}
104
128
105
129
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
106
130
Value p, Value max) const override {
@@ -109,21 +133,18 @@ class CompressedLevel : public SparseLevel {
109
133
110
134
SmallVector<Value> memCrd (batchPrefix);
111
135
memCrd.push_back (p);
112
- Value pLo = genIndexLoad (b, l, posBuffer , memCrd);
136
+ Value pLo = genIndexLoad (b, l, getPosBuf () , memCrd);
113
137
memCrd.back () = ADDI (p, C_IDX (1 ));
114
- Value pHi = genIndexLoad (b, l, posBuffer , memCrd);
138
+ Value pHi = genIndexLoad (b, l, getPosBuf () , memCrd);
115
139
return {pLo, pHi};
116
140
}
117
-
118
- private:
119
- const Value posBuffer;
120
141
};
121
142
122
- class LooseCompressedLevel : public SparseLevel {
143
+ class LooseCompressedLevel : public SparseLevel < /* hasPosBuf= */ true > {
123
144
public:
124
145
LooseCompressedLevel (unsigned tid, Level lvl, LevelType lt, Value lvlSize,
125
146
Value posBuffer, Value crdBuffer)
126
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer ) {}
147
+ : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer} ) {}
127
148
128
149
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
129
150
Value p, Value max) const override {
@@ -133,21 +154,18 @@ class LooseCompressedLevel : public SparseLevel {
133
154
134
155
p = MULI (p, C_IDX (2 ));
135
156
memCrd.push_back (p);
136
- Value pLo = genIndexLoad (b, l, posBuffer , memCrd);
157
+ Value pLo = genIndexLoad (b, l, getPosBuf () , memCrd);
137
158
memCrd.back () = ADDI (p, C_IDX (1 ));
138
- Value pHi = genIndexLoad (b, l, posBuffer , memCrd);
159
+ Value pHi = genIndexLoad (b, l, getPosBuf () , memCrd);
139
160
return {pLo, pHi};
140
161
}
141
-
142
- private:
143
- const Value posBuffer;
144
162
};
145
163
146
- class SingletonLevel : public SparseLevel {
164
+ class SingletonLevel : public SparseLevel < /* hasPosBuf= */ false > {
147
165
public:
148
166
SingletonLevel (unsigned tid, Level lvl, LevelType lt, Value lvlSize,
149
167
Value crdBuffer)
150
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
168
+ : SparseLevel(tid, lvl, lt, lvlSize, { crdBuffer} ) {}
151
169
152
170
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
153
171
Value p, Value segHi) const override {
@@ -159,11 +177,11 @@ class SingletonLevel : public SparseLevel {
159
177
}
160
178
};
161
179
162
- class NOutOfMLevel : public SparseLevel {
180
+ class NOutOfMLevel : public SparseLevel < /* hasPosBuf= */ false > {
163
181
public:
164
182
NOutOfMLevel (unsigned tid, Level lvl, LevelType lt, Value lvlSize,
165
183
Value crdBuffer)
166
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
184
+ : SparseLevel(tid, lvl, lt, lvlSize, { crdBuffer} ) {}
167
185
168
186
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
169
187
Value p, Value max) const override {
0 commit comments