Skip to content

Commit c42bbda

Browse files
author
Peiming Liu
authored
[mlir][sparse] implement lowering rules for ExtractIterSpaceOp. (#89143)
**DO NOT MERGE** until #89003
1 parent bf7c505 commit c42bbda

File tree

10 files changed

+342
-51
lines changed

10 files changed

+342
-51
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

+4
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,10 @@ struct LevelType {
357357
return hasSparseSemantic();
358358
}
359359

360+
constexpr unsigned getNumBuffer() const {
361+
return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
362+
}
363+
360364
std::string toMLIRString() const {
361365
std::string lvlStr = toFormatString(getLvlFmt());
362366
std::string propStr = "";

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

+15
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/IR/PatternMatch.h"
1717
#include "mlir/Pass/Pass.h"
1818
#include "mlir/Transforms/DialectConversion.h"
19+
#include "mlir/Transforms/OneToNTypeConversion.h"
1920

2021
//===----------------------------------------------------------------------===//
2122
// Include the generated pass header (which needs some early definitions).
@@ -143,6 +144,20 @@ void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
143144

144145
std::unique_ptr<Pass> createLowerForeachToSCFPass();
145146

147+
//===----------------------------------------------------------------------===//
148+
// The LowerSparseIterationToSCF pass.
149+
//===----------------------------------------------------------------------===//
150+
151+
/// Type converter for iter_space and iterator.
152+
struct SparseIterationTypeConverter : public OneToNTypeConverter {
153+
SparseIterationTypeConverter();
154+
};
155+
156+
void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
157+
RewritePatternSet &patterns);
158+
159+
std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
160+
146161
//===----------------------------------------------------------------------===//
147162
// The SparseTensorConversion pass.
148163
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

+16-1
Original file line numberDiff line numberDiff line change
@@ -484,12 +484,27 @@ def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
484484
let summary = "sparse space collapsing pass";
485485
let description = [{
486486
This pass collapses consecutive sparse spaces (extracted from the same tensor)
487-
into one multi-dimensional space. The pass is not yet stablized.
487+
into one multi-dimensional space. The pass is not yet stabilized.
488488
}];
489489
let constructor = "mlir::createSparseSpaceCollapsePass()";
490490
let dependentDialects = [
491491
"sparse_tensor::SparseTensorDialect",
492492
];
493493
}
494494

495+
def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::FuncOp"> {
496+
let summary = "lower sparse_tensor.iterate/coiterate into scf loops";
497+
let description = [{
498+
This pass lowers `sparse_tensor.iterate` operations into `scf.for/while` operations.
499+
The pass is not yet stabilized.
500+
}];
501+
let constructor = "mlir::createLowerSparseIterationToSCFPass()";
502+
let dependentDialects = [
503+
"memref::MemRefDialect",
504+
"scf::SCFDialect",
505+
"sparse_tensor::SparseTensorDialect",
506+
];
507+
}
508+
509+
495510
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
44
SparseAssembler.cpp
55
SparseBufferRewriting.cpp
66
SparseGPUCodegen.cpp
7+
SparseIterationToScf.cpp
78
SparseReinterpretMap.cpp
89
SparseStorageSpecifierToLLVM.cpp
910
SparseSpaceCollapse.cpp
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
2+
#include "Utils/CodegenUtils.h"
3+
#include "Utils/SparseTensorIterator.h"
4+
5+
#include "mlir/Dialect/SCF/IR/SCF.h"
6+
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
7+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
8+
#include "mlir/Transforms/OneToNTypeConversion.h"
9+
10+
using namespace mlir;
11+
using namespace mlir::sparse_tensor;
12+
13+
void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
14+
SmallVectorImpl<Type> &fields) {
15+
// Position and coordinate buffer in the sparse structure.
16+
if (enc.getLvlType(lvl).isWithPosLT())
17+
fields.push_back(enc.getPosMemRefType());
18+
if (enc.getLvlType(lvl).isWithCrdLT())
19+
fields.push_back(enc.getCrdMemRefType());
20+
// One index for shape bound (result from lvlOp).
21+
fields.push_back(IndexType::get(enc.getContext()));
22+
}
23+
24+
static std::optional<LogicalResult>
25+
convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
26+
27+
auto idxTp = IndexType::get(itSp.getContext());
28+
for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
29+
convertLevelType(itSp.getEncoding(), l, fields);
30+
31+
// Two indices for lower and upper bound (we only need one pair for the last
32+
// iteration space).
33+
fields.append({idxTp, idxTp});
34+
return success();
35+
}
36+
37+
namespace {
38+
39+
/// Sparse codegen rule for number of entries operator.
40+
class ExtractIterSpaceConverter
41+
: public OneToNOpConversionPattern<ExtractIterSpaceOp> {
42+
public:
43+
using OneToNOpConversionPattern::OneToNOpConversionPattern;
44+
LogicalResult
45+
matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
46+
OneToNPatternRewriter &rewriter) const override {
47+
Location loc = op.getLoc();
48+
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
49+
50+
// Construct the iteration space.
51+
SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
52+
op.getLvlRange(), adaptor.getParentIter());
53+
54+
SmallVector<Value> result = space.toValues();
55+
rewriter.replaceOp(op, result, resultMapping);
56+
return success();
57+
}
58+
};
59+
60+
} // namespace
61+
62+
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
63+
addConversion([](Type type) { return type; });
64+
addConversion(convertIterSpaceType);
65+
66+
addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
67+
ValueRange inputs,
68+
Location loc) -> std::optional<Value> {
69+
return builder
70+
.create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
71+
.getResult(0);
72+
});
73+
}
74+
75+
void mlir::populateLowerSparseIterationToSCFPatterns(
76+
TypeConverter &converter, RewritePatternSet &patterns) {
77+
patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
78+
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

+28
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ namespace mlir {
2626
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
2727
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
2828
#define GEN_PASS_DEF_SPARSIFICATIONPASS
29+
#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
2930
#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
3031
#define GEN_PASS_DEF_LOWERFOREACHTOSCF
3132
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
@@ -157,6 +158,29 @@ struct LowerForeachToSCFPass
157158
}
158159
};
159160

