Skip to content

Commit 4e2f152

Browse files
authored
[mlir][sparse] code cleanup, remove FIXMEs (#73575)
1 parent c0fe071 commit 4e2f152

File tree

8 files changed

+29
-81
lines changed

8 files changed

+29
-81
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,15 @@ bool isBlockSparsity(AffineMap dimToLvl);
163163
// Reordering.
164164
//
165165

166-
/// [deprecated] Convenience method to translate the given level to the
167-
/// corresponding dimension. Requires: `0 <= l < lvlRank`.
168-
Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);
169-
170-
/// [deprecated] Convenience method to translate the given dimension to
171-
/// the corresponding level. Requires: `0 <= d < dimRank`.
172-
Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
166+
/// Convenience method to translate the given level to the corresponding
167+
/// dimension.
168+
/// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`.
169+
Dimension toDim(SparseTensorEncodingAttr enc, Level l);
170+
171+
/// Convenience method to translate the given dimension to the corresponding
172+
/// level.
173+
/// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`.
174+
Level toLvl(SparseTensorEncodingAttr enc, Dimension d);
173175

174176
} // namespace sparse_tensor
175177
} // namespace mlir

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -375,14 +375,12 @@ SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
375375

376376
std::optional<uint64_t>
377377
SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
378-
// FIXME: `toOrigDim` is deprecated.
379-
return getStaticDimSliceOffset(toOrigDim(*this, lvl));
378+
return getStaticDimSliceOffset(toDim(*this, lvl));
380379
}
381380

382381
std::optional<uint64_t>
383382
SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
384-
// FIXME: `toOrigDim` is deprecated.
385-
return getStaticDimSliceStride(toOrigDim(*this, lvl));
383+
return getStaticDimSliceStride(toDim(*this, lvl));
386384
}
387385

388386
SmallVector<int64_t>
@@ -398,10 +396,8 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
398396

399397
if (isPermutation()) {
400398
for (unsigned r = 0; r < rank; r++) {
401-
// FIXME: `toOrigDim` and `toStoredDim` are deprecated.
402-
unsigned trans = dir == CrdTransDirectionKind::dim2lvl
403-
? toOrigDim(*this, r)
404-
: toStoredDim(*this, r);
399+
unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
400+
: toLvl(*this, r);
405401
ret.push_back(srcShape[trans]);
406402
}
407403
return ret;
@@ -922,31 +918,20 @@ RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src,
922918
ordered);
923919
}
924920

925-
// TODO: Remove this definition once all use-sites have been fixed to
926-
// properly handle non-permutations.
927-
Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc,
928-
Level l) {
921+
Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
929922
if (enc) {
930-
if (const auto dimToLvl = enc.getDimToLvl()) {
931-
assert(enc.isPermutation());
923+
assert(enc.isPermutation() && "Non permutation map not supported");
924+
if (const auto dimToLvl = enc.getDimToLvl())
932925
return dimToLvl.getDimPosition(l);
933-
}
934926
}
935927
return l;
936928
}
937929

938-
// TODO: Remove this definition once all use-sites have been fixed to
939-
// properly handle non-permutations.
940-
Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
941-
Dimension d) {
930+
Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
942931
if (enc) {
943-
if (const auto dimToLvl = enc.getDimToLvl()) {
944-
assert(enc.isPermutation());
945-
auto maybePos =
946-
dimToLvl.getResultPosition(getAffineDimExpr(d, enc.getContext()));
947-
assert(maybePos.has_value());
948-
return *maybePos;
949-
}
932+
assert(enc.isPermutation() && "Non permutation map not supported");
933+
if (const auto lvlToDim = enc.getLvlToDim())
934+
return lvlToDim.getDimPosition(d);
950935
}
951936
return d;
952937
}

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -546,32 +546,6 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
546546
}
547547
}
548548

