Skip to content

Commit bedbe3e

Browse files
add verification for explicit/implicit values
1 parent b9f2c16 commit bedbe3e

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,41 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
907907
return emitError()
908908
<< "dimension-rank mismatch between encoding and tensor shape: "
909909
<< getDimRank() << " != " << dimRank;
910+
Type expType, impType;
911+
if (getExplicitVal()) {
912+
auto fVal = llvm::dyn_cast<FloatAttr>(getExplicitVal());
913+
auto intVal = llvm::dyn_cast<IntegerAttr>(getExplicitVal());
914+
if (fVal && fVal.getType() != elementType) {
915+
expType = fVal.getType();
916+
} else if (intVal && intVal.getType() != elementType) {
917+
expType = intVal.getType();
918+
}
919+
if (expType) {
920+
return emitError() << "explicit value type mismatch between encoding and "
921+
<< "tensor element type: " << expType
922+
<< " != " << elementType;
923+
}
924+
}
925+
926+
if (getImplicitVal()) {
927+
auto impFVal = llvm::dyn_cast<FloatAttr>(getImplicitVal());
928+
auto impIntVal = llvm::dyn_cast<IntegerAttr>(getImplicitVal());
929+
if (impFVal && impFVal.getType() != elementType) {
930+
impType = impFVal.getType();
931+
} else if (impIntVal && impIntVal.getType() != elementType) {
932+
impType = impIntVal.getType();
933+
}
934+
if (impType) {
935+
return emitError() << "implicit value type mismatch between encoding and "
936+
<< "tensor element type: " << impType
937+
<< " != " << elementType;
938+
}
939+
// Currently, we only support zero as the implicit value.
940+
if ((impFVal && impFVal.getValueAsDouble() != 0.0) ||
941+
(impIntVal && impIntVal.getInt() != 0)) {
942+
return emitError() << "implicit value must be zero";
943+
}
944+
}
910945
return success();
911946
}
912947

mlir/test/Dialect/SparseTensor/invalid_encoding.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,3 +443,75 @@ func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
443443
func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
444444
return
445445
}
446+
447+
// -----
448+
449+
#CSR_ExpType = #sparse_tensor.encoding<{
450+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
451+
posWidth = 32,
452+
crdWidth = 32,
453+
explicitVal = 1 : i32,
454+
implicitVal = 0.0 : f32
455+
}>
456+
457+
// expected-error@+1 {{explicit value type mismatch between encoding and tensor element type: 'i32' != 'f32'}}
458+
func.func private @sparse_csr(tensor<?x?xf32, #CSR_ExpType>)
459+
460+
// -----
461+
462+
#CSR_ImpType = #sparse_tensor.encoding<{
463+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
464+
posWidth = 32,
465+
crdWidth = 32,
466+
explicitVal = 1 : i32,
467+
implicitVal = 0.0 : f32
468+
}>
469+
470+
// expected-error@+1 {{implicit value type mismatch between encoding and tensor element type: 'f32' != 'i32'}}
471+
func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpType>)
472+
473+
// -----
474+
475+
// expected-error@+1 {{expected a numeric value for explicitVal}}
476+
#CSR_ExpType = #sparse_tensor.encoding<{
477+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
478+
posWidth = 32,
479+
crdWidth = 32,
480+
explicitVal = "str"
481+
}>
482+
func.func private @sparse_csr(tensor<?x?xi32, #CSR_ExpType>)
483+
484+
// -----
485+
486+
// expected-error@+1 {{expected a numeric value for implicitVal}}
487+
#CSR_ImpType = #sparse_tensor.encoding<{
488+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
489+
posWidth = 32,
490+
crdWidth = 32,
491+
implicitVal = "str"
492+
}>
493+
func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpType>)
494+
495+
// -----
496+
497+
#CSR_ImpVal = #sparse_tensor.encoding<{
498+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
499+
posWidth = 32,
500+
crdWidth = 32,
501+
implicitVal = 1 : i32
502+
}>
503+
504+
// expected-error@+1 {{implicit value must be zero}}
505+
func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpVal>)
506+
507+
// -----
508+
509+
#CSR_ImpVal = #sparse_tensor.encoding<{
510+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
511+
posWidth = 32,
512+
crdWidth = 32,
513+
implicitVal = 1.0 : f32
514+
}>
515+
516+
// expected-error@+1 {{implicit value must be zero}}
517+
func.func private @sparse_csr(tensor<?x?xf32, #CSR_ImpVal>)

0 commit comments

Comments
 (0)