@@ -104,7 +104,8 @@ void StorageLayout::foreachField(
104
104
callback) const {
105
105
const auto lvlTypes = enc.getLvlTypes ();
106
106
const Level lvlRank = enc.getLvlRank ();
107
- SmallVector<COOSegment> cooSegs = SparseTensorType (enc).getCOOSegments ();
107
+ SmallVector<SparseTensorEncodingAttr::COOSegment> cooSegs =
108
+ enc.getCOOSegments ();
108
109
FieldIndex fieldIdx = kDataFieldStartingIdx ;
109
110
110
111
ArrayRef cooSegsRef = cooSegs;
@@ -211,7 +212,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
211
212
unsigned stride = 1 ;
212
213
if (kind == SparseTensorFieldKind::CrdMemRef) {
213
214
assert (lvl.has_value ());
214
- const Level cooStart = SparseTensorType ( enc) .getAoSCOOStart ();
215
+ const Level cooStart = enc.getAoSCOOStart ();
215
216
const Level lvlRank = enc.getLvlRank ();
216
217
if (lvl.value () >= cooStart && lvl.value () < lvlRank) {
217
218
lvl = cooStart;
@@ -912,78 +913,53 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
912
913
return emitError ()
913
914
<< " dimension-rank mismatch between encoding and tensor shape: "
914
915
<< getDimRank () << " != " << dimRank;
915
- if (getExplicitVal ()) {
916
- if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getExplicitVal ())) {
917
- Type attrType = typedAttr.getType ();
918
- if (attrType != elementType) {
919
- return emitError ()
920
- << " explicit value type mismatch between encoding and "
921
- << " tensor element type: " << attrType << " != " << elementType;
922
- }
923
- } else {
924
- return emitError () << " expected typed explicit value" ;
916
+ if (auto expVal = getExplicitVal ()) {
917
+ Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType ();
918
+ if (attrType != elementType) {
919
+ return emitError () << " explicit value type mismatch between encoding and "
920
+ << " tensor element type: " << attrType
921
+ << " != " << elementType;
925
922
}
926
923
}
927
- if (getImplicitVal ()) {
928
- auto impVal = getImplicitVal ();
929
- if (auto typedAttr = llvm::dyn_cast<TypedAttr>(impVal)) {
930
- Type attrType = typedAttr.getType ();
931
- if (attrType != elementType) {
932
- return emitError ()
933
- << " implicit value type mismatch between encoding and "
934
- << " tensor element type: " << attrType << " != " << elementType;
935
- }
936
- } else {
937
- return emitError () << " expected typed implicit value" ;
924
+ if (auto impVal = getImplicitVal ()) {
925
+ Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType ();
926
+ if (attrType != elementType) {
927
+ return emitError () << " implicit value type mismatch between encoding and "
928
+ << " tensor element type: " << attrType
929
+ << " != " << elementType;
938
930
}
939
931
// Currently, we only support zero as the implicit value.
940
932
auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
941
933
auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
942
- if ((impFVal && impFVal.getValueAsDouble () != 0.0 ) ||
943
- (impIntVal && impIntVal.getInt () != 0 )) {
934
+ auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
935
+ if ((impFVal && impFVal.getValue ().isNonZero ()) ||
936
+ (impIntVal && !impIntVal.getValue ().isZero ()) ||
937
+ (impComplexVal && (impComplexVal.getImag ().isNonZero () ||
938
+ impComplexVal.getReal ().isNonZero ()))) {
944
939
return emitError () << " implicit value must be zero" ;
945
940
}
946
941
}
947
942
return success ();
948
943
}
949
944
950
- // ===----------------------------------------------------------------------===//
951
- // SparseTensorType Methods.
952
- // ===----------------------------------------------------------------------===//
953
-
954
- bool mlir::sparse_tensor::SparseTensorType::isCOOType (Level startLvl,
955
- bool isUnique) const {
956
- if (!hasEncoding ())
957
- return false ;
958
- if (!isCompressedLvl (startLvl) && !isLooseCompressedLvl (startLvl))
959
- return false ;
960
- for (Level l = startLvl + 1 ; l < lvlRank; ++l)
961
- if (!isSingletonLvl (l))
962
- return false ;
963
- // If isUnique is true, then make sure that the last level is unique,
964
- // that is, when lvlRank == 1, the only compressed level is unique,
965
- // and when lvlRank > 1, the last singleton is unique.
966
- return !isUnique || isUniqueLvl (lvlRank - 1 );
967
- }
968
-
969
- Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart () const {
945
+ Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart () const {
970
946
SmallVector<COOSegment> coo = getCOOSegments ();
971
947
assert (coo.size () == 1 || coo.empty ());
972
948
if (!coo.empty () && coo.front ().isAoS ()) {
973
949
return coo.front ().lvlRange .first ;
974
950
}
975
- return lvlRank ;
951
+ return getLvlRank () ;
976
952
}
977
953
978
- SmallVector<COOSegment>
979
- mlir::sparse_tensor::SparseTensorType ::getCOOSegments () const {
954
+ SmallVector<SparseTensorEncodingAttr:: COOSegment>
955
+ mlir::sparse_tensor::SparseTensorEncodingAttr ::getCOOSegments () const {
980
956
SmallVector<COOSegment> ret;
981
- if (! hasEncoding () || lvlRank <= 1 )
957
+ if (getLvlRank () <= 1 )
982
958
return ret;
983
959
984
960
ArrayRef<LevelType> lts = getLvlTypes ();
985
961
Level l = 0 ;
986
- while (l < lvlRank ) {
962
+ while (l < getLvlRank () ) {
987
963
auto lt = lts[l];
988
964
if (lt.isa <LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
989
965
auto cur = lts.begin () + l;
@@ -1007,6 +983,25 @@ mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
1007
983
return ret;
1008
984
}
1009
985
986
+ // ===----------------------------------------------------------------------===//
987
+ // SparseTensorType Methods.
988
+ // ===----------------------------------------------------------------------===//
989
+
990
+ bool mlir::sparse_tensor::SparseTensorType::isCOOType (Level startLvl,
991
+ bool isUnique) const {
992
+ if (!hasEncoding ())
993
+ return false ;
994
+ if (!isCompressedLvl (startLvl) && !isLooseCompressedLvl (startLvl))
995
+ return false ;
996
+ for (Level l = startLvl + 1 ; l < lvlRank; ++l)
997
+ if (!isSingletonLvl (l))
998
+ return false ;
999
+ // If isUnique is true, then make sure that the last level is unique,
1000
+ // that is, when lvlRank == 1, the only compressed level is unique,
1001
+ // and when lvlRank > 1, the last singleton is unique.
1002
+ return !isUnique || isUniqueLvl (lvlRank - 1 );
1003
+ }
1004
+
1010
1005
RankedTensorType
1011
1006
mlir::sparse_tensor::SparseTensorType::getCOOType (bool ordered) const {
1012
1007
SmallVector<LevelType> lvlTypes;
0 commit comments