549-
Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc,
550-
SparseTensorEncodingAttr enc,
551-
ValueRange dimSizes,
552-
Value valuesBuffer,
553-
Value lvlCoords) {
554-
// Reuse the `lvlCoords` buffer to store the level-sizes.
555-
const Level lvlRank = enc.getLvlRank();
556-
SmallVector<Value> lvlSizes;
557-
lvlSizes.reserve(lvlRank);
558-
for (Level l = 0; l < lvlRank; l++)
559-
// FIXME: `toOrigDim` is deprecated.
560-
lvlSizes.push_back(dimSizes[toOrigDim(enc, l)]);
561-
storeAll(builder, loc, lvlCoords, lvlSizes);
562-
// The memref ReshapeOp requires the sizes buffer to have a static
563-
// shape.
564-
const auto iTp = builder.getIndexType();
565-
const SmallVector<Size, 1> lvlSizesShape{static_cast<Size>(lvlRank)};
566-
const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp);
567-
lvlCoords = builder.create<memref::CastOp>(loc, lvlSizesTp, lvlCoords);
568-
// Finally, create the ReshapeOp.
569-
const SmallVector<Size> resShape(lvlRank, ShapedType::kDynamic);
570-
const Type elemTp = getMemRefType(valuesBuffer).getElementType();
571-
const auto resTp = MemRefType::get(resShape, elemTp);
572-
return builder.create<memref::ReshapeOp>(loc, resTp, valuesBuffer, lvlCoords);
573-
}
574-
575549
TypedValue<BaseMemRefType>
576550
sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
577551
auto tTp = llvm::cast<TensorType>(tensor.getType());

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,6 @@ SmallVector<Value> loadAll(OpBuilder &builder, Location loc, size_t size,
277277
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
278278
size_t offsetIdx = 0, Value offsetVal = Value());
279279

280-
/// Reshapes the linear values buffer for an annotated all dense sparse tensor
281-
/// to match the shape of the corresponding dense tensor to support direct
282-
/// access of the buffer through `lvlCoords`.
283-
Value reshapeValuesToLevels(OpBuilder &builder, Location loc,
284-
SparseTensorEncodingAttr enc, ValueRange dimSizes,
285-
Value valuesBuffer, Value lvlCoords);
286-
287280
// Generates code to cast a tensor to a memref.
288281
TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
289282
Value tensor);

mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ enum class SortMask : unsigned {
2626
// The individual mask bits.
2727
kIncludeDenseOutput = 0x1, // b001
2828
kIncludeDenseInput = 0x2, // b010
29-
kIncludeUndef = 0x4, // b100
3029
// The subsets of mask bits.
3130
kIncludeAll = 0x7, // b111
3231
kIncludeDense = 0x3, // b011

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,13 @@ static constexpr unsigned kSliceIterWidth = 3;
6868
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
6969
Level lvl) {
7070
auto enc = getSparseTensorEncoding(tensor.getType());
71-
// FIXME: `toOrigDim` is deprecated
72-
return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl));
71+
return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl));
7372
}
7473

7574
static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
7675
Level lvl) {
7776
auto enc = getSparseTensorEncoding(tensor.getType());
78-
// FIXME: `toOrigDim` is deprecated
79-
return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl));
77+
return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
8078
}
8179

8280
/// Converts a coordinate relative to the slice to the coordinate relative

mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -422,10 +422,10 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
422422
// computation. Must be ordered from more strict to less strict.
423423
// Ideally (though might not be guaranteed), the earlier a constraint mask
424424
// can be satisfied, the faster the generated kernel will be.
425-
const auto allMasks = {
426-
SortMask::kIncludeAll, SortMask::kIncludeDense,
427-
SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput,
428-
SortMask::kIncludeUndef, SortMask::kSparseOnly};
425+
const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
426+
SortMask::kIncludeDenseInput,
427+
SortMask::kIncludeDenseOutput,
428+
SortMask::kSparseOnly};
429429
for (const SortMask mask : allMasks) {
430430
order = scheduler.sort(mask);
431431
if (order) {

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -660,8 +660,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
660660
SmallVector<Value> srcDcvs;
661661
srcDcvs.reserve(srcRank);
662662
for (Dimension d = 0; d < srcRank; d++) {
663-
// FIXME: `toStoredDim` is deprecated
664-
Level lvl = toStoredDim(encSrc, d);
663+
Level lvl = toLvl(encSrc, d);
665664
srcDcvs.push_back(srcLcvs[lvl]);
666665
}
667666

@@ -765,8 +764,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
765764
SmallVector<Value> srcDcvs;
766765
srcDcvs.reserve(dimRank);
767766
for (Dimension d = 0; d < dimRank; d++) {
768-
// FIXME: `toStoredDim` is deprecated
769-
Level lvl = toStoredDim(encSrc, d);
767+
Level lvl = toLvl(encSrc, d);
770768
srcDcvs.push_back(srcLcvs[lvl]);
771769
}
772770
SmallVector<Value> dstDcvs;
@@ -871,9 +869,8 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
871869
return failure();
872870

873871
if (stt.isPermutation()) {
874-
// FIXME: `toStoredDim` is deprecated
875872
rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
876-
toStoredDim(stt.getEncoding(), *dim));
873+
toLvl(stt.getEncoding(), *dim));
877874
return success();
878875
}
879876

0 commit comments

Comments
 (0)