Skip to content

Commit bf8d54a

Browse files
add complex type verification
1 parent c96df1e commit bf8d54a

File tree

4 files changed

+90
-69
lines changed

4 files changed

+90
-69
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,37 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
502502
//
503503
// Helper function to translate between level/dimension space.
504504
//
505+
505506
SmallVector<int64_t> translateShape(::mlir::ArrayRef<int64_t> srcShape, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
506507
ValueRange translateCrds(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange crds, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
507508

509+
//
510+
// COO struct and methods.
511+
//
512+
513+
/// A simple structure that encodes a range of levels in the sparse tensors
514+
/// that forms a COO segment.
515+
struct COOSegment {
516+
std::pair<Level, Level> lvlRange; // [low, high)
517+
bool isSoA;
518+
519+
bool isAoS() const { return !isSoA; }
520+
bool isSegmentStart(Level l) const { return l == lvlRange.first; }
521+
bool inSegment(Level l) const {
522+
return l >= lvlRange.first && l < lvlRange.second;
523+
}
524+
};
525+
526+
/// Returns the starting level of this sparse tensor type for a
527+
/// trailing COO region that spans **at least** two levels. If
528+
/// no such COO region is found, then returns the level-rank.
529+
///
530+
/// DEPRECATED: use getCOOSegment instead;
531+
Level getAoSCOOStart() const;
532+
533+
/// Returns a list of COO segments in the sparse tensor types.
534+
SmallVector<COOSegment> getCOOSegments() const;
535+
508536
//
509537
// Printing methods.
510538
//

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,6 @@
1818
namespace mlir {
1919
namespace sparse_tensor {
2020

21-
/// A simple structure that encodes a range of levels in the sparse tensors that
22-
/// forms a COO segment.
23-
struct COOSegment {
24-
std::pair<Level, Level> lvlRange; // [low, high)
25-
bool isSoA;
26-
27-
bool isAoS() const { return !isSoA; }
28-
bool isSegmentStart(Level l) const { return l == lvlRange.first; }
29-
bool inSegment(Level l) const {
30-
return l >= lvlRange.first && l < lvlRange.second;
31-
}
32-
};
3321

3422
//===----------------------------------------------------------------------===//
3523
/// A wrapper around `RankedTensorType`, which has three goals:
@@ -73,11 +61,6 @@ class SparseTensorType {
7361
: SparseTensorType(
7462
RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}
7563

76-
// TODO: remove?
77-
SparseTensorType(SparseTensorEncodingAttr enc)
78-
: SparseTensorType(RankedTensorType::get(
79-
SmallVector<Size>(enc.getDimRank(), ShapedType::kDynamic),
80-
Float32Type::get(enc.getContext()), enc)) {}
8164

8265
SparseTensorType &operator=(const SparseTensorType &) = delete;
8366
SparseTensorType(const SparseTensorType &) = default;
@@ -369,13 +352,15 @@ class SparseTensorType {
369352
/// no such COO region is found, then returns the level-rank.
370353
///
371354
/// DEPRECATED: use getCOOSegment instead;
372-
Level getAoSCOOStart() const;
355+
Level getAoSCOOStart() const { return getEncoding().getAoSCOOStart(); };
373356

374357
/// Returns [un]ordered COO type for this sparse tensor type.
375358
RankedTensorType getCOOType(bool ordered) const;
376359

377360
/// Returns a list of COO segments in the sparse tensor types.
378-
SmallVector<COOSegment> getCOOSegments() const;
361+
SmallVector<SparseTensorEncodingAttr::COOSegment> getCOOSegments() const {
362+
return getEncoding().getCOOSegments();
363+
}
379364

380365
private:
381366
// These two must be const, to ensure coherence of the memoized fields.

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 45 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ void StorageLayout::foreachField(
104104
callback) const {
105105
const auto lvlTypes = enc.getLvlTypes();
106106
const Level lvlRank = enc.getLvlRank();
107-
SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments();
107+
SmallVector<SparseTensorEncodingAttr::COOSegment> cooSegs =
108+
enc.getCOOSegments();
108109
FieldIndex fieldIdx = kDataFieldStartingIdx;
109110

110111
ArrayRef cooSegsRef = cooSegs;
@@ -211,7 +212,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
211212
unsigned stride = 1;
212213
if (kind == SparseTensorFieldKind::CrdMemRef) {
213214
assert(lvl.has_value());
214-
const Level cooStart = SparseTensorType(enc).getAoSCOOStart();
215+
const Level cooStart = enc.getAoSCOOStart();
215216
const Level lvlRank = enc.getLvlRank();
216217
if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
217218
lvl = cooStart;
@@ -912,78 +913,53 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
912913
return emitError()
913914
<< "dimension-rank mismatch between encoding and tensor shape: "
914915
<< getDimRank() << " != " << dimRank;
915-
if (getExplicitVal()) {
916-
if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getExplicitVal())) {
917-
Type attrType = typedAttr.getType();
918-
if (attrType != elementType) {
919-
return emitError()
920-
<< "explicit value type mismatch between encoding and "
921-
<< "tensor element type: " << attrType << " != " << elementType;
922-
}
923-
} else {
924-
return emitError() << "expected typed explicit value";
916+
if (auto expVal = getExplicitVal()) {
917+
Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
918+
if (attrType != elementType) {
919+
return emitError() << "explicit value type mismatch between encoding and "
920+
<< "tensor element type: " << attrType
921+
<< " != " << elementType;
925922
}
926923
}
927-
if (getImplicitVal()) {
928-
auto impVal = getImplicitVal();
929-
if (auto typedAttr = llvm::dyn_cast<TypedAttr>(impVal)) {
930-
Type attrType = typedAttr.getType();
931-
if (attrType != elementType) {
932-
return emitError()
933-
<< "implicit value type mismatch between encoding and "
934-
<< "tensor element type: " << attrType << " != " << elementType;
935-
}
936-
} else {
937-
return emitError() << "expected typed implicit value";
924+
if (auto impVal = getImplicitVal()) {
925+
Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
926+
if (attrType != elementType) {
927+
return emitError() << "implicit value type mismatch between encoding and "
928+
<< "tensor element type: " << attrType
929+
<< " != " << elementType;
938930
}
939931
// Currently, we only support zero as the implicit value.
940932
auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
941933
auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
942-
if ((impFVal && impFVal.getValueAsDouble() != 0.0) ||
943-
(impIntVal && impIntVal.getInt() != 0)) {
934+
auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
935+
if ((impFVal && impFVal.getValue().isNonZero()) ||
936+
(impIntVal && !impIntVal.getValue().isZero()) ||
937+
(impComplexVal && (impComplexVal.getImag().isNonZero() ||
938+
impComplexVal.getReal().isNonZero()))) {
944939
return emitError() << "implicit value must be zero";
945940
}
946941
}
947942
return success();
948943
}
949944

950-
//===----------------------------------------------------------------------===//
951-
// SparseTensorType Methods.
952-
//===----------------------------------------------------------------------===//
953-
954-
bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
955-
bool isUnique) const {
956-
if (!hasEncoding())
957-
return false;
958-
if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
959-
return false;
960-
for (Level l = startLvl + 1; l < lvlRank; ++l)
961-
if (!isSingletonLvl(l))
962-
return false;
963-
// If isUnique is true, then make sure that the last level is unique,
964-
// that is, when lvlRank == 1, the only compressed level is unique,
965-
// and when lvlRank > 1, the last singleton is unique.
966-
return !isUnique || isUniqueLvl(lvlRank - 1);
967-
}
968-
969-
Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const {
945+
Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
970946
SmallVector<COOSegment> coo = getCOOSegments();
971947
assert(coo.size() == 1 || coo.empty());
972948
if (!coo.empty() && coo.front().isAoS()) {
973949
return coo.front().lvlRange.first;
974950
}
975-
return lvlRank;
951+
return getLvlRank();
976952
}
977953

978-
SmallVector<COOSegment>
979-
mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
954+
SmallVector<SparseTensorEncodingAttr::COOSegment>
955+
mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
980956
SmallVector<COOSegment> ret;
981-
if (!hasEncoding() || lvlRank <= 1)
957+
if (getLvlRank() <= 1)
982958
return ret;
983959

984960
ArrayRef<LevelType> lts = getLvlTypes();
985961
Level l = 0;
986-
while (l < lvlRank) {
962+
while (l < getLvlRank()) {
987963
auto lt = lts[l];
988964
if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
989965
auto cur = lts.begin() + l;
@@ -1007,6 +983,25 @@ mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
1007983
return ret;
1008984
}
1009985

986+
//===----------------------------------------------------------------------===//
987+
// SparseTensorType Methods.
988+
//===----------------------------------------------------------------------===//
989+
990+
bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
991+
bool isUnique) const {
992+
if (!hasEncoding())
993+
return false;
994+
if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
995+
return false;
996+
for (Level l = startLvl + 1; l < lvlRank; ++l)
997+
if (!isSingletonLvl(l))
998+
return false;
999+
// If isUnique is true, then make sure that the last level is unique,
1000+
// that is, when lvlRank == 1, the only compressed level is unique,
1001+
// and when lvlRank > 1, the last singleton is unique.
1002+
return !isUnique || isUniqueLvl(lvlRank - 1);
1003+
}
1004+
10101005
RankedTensorType
10111006
mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
10121007
SmallVector<LevelType> lvlTypes;

mlir/test/Dialect/SparseTensor/invalid_encoding.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,16 @@ func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpVal>)
515515

516516
// expected-error@+1 {{implicit value must be zero}}
517517
func.func private @sparse_csr(tensor<?x?xf32, #CSR_ImpVal>)
518+
519+
// -----
520+
521+
#CSR_OnlyOnes = #sparse_tensor.encoding<{
522+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
523+
posWidth = 64,
524+
crdWidth = 64,
525+
explicitVal = #complex.number<:f32 1.0, 0.0>,
526+
implicitVal = #complex.number<:f32 1.0, 0.0>
527+
}>
528+
529+
// expected-error@+1 {{implicit value must be zero}}
530+
func.func private @sparse_csr(tensor<?x?xcomplex<f32>, #CSR_OnlyOnes>)

0 commit comments

Comments
 (0)