Skip to content

Commit bfc234a

Browse files
author
Peiming Liu
committed
[mlir][sparse] implement lowering rules for ExtractIterSpaceOp.
1 parent 1f2857b commit bfc234a

File tree

10 files changed

+325
-35
lines changed

10 files changed

+325
-35
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,10 @@ struct LevelType {
357357
return hasSparseSemantic();
358358
}
359359

360+
constexpr unsigned getNumBuffer() const {
361+
return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
362+
}
363+
360364
std::string toMLIRString() const {
361365
std::string lvlStr = toFormatString(getLvlFmt());
362366
std::string propStr = "";

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/IR/PatternMatch.h"
1717
#include "mlir/Pass/Pass.h"
1818
#include "mlir/Transforms/DialectConversion.h"
19+
#include "mlir/Transforms/OneToNTypeConversion.h"
1920

2021
//===----------------------------------------------------------------------===//
2122
// Include the generated pass header (which needs some early definitions).
@@ -149,6 +150,20 @@ void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
149150

150151
std::unique_ptr<Pass> createLowerForeachToSCFPass();
151152

153+
//===----------------------------------------------------------------------===//
154+
// The LowerSparseIterationToSCF pass.
155+
//===----------------------------------------------------------------------===//
156+
157+
/// Type converter for iter_space and iterator.
158+
struct SparseIterationTypeConverter : public OneToNTypeConverter {
159+
SparseIterationTypeConverter();
160+
};
161+
162+
void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
163+
RewritePatternSet &patterns);
164+
165+
std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
166+
152167
//===----------------------------------------------------------------------===//
153168
// The SparseTensorConversion pass.
154169
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,4 +516,19 @@ def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
516516
];
517517
}
518518

519+
def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::FuncOp"> {
520+
let summary = "lower sparse_tensor.iterate/coiterate into scf loops";
521+
let description = [{
522+
This pass lowers `sparse_tensor.iterate` operations into `scf.for/while` operations.
523+
The pass is not yet stablized.
524+
}];
525+
let constructor = "mlir::createLowerSparseIterationToSCFPass()";
526+
let dependentDialects = [
527+
"memref::MemRefDialect",
528+
"scf::SCFDialect",
529+
"sparse_tensor::SparseTensorDialect",
530+
];
531+
}
532+
533+
519534
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
44
SparseAssembler.cpp
55
SparseBufferRewriting.cpp
66
SparseGPUCodegen.cpp
7+
SparseIterationToScf.cpp
78
SparseReinterpretMap.cpp
89
SparseStorageSpecifierToLLVM.cpp
910
SparseSpaceCollapse.cpp
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
2+
#include "Utils/CodegenUtils.h"
3+
#include "Utils/SparseTensorIterator.h"
4+
5+
#include "mlir/Dialect/SCF/IR/SCF.h"
6+
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
7+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
8+
#include "mlir/Transforms/OneToNTypeConversion.h"
9+
10+
using namespace mlir;
11+
using namespace mlir::sparse_tensor;
12+
13+
void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
14+
SmallVectorImpl<Type> &fields) {
15+
// Position and coordinate buffer in the sparse structure.
16+
if (enc.getLvlType(lvl).isWithPosLT())
17+
fields.push_back(enc.getPosMemRefType());
18+
if (enc.getLvlType(lvl).isWithCrdLT())
19+
fields.push_back(enc.getCrdMemRefType());
20+
// One index for shape bound (result from lvlOp)
21+
fields.push_back(IndexType::get(enc.getContext()));
22+
}
23+
24+
static std::optional<LogicalResult>
25+
convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
26+
27+
auto idxTp = IndexType::get(itSp.getContext());
28+
for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
29+
convertLevelType(itSp.getEncoding(), l, fields);
30+
31+
// Two indices for lower and upper bound (we only need one pair for the last
32+
// iteration space).
33+
fields.append({idxTp, idxTp});
34+
return success();
35+
}
36+
37+
namespace {
38+
39+
/// Sparse codegen rule for number of entries operator.
40+
class ExtractIterSpaceConverter
41+
: public OneToNOpConversionPattern<ExtractIterSpaceOp> {
42+
public:
43+
using OneToNOpConversionPattern::OneToNOpConversionPattern;
44+
LogicalResult
45+
matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
46+
OneToNPatternRewriter &rewriter) const override {
47+
Location loc = op.getLoc();
48+
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
49+
50+
// Construct the iteration space.
51+
SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
52+
op.getLvlRange(), adaptor.getParentIter());
53+
54+
SmallVector<Value> result = space.toValues();
55+
rewriter.replaceOp(op, result, resultMapping);
56+
return success();
57+
}
58+
};
59+
60+
} // namespace
61+
62+
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
63+
addConversion([](Type type) { return type; });
64+
addConversion(convertIterSpaceType);
65+
66+
addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
67+
ValueRange inputs,
68+
Location loc) -> std::optional<Value> {
69+
return builder
70+
.create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
71+
.getResult(0);
72+
});
73+
}
74+
75+
void mlir::populateLowerSparseIterationToSCFPatterns(
76+
TypeConverter &converter, RewritePatternSet &patterns) {
77+
patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
78+
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace mlir {
2727
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
2828
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
2929
#define GEN_PASS_DEF_SPARSIFICATIONPASS
30+
#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
3031
#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
3132
#define GEN_PASS_DEF_LOWERFOREACHTOSCF
3233
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
@@ -162,10 +163,34 @@ struct LowerForeachToSCFPass
162163
auto *ctx = &getContext();
163164
RewritePatternSet patterns(ctx);
164165
populateLowerForeachToSCFPatterns(patterns);
166+
165167
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
166168
}
167169
};
168170

