Skip to content

Commit f4f58e4

Browse files
new function
1 parent bedbe3e commit f4f58e4

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,11 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
512512
void printSymbols(AffineMap &map, AsmPrinter &printer) const;
513513
void printDimensions(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
514514
void printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::LevelType> lvlTypes) const;
515+
516+
//
517+
// Explicit/implicit value methods.
518+
//
519+
Type getMismatchedValueType(Type elementType, Attribute val) const;
515520
}];
516521

517522
let genVerifyDecl = 1;

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,19 @@ LogicalResult SparseTensorEncodingAttr::verify(
888888
return success();
889889
}
890890

891+
Type SparseTensorEncodingAttr::getMismatchedValueType(Type elementType,
892+
Attribute val) const {
893+
Type type;
894+
auto fVal = llvm::dyn_cast<FloatAttr>(val);
895+
auto intVal = llvm::dyn_cast<IntegerAttr>(val);
896+
if (fVal && fVal.getType() != elementType) {
897+
type = fVal.getType();
898+
} else if (intVal && intVal.getType() != elementType) {
899+
type = intVal.getType();
900+
}
901+
return type;
902+
}
903+
891904
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
892905
ArrayRef<Size> dimShape, Type elementType,
893906
function_ref<InFlightDiagnostic()> emitError) const {
@@ -907,36 +920,24 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
907920
return emitError()
908921
<< "dimension-rank mismatch between encoding and tensor shape: "
909922
<< getDimRank() << " != " << dimRank;
910-
Type expType, impType;
923+
Type type;
911924
if (getExplicitVal()) {
912-
auto fVal = llvm::dyn_cast<FloatAttr>(getExplicitVal());
913-
auto intVal = llvm::dyn_cast<IntegerAttr>(getExplicitVal());
914-
if (fVal && fVal.getType() != elementType) {
915-
expType = fVal.getType();
916-
} else if (intVal && intVal.getType() != elementType) {
917-
expType = intVal.getType();
918-
}
919-
if (expType) {
925+
if ((type = getMismatchedValueType(elementType, getExplicitVal()))) {
920926
return emitError() << "explicit value type mismatch between encoding and "
921-
<< "tensor element type: " << expType
927+
<< "tensor element type: " << type
922928
<< " != " << elementType;
923929
}
924930
}
925-
926931
if (getImplicitVal()) {
927-
auto impFVal = llvm::dyn_cast<FloatAttr>(getImplicitVal());
928-
auto impIntVal = llvm::dyn_cast<IntegerAttr>(getImplicitVal());
929-
if (impFVal && impFVal.getType() != elementType) {
930-
impType = impFVal.getType();
931-
} else if (impIntVal && impIntVal.getType() != elementType) {
932-
impType = impIntVal.getType();
933-
}
934-
if (impType) {
932+
auto impVal = getImplicitVal();
933+
if ((type = getMismatchedValueType(elementType, impVal))) {
935934
return emitError() << "implicit value type mismatch between encoding and "
936-
<< "tensor element type: " << impType
935+
<< "tensor element type: " << type
937936
<< " != " << elementType;
938937
}
939938
// Currently, we only support zero as the implicit value.
939+
auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
940+
auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
940941
if ((impFVal && impFVal.getValueAsDouble() != 0.0) ||
941942
(impIntVal && impIntVal.getInt() != 0)) {
942943
return emitError() << "implicit value must be zero";

0 commit comments

Comments
 (0)