@@ -168,7 +168,7 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
168
168
ValueRange posRange = posRangeIf.getResults ();
169
169
return {posRange.front (), posRange.back ()};
170
170
}
171
- };
171
+ }; // namespace
172
172
173
173
class LooseCompressedLevel : public SparseLevel </* hasPosBuf=*/ true > {
174
174
public:
@@ -190,7 +190,7 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
190
190
Value pHi = genIndexLoad (b, l, getPosBuf (), memCrd);
191
191
return {pLo, pHi};
192
192
}
193
- };
193
+ }; // namespace
194
194
195
195
class SingletonLevel : public SparseLevel </* hasPosBuf=*/ false > {
196
196
public:
@@ -210,6 +210,13 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
210
210
// Use the segHi as the loop upper bound.
211
211
return {p, segHi};
212
212
}
213
+
214
+ ValuePair
215
+ collapseRangeBetween (OpBuilder &b, Location l, ValueRange batchPrefix,
216
+ std::pair<Value, Value> parentRange) const override {
217
+ // Singleton level keeps the same range after collapsing.
218
+ return parentRange;
219
+ };
213
220
};
214
221
215
222
class NOutOfMLevel : public SparseLevel </* hasPosBuf=*/ false > {
@@ -1474,10 +1481,85 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
1474
1481
return getCursor ();
1475
1482
}
1476
1483
1484
+ // ===----------------------------------------------------------------------===//
1485
+ // SparseIterationSpace Implementation
1486
+ // ===----------------------------------------------------------------------===//
1487
+
1488
+ mlir::sparse_tensor::SparseIterationSpace::SparseIterationSpace (
1489
+ Location l, OpBuilder &b, Value t, unsigned tid,
1490
+ std::pair<Level, Level> lvlRange, ValueRange parentPos)
1491
+ : lvls() {
1492
+ auto [lvlLo, lvlHi] = lvlRange;
1493
+
1494
+ Value c0 = C_IDX (0 );
1495
+ if (parentPos.empty ())
1496
+ parentPos = c0;
1497
+
1498
+ for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
1499
+ lvls.emplace_back (makeSparseTensorLevel (b, l, t, tid, lvl));
1500
+
1501
+ bound = lvls.front ()->peekRangeAt (b, l, /* batchPrefix=*/ {}, parentPos);
1502
+ for (auto &lvl : getLvlRef ().drop_front ())
1503
+ bound = lvl->collapseRangeBetween (b, l, /* batchPrefix=*/ {}, bound);
1504
+ }
1505
+
1506
+ SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues (
1507
+ IterSpaceType dstTp, ValueRange values, unsigned int tid) {
1508
+ // Reconstruct every sparse tensor level.
1509
+ SparseIterationSpace space;
1510
+ for (auto [i, lt] : llvm::enumerate (dstTp.getLvlTypes ())) {
1511
+ unsigned bufferCnt = 0 ;
1512
+ if (lt.isWithPosLT ())
1513
+ bufferCnt++;
1514
+ if (lt.isWithCrdLT ())
1515
+ bufferCnt++;
1516
+ // Sparse tensor buffers.
1517
+ ValueRange buffers = values.take_front (bufferCnt);
1518
+ values = values.drop_front (bufferCnt);
1519
+
1520
+ // Level size.
1521
+ Value sz = values.front ();
1522
+ values = values.drop_front ();
1523
+ space.lvls .push_back (
1524
+ makeSparseTensorLevel (lt, sz, buffers, tid, i + dstTp.getLoLvl ()));
1525
+ }
1526
+ // Two bounds.
1527
+ space.bound = std::make_pair (values[0 ], values[1 ]);
1528
+ values = values.drop_front (2 );
1529
+
1530
+ // Must have consumed all values.
1531
+ assert (values.empty ());
1532
+ return space;
1533
+ }
1534
+
1477
1535
// ===----------------------------------------------------------------------===//
1478
1536
// SparseIterator factory functions.
1479
1537
// ===----------------------------------------------------------------------===//
1480
1538
1539
+ // / Helper function to create a TensorLevel object from given `tensor`.
1540
+ std::unique_ptr<SparseTensorLevel>
1541
+ sparse_tensor::makeSparseTensorLevel (LevelType lt, Value sz, ValueRange b,
1542
+ unsigned t, Level l) {
1543
+ assert (lt.getNumBuffer () == b.size ());
1544
+ switch (lt.getLvlFmt ()) {
1545
+ case LevelFormat::Dense:
1546
+ return std::make_unique<DenseLevel>(t, l, sz);
1547
+ case LevelFormat::Batch:
1548
+ return std::make_unique<BatchLevel>(t, l, sz);
1549
+ case LevelFormat::Compressed:
1550
+ return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0 ], b[1 ]);
1551
+ case LevelFormat::LooseCompressed:
1552
+ return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0 ], b[1 ]);
1553
+ case LevelFormat::Singleton:
1554
+ return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0 ]);
1555
+ case LevelFormat::NOutOfM:
1556
+ return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0 ]);
1557
+ case LevelFormat::Undef:
1558
+ llvm_unreachable (" undefined level format" );
1559
+ }
1560
+ llvm_unreachable (" unrecognizable level format" );
1561
+ }
1562
+
1481
1563
std::unique_ptr<SparseTensorLevel>
1482
1564
sparse_tensor::makeSparseTensorLevel (OpBuilder &b, Location l, Value t,
1483
1565
unsigned tid, Level lvl) {
@@ -1487,33 +1569,16 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
1487
1569
Value sz = stt.hasEncoding () ? b.create <LvlOp>(l, t, lvl).getResult ()
1488
1570
: b.create <tensor::DimOp>(l, t, lvl).getResult ();
1489
1571
1490
- switch (lt.getLvlFmt ()) {
1491
- case LevelFormat::Dense:
1492
- return std::make_unique<DenseLevel>(tid, lvl, sz);
1493
- case LevelFormat::Batch:
1494
- return std::make_unique<BatchLevel>(tid, lvl, sz);
1495
- case LevelFormat::Compressed: {
1496
- Value pos = b.create <ToPositionsOp>(l, t, lvl);
1497
- Value crd = b.create <ToCoordinatesOp>(l, t, lvl);
1498
- return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
1499
- }
1500
- case LevelFormat::LooseCompressed: {
1572
+ SmallVector<Value, 2 > buffers;
1573
+ if (lt.isWithPosLT ()) {
1501
1574
Value pos = b.create <ToPositionsOp>(l, t, lvl);
1502
- Value crd = b.create <ToCoordinatesOp>(l, t, lvl);
1503
- return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
1504
- }
1505
- case LevelFormat::Singleton: {
1506
- Value crd = b.create <ToCoordinatesOp>(l, t, lvl);
1507
- return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
1575
+ buffers.push_back (pos);
1508
1576
}
1509
- case LevelFormat::NOutOfM: {
1510
- Value crd = b.create <ToCoordinatesOp>(l, t, lvl);
1511
- return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd );
1577
+ if (lt. isWithCrdLT ()) {
1578
+ Value pos = b.create <ToCoordinatesOp>(l, t, lvl);
1579
+ buffers. push_back (pos );
1512
1580
}
1513
- case LevelFormat::Undef:
1514
- llvm_unreachable (" undefined level format" );
1515
- }
1516
- llvm_unreachable (" unrecognizable level format" );
1581
+ return makeSparseTensorLevel (lt, sz, buffers, tid, lvl);
1517
1582
}
1518
1583
1519
1584
std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
0 commit comments