Skip to content

Commit a454d92

Browse files
author
Peiming Liu
authored
[mlir][sparse] rename files and unifies APIs (#88162)
1 parent f48895a commit a454d92

File tree

4 files changed

+49
-30
lines changed

4 files changed

+49
-30
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
2020
Utils/IterationGraphSorter.cpp
2121
Utils/LoopEmitter.cpp
2222
Utils/SparseTensorDescriptor.cpp
23-
Utils/SparseTensorLevel.cpp
23+
Utils/SparseTensorIterator.cpp
2424

2525
ADDITIONAL_HEADER_DIRS
2626
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
#include <vector>
1313

14-
#include "SparseTensorLevel.h"
14+
#include "SparseTensorIterator.h"
1515

1616
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
1717
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp renamed to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
//===- SparseTensorLevel.cpp - Tensor management class -------------------===//
1+
//===- SparseTensorIterator.cpp -------------------------------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "SparseTensorLevel.h"
9+
#include "SparseTensorIterator.h"
1010
#include "CodegenUtils.h"
1111

1212
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -46,21 +46,41 @@ using ValueTuple = std::tuple<Value, Value, Value>;
4646

4747
namespace {
4848

49+
template <bool hasPosBuffer>
4950
class SparseLevel : public SparseTensorLevel {
51+
// It is either an array of size 2 or size 1 depending on whether the sparse
52+
// level requires a position array.
53+
using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
54+
std::array<Value, 1>>;
55+
5056
public:
5157
SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
52-
Value crdBuffer)
53-
: SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
58+
BufferT buffers)
59+
: SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
60+
61+
ValueRange getLvlBuffers() const override { return buffers; }
5462

5563
Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
5664
Value iv) const override {
5765
SmallVector<Value> memCrd(batchPrefix);
5866
memCrd.push_back(iv);
59-
return genIndexLoad(b, l, crdBuffer, memCrd);
67+
return genIndexLoad(b, l, getCrdBuf(), memCrd);
6068
}
6169

6270
protected:
63-
const Value crdBuffer;
71+
template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
72+
Value getPosBuf() const {
73+
return buffers[0];
74+
}
75+
76+
Value getCrdBuf() const {
77+
if constexpr (hasPosBuffer)
78+
return buffers[1];
79+
else
80+
return buffers[0];
81+
}
82+
83+
const BufferT buffers;
6484
};
6585

6686
class DenseLevel : public SparseTensorLevel {
@@ -72,6 +92,8 @@ class DenseLevel : public SparseTensorLevel {
7292
llvm_unreachable("locate random-accessible level instead");
7393
}
7494

95+
ValueRange getLvlBuffers() const override { return {}; }
96+
7597
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
7698
Value max) const override {
7799
Value posLo = MULI(p, lvlSize);
@@ -88,6 +110,8 @@ class BatchLevel : public SparseTensorLevel {
88110
llvm_unreachable("locate random-accessible level instead");
89111
}
90112

113+
ValueRange getLvlBuffers() const override { return {}; }
114+
91115
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
92116
Value max) const override {
93117
assert(max == nullptr && "Dense level can not be non-unique.");
@@ -96,11 +120,11 @@ class BatchLevel : public SparseTensorLevel {
96120
}
97121
};
98122

99-
class CompressedLevel : public SparseLevel {
123+
class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
100124
public:
101125
CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
102126
Value posBuffer, Value crdBuffer)
103-
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
127+
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
104128

105129
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
106130
Value p, Value max) const override {
@@ -109,21 +133,18 @@ class CompressedLevel : public SparseLevel {
109133

110134
SmallVector<Value> memCrd(batchPrefix);
111135
memCrd.push_back(p);
112-
Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
136+
Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
113137
memCrd.back() = ADDI(p, C_IDX(1));
114-
Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
138+
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
115139
return {pLo, pHi};
116140
}
117-
118-
private:
119-
const Value posBuffer;
120141
};
121142

122-
class LooseCompressedLevel : public SparseLevel {
143+
class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
123144
public:
124145
LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
125146
Value posBuffer, Value crdBuffer)
126-
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
147+
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
127148

128149
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
129150
Value p, Value max) const override {
@@ -133,21 +154,18 @@ class LooseCompressedLevel : public SparseLevel {
133154

134155
p = MULI(p, C_IDX(2));
135156
memCrd.push_back(p);
136-
Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
157+
Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
137158
memCrd.back() = ADDI(p, C_IDX(1));
138-
Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
159+
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
139160
return {pLo, pHi};
140161
}
141-
142-
private:
143-
const Value posBuffer;
144162
};
145163

146-
class SingletonLevel : public SparseLevel {
164+
class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
147165
public:
148166
SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
149167
Value crdBuffer)
150-
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
168+
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
151169

152170
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
153171
Value p, Value segHi) const override {
@@ -159,11 +177,11 @@ class SingletonLevel : public SparseLevel {
159177
}
160178
};
161179

162-
class NOutOfMLevel : public SparseLevel {
180+
class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
163181
public:
164182
NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
165183
Value crdBuffer)
166-
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
184+
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
167185

168186
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
169187
Value p, Value max) const override {

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h renamed to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
//===- SparseTensorLevel.h --------------------------------------*- C++ -*-===//
1+
//===- SparseTensorIterator.h ---------------------------------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
10-
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
9+
#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
10+
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
1111

1212
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1313
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -55,6 +55,7 @@ class SparseTensorLevel {
5555
Level getLevel() const { return lvl; }
5656
LevelType getLT() const { return lt; }
5757
Value getSize() const { return lvlSize; }
58+
virtual ValueRange getLvlBuffers() const = 0;
5859

5960
//
6061
// Level properties
@@ -321,4 +322,4 @@ std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
321322
} // namespace sparse_tensor
322323
} // namespace mlir
323324

324-
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
325+
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_

0 commit comments

Comments
 (0)