171+
struct LowerSparseIterationToSCFPass
172+
: public impl::LowerSparseIterationToSCFBase<
173+
LowerSparseIterationToSCFPass> {
174+
LowerSparseIterationToSCFPass() = default;
175+
LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
176+
default;
177+
178+
void runOnOperation() override {
179+
auto *ctx = &getContext();
180+
RewritePatternSet patterns(ctx);
181+
SparseIterationTypeConverter converter;
182+
ConversionTarget target(*ctx);
183+
184+
// The actual conversion.
185+
target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
186+
populateLowerSparseIterationToSCFPatterns(converter, patterns);
187+
188+
if (failed(applyPartialOneToNConversion(getOperation(), converter,
189+
std::move(patterns))))
190+
signalPassFailure();
191+
}
192+
};
193+
169194
struct SparseTensorConversionPass
170195
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
171196
SparseTensorConversionPass() = default;
@@ -452,6 +477,10 @@ std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
452477
return std::make_unique<LowerForeachToSCFPass>();
453478
}
454479

480+
std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
481+
return std::make_unique<LowerSparseIterationToSCFPass>();
482+
}
483+
455484
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
456485
return std::make_unique<SparseTensorConversionPass>();
457486
}

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp

Lines changed: 91 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
168168
ValueRange posRange = posRangeIf.getResults();
169169
return {posRange.front(), posRange.back()};
170170
}
171-
};
171+
}; // namespace
172172

173173
class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
174174
public:
@@ -190,7 +190,7 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
190190
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
191191
return {pLo, pHi};
192192
}
193-
};
193+
}; // namespace
194194

195195
class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
196196
public:
@@ -210,6 +210,13 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
210210
// Use the segHi as the loop upper bound.
211211
return {p, segHi};
212212
}
213+
214+
ValuePair
215+
collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
216+
std::pair<Value, Value> parentRange) const override {
217+
// Singleton level keeps the same range after collapsing.
218+
return parentRange;
219+
};
213220
};
214221

215222
class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
@@ -1474,10 +1481,85 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
14741481
return getCursor();
14751482
}
14761483

