@@ -888,6 +888,19 @@ LogicalResult SparseTensorEncodingAttr::verify(
888
888
return success ();
889
889
}
890
890
891
+ Type SparseTensorEncodingAttr::getMismatchedValueType (Type elementType,
892
+ Attribute val) const {
893
+ Type type;
894
+ auto fVal = llvm::dyn_cast<FloatAttr>(val);
895
+ auto intVal = llvm::dyn_cast<IntegerAttr>(val);
896
+ if (fVal && fVal .getType () != elementType) {
897
+ type = fVal .getType ();
898
+ } else if (intVal && intVal.getType () != elementType) {
899
+ type = intVal.getType ();
900
+ }
901
+ return type;
902
+ }
903
+
891
904
LogicalResult SparseTensorEncodingAttr::verifyEncoding (
892
905
ArrayRef<Size > dimShape, Type elementType,
893
906
function_ref<InFlightDiagnostic()> emitError) const {
@@ -907,36 +920,24 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
907
920
return emitError ()
908
921
<< " dimension-rank mismatch between encoding and tensor shape: "
909
922
<< getDimRank () << " != " << dimRank;
910
- Type expType, impType ;
923
+ Type type ;
911
924
if (getExplicitVal ()) {
912
- auto fVal = llvm::dyn_cast<FloatAttr>(getExplicitVal ());
913
- auto intVal = llvm::dyn_cast<IntegerAttr>(getExplicitVal ());
914
- if (fVal && fVal .getType () != elementType) {
915
- expType = fVal .getType ();
916
- } else if (intVal && intVal.getType () != elementType) {
917
- expType = intVal.getType ();
918
- }
919
- if (expType) {
925
+ if ((type = getMismatchedValueType (elementType, getExplicitVal ()))) {
920
926
return emitError () << " explicit value type mismatch between encoding and "
921
- << " tensor element type: " << expType
927
+ << " tensor element type: " << type
922
928
<< " != " << elementType;
923
929
}
924
930
}
925
-
926
931
if (getImplicitVal ()) {
927
- auto impFVal = llvm::dyn_cast<FloatAttr>(getImplicitVal ());
928
- auto impIntVal = llvm::dyn_cast<IntegerAttr>(getImplicitVal ());
929
- if (impFVal && impFVal.getType () != elementType) {
930
- impType = impFVal.getType ();
931
- } else if (impIntVal && impIntVal.getType () != elementType) {
932
- impType = impIntVal.getType ();
933
- }
934
- if (impType) {
932
+ auto impVal = getImplicitVal ();
933
+ if ((type = getMismatchedValueType (elementType, impVal))) {
935
934
return emitError () << " implicit value type mismatch between encoding and "
936
- << " tensor element type: " << impType
935
+ << " tensor element type: " << type
937
936
<< " != " << elementType;
938
937
}
939
938
// Currently, we only support zero as the implicit value.
939
+ auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
940
+ auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
940
941
if ((impFVal && impFVal.getValueAsDouble () != 0.0 ) ||
941
942
(impIntVal && impIntVal.getInt () != 0 )) {
942
943
return emitError () << " implicit value must be zero" ;
0 commit comments