Skip to content

Commit 0d5d889

Browse files
move COOSegments
1 parent 35c1a47 commit 0d5d889

File tree

4 files changed

+17
-18
lines changed

4 files changed

+17
-18
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,19 @@ using Level = uint64_t;
4141
/// including the value `ShapedType::kDynamic` (for shapes).
4242
using Size = int64_t;
4343

44+
/// A simple structure that encodes a range of levels in the sparse tensors
45+
/// that forms a COO segment.
46+
struct COOSegment {
47+
std::pair<Level, Level> lvlRange; // [low, high)
48+
bool isSoA;
49+
50+
bool isAoS() const { return !isSoA; }
51+
bool isSegmentStart(Level l) const { return l == lvlRange.first; }
52+
bool inSegment(Level l) const {
53+
return l >= lvlRange.first && l < lvlRange.second;
54+
}
55+
};
56+
4457
} // namespace sparse_tensor
4558
} // namespace mlir
4659

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -507,22 +507,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
507507
ValueRange translateCrds(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange crds, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
508508

509509
//
510-
// COO struct and methods.
510+
// COO methods.
511511
//
512512

513-
/// A simple structure that encodes a range of levels in the sparse tensors
514-
/// that forms a COO segment.
515-
struct COOSegment {
516-
std::pair<Level, Level> lvlRange; // [low, high)
517-
bool isSoA;
518-
519-
bool isAoS() const { return !isSoA; }
520-
bool isSegmentStart(Level l) const { return l == lvlRange.first; }
521-
bool inSegment(Level l) const {
522-
return l >= lvlRange.first && l < lvlRange.second;
523-
}
524-
};
525-
526513
/// Returns the starting level of this sparse tensor type for a
527514
/// trailing COO region that spans **at least** two levels. If
528515
/// no such COO region is found, then returns the level-rank.

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ class SparseTensorType {
356356
RankedTensorType getCOOType(bool ordered) const;
357357

358358
/// Returns a list of COO segments in the sparse tensor types.
359-
SmallVector<SparseTensorEncodingAttr::COOSegment> getCOOSegments() const {
359+
SmallVector<COOSegment> getCOOSegments() const {
360360
return getEncoding().getCOOSegments();
361361
}
362362

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ void StorageLayout::foreachField(
104104
callback) const {
105105
const auto lvlTypes = enc.getLvlTypes();
106106
const Level lvlRank = enc.getLvlRank();
107-
SmallVector<SparseTensorEncodingAttr::COOSegment> cooSegs =
108-
enc.getCOOSegments();
107+
SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
109108
FieldIndex fieldIdx = kDataFieldStartingIdx;
110109

111110
ArrayRef cooSegsRef = cooSegs;
@@ -951,7 +950,7 @@ Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
951950
return getLvlRank();
952951
}
953952

954-
SmallVector<SparseTensorEncodingAttr::COOSegment>
953+
SmallVector<COOSegment>
955954
mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
956955
SmallVector<COOSegment> ret;
957956
if (getLvlRank() <= 1)

0 commit comments

Comments
 (0)