1484+
//===----------------------------------------------------------------------===//
1485+
// SparseIterationSpace Implementation
1486+
//===----------------------------------------------------------------------===//
1487+
1488+
mlir::sparse_tensor::SparseIterationSpace::SparseIterationSpace(
1489+
Location l, OpBuilder &b, Value t, unsigned tid,
1490+
std::pair<Level, Level> lvlRange, ValueRange parentPos)
1491+
: lvls() {
1492+
auto [lvlLo, lvlHi] = lvlRange;
1493+
1494+
Value c0 = C_IDX(0);
1495+
if (parentPos.empty())
1496+
parentPos = c0;
1497+
1498+
for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
1499+
lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl));
1500+
1501+
bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
1502+
for (auto &lvl : getLvlRef().drop_front())
1503+
bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound);
1504+
}
1505+
1506+
SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
1507+
IterSpaceType dstTp, ValueRange values, unsigned int tid) {
1508+
// Reconstruct every sparse tensor level.
1509+
SparseIterationSpace space;
1510+
for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
1511+
unsigned bufferCnt = 0;
1512+
if (lt.isWithPosLT())
1513+
bufferCnt++;
1514+
if (lt.isWithCrdLT())
1515+
bufferCnt++;
1516+
// Sparse tensor buffers.
1517+
ValueRange buffers = values.take_front(bufferCnt);
1518+
values = values.drop_front(bufferCnt);
1519+
1520+
// Level size.
1521+
Value sz = values.front();
1522+
values = values.drop_front();
1523+
space.lvls.push_back(
1524+
makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
1525+
}
1526+
// Two bounds.
1527+
space.bound = std::make_pair(values[0], values[1]);
1528+
values = values.drop_front(2);
1529+
1530+
// Must have consumed all values.
1531+
assert(values.empty());
1532+
return space;
1533+
}
1534+
14771535
//===----------------------------------------------------------------------===//
14781536
// SparseIterator factory functions.
14791537
//===----------------------------------------------------------------------===//
14801538

1539+
/// Helper function to create a TensorLevel object from given `tensor`.
1540+
std::unique_ptr<SparseTensorLevel>
1541+
sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b,
1542+
unsigned t, Level l) {
1543+
assert(lt.getNumBuffer() == b.size());
1544+
switch (lt.getLvlFmt()) {
1545+
case LevelFormat::Dense:
1546+
return std::make_unique<DenseLevel>(t, l, sz);
1547+
case LevelFormat::Batch:
1548+
return std::make_unique<BatchLevel>(t, l, sz);
1549+
case LevelFormat::Compressed:
1550+
return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
1551+
case LevelFormat::LooseCompressed:
1552+
return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
1553+
case LevelFormat::Singleton:
1554+
return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
1555+
case LevelFormat::NOutOfM:
1556+
return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
1557+
case LevelFormat::Undef:
1558+
llvm_unreachable("undefined level format");
1559+
}
1560+
llvm_unreachable("unrecognizable level format");
1561+
}
1562+
14811563
std::unique_ptr<SparseTensorLevel>
14821564
sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
14831565
unsigned tid, Level lvl) {
@@ -1487,33 +1569,16 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
14871569
Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
14881570
: b.create<tensor::DimOp>(l, t, lvl).getResult();
14891571

1490-
switch (lt.getLvlFmt()) {
1491-
case LevelFormat::Dense:
1492-
return std::make_unique<DenseLevel>(tid, lvl, sz);
1493-
case LevelFormat::Batch:
1494-
return std::make_unique<BatchLevel>(tid, lvl, sz);
1495-
case LevelFormat::Compressed: {
1496-
Value pos = b.create<ToPositionsOp>(l, t, lvl);
1497-
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1498-
return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
1499-
}
1500-
case LevelFormat::LooseCompressed: {
1572+
SmallVector<Value, 2> buffers;
1573+
if (lt.isWithPosLT()) {
15011574
Value pos = b.create<ToPositionsOp>(l, t, lvl);
1502-
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1503-
return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
1504-
}
1505-
case LevelFormat::Singleton: {
1506-
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1507-
return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
1575+
buffers.push_back(pos);
15081576
}
1509-
case LevelFormat::NOutOfM: {
1510-
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1511-
return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
1577+
if (lt.isWithCrdLT()) {
1578+
Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
1579+
buffers.push_back(pos);
15121580
}
1513-
case LevelFormat::Undef:
1514-
llvm_unreachable("undefined level format");
1515-
}
1516-
llvm_unreachable("unrecognizable level format");
1581+
return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
15171582
}
15181583

15191584
std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>

0 commit comments

Comments
 (0)