Skip to content

Commit cf4dd91

Browse files
authored
[mlir][sparse] initialize slice-driven loop-related fields in one place (#76099)
1 parent 8fdc3b9 commit cf4dd91

File tree

3 files changed

+76
-85
lines changed

3 files changed

+76
-85
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ void CodegenEnv::startEmit() {
8585
for (Level lvl = 0; lvl < lvlRank; lvl++)
8686
sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl));
8787
}
88-
8988
loopEmitter.initialize(
9089
tensors,
9190
StringAttr::get(linalgOp.getContext(),
@@ -95,17 +94,8 @@ void CodegenEnv::startEmit() {
9594
// TODO: compute the map and pass it to loop emitter directly instead of
9695
// passing in a callback.
9796
/*dependentLvlGetter=*/
98-
[this](TensorId t,
99-
Level lvl) -> std::vector<std::pair<TensorLevel, unsigned>> {
100-
// Translates from a list of loop indices to a list of [tid, lvl] pair.
101-
std::vector<LoopCoeffPair> &rLoops = merger().getDependentLoops(t, lvl);
102-
std::vector<std::pair<TensorLevel, unsigned>> ret;
103-
ret.reserve(rLoops.size());
104-
for (auto [loop, coeff] : rLoops) {
105-
TensorLevel tl = makeTensorLevel(merger().getLoopDefiningLvl(loop));
106-
ret.emplace_back(tl, coeff);
107-
};
108-
return ret;
97+
[this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> {
98+
return merger().getDependentLoops(t, lvl);
10999
});
110100
}
111101

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp

Lines changed: 62 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -391,13 +391,18 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
391391
/*posTupleNum=*/Value(), std::nullopt, 0);
392392
if (dimGetter && !isSynTensor(tid)) {
393393
for (Level l = 0; l < lvlRank; l++) {
394-
dependentLvlMap[tid][l] = dimGetter(tid, l);
394+
std::vector<std::pair<LoopId, unsigned>> deps = dimGetter(tid, l);
395+
// Sort the loop by order.
396+
std::sort(deps.begin(), deps.end(),
397+
[](auto &lhs, auto &rhs) { return lhs.first < rhs.first; });
398+
399+
dependentLvlMap[tid][l] = std::move(deps);
395400
unsigned depends = dependentLvlMap[tid][l].size();
396401
if (depends == 0)
397402
continue;
398-
sliceMeta[tid][l].assign(depends, std::make_pair(nullptr, 0));
403+
sliceMeta[tid][l].reserve(depends);
399404
// We need `depends - 1` slices to fully reduce the affine expression.
400-
slicePosBuffer[tid][l].assign(depends - 1, nullptr);
405+
slicePosBuffer[tid][l].reserve(depends - 1);
401406
}
402407
}
403408
}
@@ -487,35 +492,70 @@ void LoopEmitter::initializeLoopEmit(
487492
// hoist the code ouside if-conditions.
488493
}
489494

490-
Type indexType = builder.getIndexType();
491-
Value c0 = constantZero(builder, loc, indexType);
495+
initSliceDriven(builder, loc);
496+
}
497+
498+
void LoopEmitter::initSliceDriven(OpBuilder &builder, Location loc) {
499+
Value c0 = C_IDX(0);
492500
for (TensorId t = 0, e = tensors.size(); t < e; t++) {
493501
auto rtp = dyn_cast<RankedTensorType>(tensors[t].getType());
494502
if (!rtp)
495503
continue;
496504

497505
Level lvlRank = SparseTensorType(rtp).getLvlRank();
506+
507+
// Compute the dependency reduction order.
508+
auto remDepStack = dependentLvlMap;
509+
std::vector<std::tuple<LoopId, TensorId, Level>> depRedOrder;
498510
for (Level lvl = 0; lvl < lvlRank; lvl++) {
499-
if (!dependentLvlMap[t][lvl].empty()) {
500-
ArrayRef<std::pair<TensorLevel, unsigned>> depLvls =
501-
dependentLvlMap[t][lvl];
502-
// Needs at least two operands to form a non-trivial affine expression.
503-
assert(depLvls.size() == sliceMeta[t][lvl].size());
504-
505-
Value size = c0;
506-
for (int e = depLvls.size() - 1; e >= 0; e--) {
507-
auto [dt, dl] = unpackTensorLevel(depLvls[e].first);
508-
unsigned stride = depLvls[e].second;
509-
Value stridedSize = lvlSizes[dt][dl];
510-
if (stride != 1)
511-
stridedSize = MULI(stridedSize, C_IDX(stride));
512-
size = ADDI(size, stridedSize);
513-
sliceMeta[t][lvl][e] = std::make_pair(size, stride);
514-
}
511+
// Reverse queue into a stack.
512+
std::reverse(remDepStack[t][lvl].begin(), remDepStack[t][lvl].end());
513+
for (auto [loop, coeff] : dependentLvlMap[t][lvl])
514+
depRedOrder.emplace_back(std::make_tuple(loop, t, lvl));
515+
}
516+
517+
if (depRedOrder.empty())
518+
continue;
519+
std::sort(depRedOrder.begin(), depRedOrder.end(),
520+
[](auto &l, auto &r) { return std::get<0>(l) < std::get<0>(r); });
521+
522+
for (auto [loop, t, lvl] : depRedOrder) {
523+
std::pair<LoopId, unsigned> curDep = remDepStack[t][lvl].back();
524+
assert(curDep.first == loop);
525+
Value size = c0;
526+
for (auto [loop, stride] : remDepStack[t][lvl]) {
527+
// The synthetic tensor high defines the loop upper bound.
528+
Value loopHi = highs[getSynTensorId()][loop];
529+
size = ADDI(size, MULI(loopHi, C_IDX(stride)));
515530
}
531+
sliceMeta[t][lvl].emplace_back(size, curDep.second);
532+
remDepStack[t][lvl].pop_back();
533+
534+
// Generate caches required to fast compute next-non-empty slices with
535+
// increasing offset for slice-base loop.
536+
// We do not need cache for dense levels.
537+
if (!remDepStack[t][lvl].empty() && !isDenseLT(lvls[t][lvl]->getLT())) {
538+
Value cnt = C_IDX(1);
539+
for (int preLvl = lvl - 1; preLvl >= 0; preLvl--) {
540+
if (remDepStack[t][preLvl].empty())
541+
break;
542+
assert(remDepStack[t][preLvl].size() == 1 && "Not implemented");
543+
auto [loop, stride] = remDepStack[t][preLvl].back();
544+
assert(stride == 1 && "Not yet implemented");
545+
// Accumlate the size required to cache the pLo for the slice.
546+
// E.g., if we want to cache the pIdx for slice<d0xd1xf64> on the
547+
// second level. We at most need a memref<d0xindex>.
548+
//
549+
// NOTE: this is apparently an over-approximation when the previous
550+
// level is compressed, and we can compute a precise memory size
551+
// inside the loops. But that would also requires us to allocate/free
552+
// memory in loops.
553+
cnt = MULI(highs[getSynTensorId()][loop], cnt);
554+
}
555+
slicePosBuffer[t][lvl].push_back(allocSlicePosBuf(builder, loc, cnt));
556+
} // else fully resolved.
516557
}
517558
}
518-
localInsertPos = builder.getInsertionPoint()->getPrevNode();
519559
}
520560

521561
void LoopEmitter::categorizeLoopCondition(
@@ -1878,9 +1918,6 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
18781918
// simple dim expression in between).
18791919
assert(lvl == *sliceStack[tid].back().slicedOnLvl + 1);
18801920

1881-
// Check slice stack integrity.
1882-
assert(slicePosBuffer[tid][lvl - 1].size() == sliceStack[tid].back().depth);
1883-
18841921
SmallVector<const SliceInfo *> unResSlices;
18851922
std::optional<std::pair<TensorId, Level>> firstResLvl;
18861923
for (Level curLvl = lvl; curLvl >= 1; curLvl--) {
@@ -2006,37 +2043,6 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
20062043
if (baseEnc.isSlice())
20072044
llvm_unreachable("TODO: not yet implemented");
20082045

2009-
// Generate caches required to fast compute next-non-empty slices with
2010-
// increasing offset for slice-base loop.
2011-
// We do not need cache for dense levels.
2012-
if (slicePosBuffer[tid][lvl][0] == nullptr && !isDenseLT(lvlType)) {
2013-
OpBuilder::InsertionGuard guard(builder);
2014-
// The buffer can be reused, and the size is loop invariant: it only
2015-
// depends on the iteration graph's toposort.
2016-
builder.setInsertionPointAfter(localInsertPos);
2017-
Value tupleCnt = C_IDX(1);
2018-
// Accumlates the size required to cache the pLo for the slice.
2019-
// E.g., if we want to cache the pIdx for slice<d0xd1xf64> on the second
2020-
// level. We at most need to a memref<d0xindex>.
2021-
// NOTE: this is apperantly an over-approximation when the previous
2022-
// level is compressed, and we can compute a precise memory size
2023-
// inside the loops. But that would also requires us to allocate/free
2024-
// memorys in loops.
2025-
// TODO: Maybe using allocaScopeOp inside the loop to resolve the issue?
2026-
for (Level curLevel = lvl;
2027-
curLevel >= 1 && !lvlFullyResolved(tid, curLevel - 1); curLevel--) {
2028-
// We only handle cases when all the previously unresolved levels are
2029-
// fully reduced.
2030-
assert(depFullyReduced(tid, curLevel - 1));
2031-
assert(!sliceMeta[tid][curLevel - 1].empty());
2032-
auto [sz, stride] = sliceMeta[tid][curLevel - 1].back();
2033-
assert(stride == 1 && "Not yet implemented");
2034-
tupleCnt = MULI(tupleCnt, sz);
2035-
}
2036-
for (Value &cache : slicePosBuffer[tid][lvl])
2037-
cache = allocSlicePosBuf(builder, loc, tupleCnt);
2038-
}
2039-
20402046
if (sliceInfo.isInitialTensor() ||
20412047
(lvl >= 1 && lvlFullyResolved(tid, lvl - 1))) {
20422048
// First level or previous level has been full resolved.

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,18 @@ class LoopEmitter {
6363
using SynTensorBoundSetter =
6464
function_ref<Value(OpBuilder &builder, Location loc, Level lvl)>;
6565

66-
// Map from [tid, lvl] to a list of dependent [tidlvl, coeffecient] for
66+
// Map from [tid, lvl] to a list of dependent [LoopId, coeffecient] for
6767
// subscript expressions on sparse tensors.
6868
//
69-
// E.g., for affine index (2 * d0 + d1), it depends on two tidlvls that
70-
// defines d0 and d1 (for affine expression reduction) and uses 2 and 1 for
71-
// cofficients on d0, d1 respectively.
72-
// If the list is empty, it means that there is no affine expression on the
73-
// input [tid, lvl].
69+
// E.g., for affine index (2 * d0 + d1), it depends on loop d0 and d1 (for
70+
// affine expression reduction) and uses 2 and 1 for coefficients on d0, d1
71+
// respectively. If the list is empty, it means that there is no affine
72+
// expression on the input [tid, lvl].
7473
//
75-
// NOTE: The caller is responsible to ensure that the order of the returned
76-
// list to be consistent with the topological order of the iteration graph,
77-
// otherwise the loop emitter might reduce a wrong dependent index variable
78-
// when generating slice-driven loops.
74+
// NOTE: LoopEmitter assumes that the loop id is consistent with the loop
75+
// order, i.e., loop `d0` will be generated before loop `d1`.
7976
using DependentLvlGetter =
80-
function_ref<std::vector<std::pair<TensorLevel, unsigned>>(TensorId,
81-
Level)>;
77+
function_ref<std::vector<std::pair<LoopId, unsigned>>(TensorId, Level)>;
8278

8379
LoopEmitter() = default;
8480

@@ -534,6 +530,8 @@ class LoopEmitter {
534530
// Slice-driven loop related methods.
535531
//
536532

533+
void initSliceDriven(OpBuilder &builder, Location loc);
534+
537535
/// Retrieves the most recent slice on lvl. To reduce affine expression like
538536
/// d0 + d1 + d2, we need two slices (one of size d1 + d2, and the other of
539537
/// size d2). This methods returns the latter slice (of size d2).
@@ -621,9 +619,6 @@ class LoopEmitter {
621619
bool hasOutput;
622620
bool isSparseOut;
623621

624-
/// The insertion point to allocate top level local variables.
625-
Operation *localInsertPos;
626-
627622
//
628623
// Fields which have `numTensor` many entries.
629624
//
@@ -645,7 +640,7 @@ class LoopEmitter {
645640
std::vector<std::vector<Value>> highs;
646641
std::vector<std::vector<Value>> lvlSizes;
647642
std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls;
648-
std::vector<Value> valBuffer; // to_value
643+
std::vector<Value> valBuffer; // to_value
649644

650645
//
651646
// Slice-driven loops related fields.
@@ -659,7 +654,7 @@ class LoopEmitter {
659654

660655
// Map from [tid, level] to a list of dependent [tidlevel, coefficient].
661656
// See comments for `DependentLvlGetter`.
662-
std::vector<std::vector<std::vector<std::pair<TensorLevel, unsigned>>>>
657+
std::vector<std::vector<std::vector<std::pair<LoopId, unsigned>>>>
663658
dependentLvlMap;
664659

665660
// The cached position buffer for the slices, they serve the same purpose as

0 commit comments

Comments
 (0)