161+
struct LowerSparseIterationToSCFPass
162+
: public impl::LowerSparseIterationToSCFBase<
163+
LowerSparseIterationToSCFPass> {
164+
LowerSparseIterationToSCFPass() = default;
165+
LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
166+
default;
167+
168+
void runOnOperation() override {
169+
auto *ctx = &getContext();
170+
RewritePatternSet patterns(ctx);
171+
SparseIterationTypeConverter converter;
172+
ConversionTarget target(*ctx);
173+
174+
// The actual conversion.
175+
target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
176+
populateLowerSparseIterationToSCFPatterns(converter, patterns);
177+
178+
if (failed(applyPartialOneToNConversion(getOperation(), converter,
179+
std::move(patterns))))
180+
signalPassFailure();
181+
}
182+
};
183+
160184
struct SparseTensorConversionPass
161185
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
162186
SparseTensorConversionPass() = default;
@@ -439,6 +463,10 @@ std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
439463
return std::make_unique<LowerForeachToSCFPass>();
440464
}
441465

466+
std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
467+
return std::make_unique<LowerSparseIterationToSCFPass>();
468+
}
469+
442470
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
443471
return std::make_unique<SparseTensorConversionPass>();
444472
}

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp

+91-26
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
168168
ValueRange posRange = posRangeIf.getResults();
169169
return {posRange.front(), posRange.back()};
170170
}
171-
};
171+
}; // namespace
172172

173173
class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
174174
public:
@@ -190,7 +190,7 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
190190
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
191191
return {pLo, pHi};
192192
}
193-
};
193+
}; // namespace
194194

195195
class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
196196
public:
@@ -210,6 +210,13 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
210210
// Use the segHi as the loop upper bound.
211211
return {p, segHi};
212212
}
213+
214+
ValuePair
215+
collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
216+
std::pair<Value, Value> parentRange) const override {
217+
// Singleton level keeps the same range after collapsing.
218+
return parentRange;
219+
};
213220
};
214221

215222
class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
@@ -1474,10 +1481,85 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
14741481
return getCursor();
14751482
}
14761483

1484+
//===----------------------------------------------------------------------===//
1485+
// SparseIterationSpace Implementation
1486+
//===----------------------------------------------------------------------===//
1487+
1488+
mlir::sparse_tensor::SparseIterationSpace::SparseIterationSpace(
1489+
Location l, OpBuilder &b, Value t, unsigned tid,
1490+
std::pair<Level, Level> lvlRange, ValueRange parentPos)
1491+
: lvls() {
1492+
auto [lvlLo, lvlHi] = lvlRange;
1493+
1494+
Value c0 = C_IDX(0);
1495+
if (parentPos.empty())
1496+
parentPos = c0;
1497+
1498+
for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
1499+
lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl));
1500+
1501+
bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
1502+
for (auto &lvl : getLvlRef().drop_front())
1503+
bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound);
1504+
}
1505+
1506+
SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
1507+
IterSpaceType dstTp, ValueRange values, unsigned int tid) {
1508+
// Reconstruct every sparse tensor level.
1509+
SparseIterationSpace space;
1510+
for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
1511+
unsigned bufferCnt = 0;
1512+
if (lt.isWithPosLT())
1513+
bufferCnt++;
1514+
if (lt.isWithCrdLT())
1515+
bufferCnt++;
1516+
// Sparse tensor buffers.
1517+
ValueRange buffers = values.take_front(bufferCnt);
1518+
values = values.drop_front(bufferCnt);
1519+
1520+
// Level size.
1521+
Value sz = values.front();
1522+
values = values.drop_front();
1523+
space.lvls.push_back(
1524+
makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
1525+
}
1526+
// Two bounds.
1527+
space.bound = std::make_pair(values[0], values[1]);
1528+
values = values.drop_front(2);
1529+
1530+
// Must have consumed all values.
1531+
assert(values.empty());
1532+
return space;
1533+
}
1534+
14771535
//===----------------------------------------------------------------------===//
14781536
// SparseIterator factory functions.
14791537
//===----------------------------------------------------------------------===//
14801538

1539+
/// Helper function to create a TensorLevel object from given `tensor`.
1540+
std::unique_ptr<SparseTensorLevel>
1541+
sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b,
1542+
unsigned t, Level l) {
1543+
assert(lt.getNumBuffer() == b.size());
1544+
switch (lt.getLvlFmt()) {
1545+
case LevelFormat::Dense:
1546+
return std::make_unique<DenseLevel>(t, l, sz);
1547+
case LevelFormat::Batch:
1548+
return std::make_unique<BatchLevel>(t, l, sz);
1549+
case LevelFormat::Compressed:
1550+
return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
1551+
case LevelFormat::LooseCompressed:
1552+
return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
1553+
case LevelFormat::Singleton:
1554+
return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
1555+
case LevelFormat::NOutOfM:
1556+
return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
1557+
case LevelFormat::Undef:
1558+
llvm_unreachable("undefined level format");
1559+
}
1560+
llvm_unreachable("unrecognizable level format");
1561+
}
1562+
14811563
std::unique_ptr<SparseTensorLevel>
14821564
sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
14831565
unsigned tid, Level lvl) {
@@ -1487,33 +1569,16 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
14871569
Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
14881570
: b.create<tensor::DimOp>(l, t, lvl).getResult();
14891571

1490-
switch (lt.getLvlFmt()) {
1491-
case LevelFormat::Dense:
1492-
return std::make_unique<DenseLevel>(tid, lvl, sz);
1493-
case LevelFormat::Batch:
1494-
return std::make_unique<BatchLevel>(tid, lvl, sz);
1495-
case LevelFormat::Compressed: {
1496-
Value pos = b.create<ToPositionsOp>(l, t, lvl);
1497-
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1498-
return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
1499-
}
1500-
case LevelFormat::LooseCompressed: {
1572+
SmallVector<Value, 2> buffers;
1573+
if (lt.isWithPosLT()) {
15011574
Value pos = b.create<ToPositionsOp>(l, t, lvl);
1502-
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1503-
return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
1504-
}
1505-
case LevelFormat::Singleton: {
1506-
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1507-
return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
1575+
buffers.push_back(pos);
15081576
}
1509-
case LevelFormat::NOutOfM: {
1510-
Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1511-
return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
1577+
if (lt.isWithCrdLT()) {
1578+
Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
1579+
buffers.push_back(pos);
15121580
}
1513-
case LevelFormat::Undef:
1514-
llvm_unreachable("undefined level format");
1515-
}
1516-
llvm_unreachable("unrecognizable level format");
1581+
return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
15171582
}
15181583

15191584
std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>

0 commit comments

Comments
 (0)