-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][tosa] Add error if checks Variable Operators #137291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir-linalg Author: TatWai Chong (tatwaichong) ChangesFor VARIABLE, VARIABLE_WRITE & VARIABLE_READ Full diff: https://github.com/llvm/llvm-project/pull/137291.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 0ab0a62f1cf11..6e5f6317816f2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -131,6 +131,8 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
let assemblyFormat = [{
$name attr-dict `,` $input1 `:` type($input1)
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -159,6 +161,8 @@ def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
let assemblyFormat = [{
$name attr-dict `:` type($output1)
}];
+
+ let hasVerifier = 1;
}
#endif // TOSA_UTIL_OPS
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 751ae785bda6f..b1312afbbf6d4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -572,6 +572,74 @@ static LogicalResult verifyConvOpErrorIf(T op) {
return success();
}
+// Verify whether same type and shape of the given two types.
+static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
+ StringRef name1, Type type2,
+ StringRef name2) {
+ auto shapeType1 = dyn_cast<ShapedType>(type1);
+ auto shapeType2 = dyn_cast<ShapedType>(type2);
+ if (!shapeType1 || !shapeType2)
+ return failure();
+
+ auto elemType1 = shapeType1.getElementType();
+ auto elemType2 = shapeType2.getElementType();
+ if (elemType1 != elemType2)
+ return op->emitOpError()
+ << "require same element type for " << name1 << " (" << elemType1
+ << ") and " << name2 << " (" << elemType2 << ")";
+
+ if (failed(verifyCompatibleShape(type1, type2)))
+ return op->emitOpError()
+ << "require same shapes for " << name1 << " (" << type1 << ") and "
+ << name2 << " (" << type2 << ")";
+
+ return success();
+}
+
+template <typename T>
+static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
+ // Currently, the variable's definition point is searched via walk(),
+ // starting from the top-level ModuleOp and stopping at the point of use. Once
+ // TOSA control flow and variable extensions reach the complete state, may
+ // leverage MLIR's Symbol Table functionality to look up symbol and enhance
+ // the search to a TOSA specific graph traversal over the IR structure.
+ StringRef symName = op.getName();
+ tosa::VariableOp varOp = nullptr;
+ auto thisOp = op.getOperation();
+ ModuleOp module = thisOp->template getParentOfType<ModuleOp>();
+ bool found = false;
+
+ module.walk([&](Operation *tempOp) {
+ // Reach this op itself.
+ if (tempOp == thisOp)
+ return;
+
+ if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
+ if (symName == tosaOp.getName()) {
+ if (found == true) {
+ op->emitOpError("illegal to have multiple declaration of '")
+ << symName << "'";
+ return;
+ }
+ found = true;
+ varOp = tosaOp;
+ }
+ }
+ });
+
+ if (found == false)
+ return op->emitOpError("'")
+ << symName << "' has not been declared by 'tosa.variable'";
+
+ // Verify type and shape
+ Type varType = cast<tosa::VariableOp>(varOp).getType();
+ if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
+ .failed())
+ return failure();
+
+ return success();
+}
+
// verify that inType and outType have same element types
template <typename T>
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3455,6 +3523,22 @@ LogicalResult tosa::SelectOp::verify() {
return success();
}
+LogicalResult tosa::VariableReadOp::verify() {
+ if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
+ .failed())
+ return failure();
+
+ return success();
+}
+
+LogicalResult tosa::VariableWriteOp::verify() {
+ if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
+ .failed())
+ return failure();
+
+ return success();
+}
+
// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index 37ed5cec00a0d..74706c426ea9c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -1,16 +1,6 @@
// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
-// -----
-
-// check that -tosa-validate of stateful ops kick in
-func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
- tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
- return
-}
-
// -----
// check that -tosa-validate level checking kick in
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b147c94fde9b0..eba65eabe97fb 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -595,7 +595,7 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+ // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
return
}
@@ -604,7 +604,7 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+ // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
return
}
@@ -613,7 +613,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error@+1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
+ // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
return
}
@@ -622,7 +622,7 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
+ // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
index 6fa6b26155461..25f63331f39df 100644
--- a/mlir/test/Dialect/Tosa/variables.mlir
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --split-input-file | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --split-input-file --mlir-print-op-generic | mlir-opt | FileCheck %s
// -----
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 262e6d4265ea6..3d2505c27ee58 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -358,3 +358,93 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
return %0 : tensor<2x?xf32>
}
+
+// -----
+
+func.func @test_variable_shape_mismatch() -> () {
+ // expected-error@+1 {{inferred shape of elements literal ([2]) does not match type ([3])}}
+ tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32>
+ // expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
+ return
+}
+
+// -----
+
+func.func @test_variable_type_mismatch() -> () {
+ // expected-error@+1 {{expected integer elements, but parsed floating-point}}
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32>
+ // expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
+ return
+}
+
+// -----
+
+func.func @test_variable_read_no_declaration() -> () {
+ // expected-error@+1 {{'tosa.variable_read' op 'stored_var' has not been declared by 'tosa.variable'}}
+ %0 = tosa.variable_read @stored_var : tensor<f32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_multiple_declaration() -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error@+1 {{'tosa.variable_read' op illegal to have multiple declaration of 'stored_var'}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_type_mismatch() -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
+ // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_shape_mismatch() -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
+ // expected-error@+1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xf32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_no_declaration(%arg0: tensor<f32>) -> () {
+ // expected-error@+1 {{'tosa.variable_write' op 'stored_var' has not been declared by 'tosa.variable'}}
+ tosa.variable_write @stored_var, %arg0 : tensor<f32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_multiple_declaration(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error@+1 {{'tosa.variable_write' op illegal to have multiple declaration of 'stored_var'}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
+ // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
+ // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
+ return
+}
|
@llvm/pr-subscribers-mlir Author: TatWai Chong (tatwaichong) ChangesFor VARIABLE, VARIABLE_WRITE & VARIABLE_READ Full diff: https://github.com/llvm/llvm-project/pull/137291.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 0ab0a62f1cf11..6e5f6317816f2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -131,6 +131,8 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
let assemblyFormat = [{
$name attr-dict `,` $input1 `:` type($input1)
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -159,6 +161,8 @@ def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
let assemblyFormat = [{
$name attr-dict `:` type($output1)
}];
+
+ let hasVerifier = 1;
}
#endif // TOSA_UTIL_OPS
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 751ae785bda6f..b1312afbbf6d4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -572,6 +572,74 @@ static LogicalResult verifyConvOpErrorIf(T op) {
return success();
}
+// Verify whether same type and shape of the given two types.
+static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
+ StringRef name1, Type type2,
+ StringRef name2) {
+ auto shapeType1 = dyn_cast<ShapedType>(type1);
+ auto shapeType2 = dyn_cast<ShapedType>(type2);
+ if (!shapeType1 || !shapeType2)
+ return failure();
+
+ auto elemType1 = shapeType1.getElementType();
+ auto elemType2 = shapeType2.getElementType();
+ if (elemType1 != elemType2)
+ return op->emitOpError()
+ << "require same element type for " << name1 << " (" << elemType1
+ << ") and " << name2 << " (" << elemType2 << ")";
+
+ if (failed(verifyCompatibleShape(type1, type2)))
+ return op->emitOpError()
+ << "require same shapes for " << name1 << " (" << type1 << ") and "
+ << name2 << " (" << type2 << ")";
+
+ return success();
+}
+
+template <typename T>
+static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
+ // Currently, the variable's definition point is searched via walk(),
+ // starting from the top-level ModuleOp and stopping at the point of use. Once
+ // TOSA control flow and variable extensions reach the complete state, may
+ // leverage MLIR's Symbol Table functionality to look up symbol and enhance
+ // the search to a TOSA specific graph traversal over the IR structure.
+ StringRef symName = op.getName();
+ tosa::VariableOp varOp = nullptr;
+ auto thisOp = op.getOperation();
+ ModuleOp module = thisOp->template getParentOfType<ModuleOp>();
+ bool found = false;
+
+ module.walk([&](Operation *tempOp) {
+ // Reach this op itself.
+ if (tempOp == thisOp)
+ return;
+
+ if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
+ if (symName == tosaOp.getName()) {
+ if (found == true) {
+ op->emitOpError("illegal to have multiple declaration of '")
+ << symName << "'";
+ return;
+ }
+ found = true;
+ varOp = tosaOp;
+ }
+ }
+ });
+
+ if (found == false)
+ return op->emitOpError("'")
+ << symName << "' has not been declared by 'tosa.variable'";
+
+ // Verify type and shape
+ Type varType = cast<tosa::VariableOp>(varOp).getType();
+ if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
+ .failed())
+ return failure();
+
+ return success();
+}
+
// verify that inType and outType have same element types
template <typename T>
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3455,6 +3523,22 @@ LogicalResult tosa::SelectOp::verify() {
return success();
}
+LogicalResult tosa::VariableReadOp::verify() {
+ if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
+ .failed())
+ return failure();
+
+ return success();
+}
+
+LogicalResult tosa::VariableWriteOp::verify() {
+ if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
+ .failed())
+ return failure();
+
+ return success();
+}
+
// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index 37ed5cec00a0d..74706c426ea9c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -1,16 +1,6 @@
// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
-// -----
-
-// check that -tosa-validate of stateful ops kick in
-func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
- tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
- tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
- return
-}
-
// -----
// check that -tosa-validate level checking kick in
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b147c94fde9b0..eba65eabe97fb 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -595,7 +595,7 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+ // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
return
}
@@ -604,7 +604,7 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
+ // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
return
}
@@ -613,7 +613,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error@+1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
+ // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
return
}
@@ -622,7 +622,7 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
- // expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
+ // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
index 6fa6b26155461..25f63331f39df 100644
--- a/mlir/test/Dialect/Tosa/variables.mlir
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --split-input-file | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --split-input-file --mlir-print-op-generic | mlir-opt | FileCheck %s
// -----
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 262e6d4265ea6..3d2505c27ee58 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -358,3 +358,93 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
return %0 : tensor<2x?xf32>
}
+
+// -----
+
+func.func @test_variable_shape_mismatch() -> () {
+ // expected-error@+1 {{inferred shape of elements literal ([2]) does not match type ([3])}}
+ tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32>
+ // expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
+ return
+}
+
+// -----
+
+func.func @test_variable_type_mismatch() -> () {
+ // expected-error@+1 {{expected integer elements, but parsed floating-point}}
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32>
+ // expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
+ return
+}
+
+// -----
+
+func.func @test_variable_read_no_declaration() -> () {
+ // expected-error@+1 {{'tosa.variable_read' op 'stored_var' has not been declared by 'tosa.variable'}}
+ %0 = tosa.variable_read @stored_var : tensor<f32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_multiple_declaration() -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error@+1 {{'tosa.variable_read' op illegal to have multiple declaration of 'stored_var'}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_type_mismatch() -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
+ // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_read_shape_mismatch() -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
+ // expected-error@+1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+ %0 = tosa.variable_read @stored_var : tensor<2x4x8xf32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_no_declaration(%arg0: tensor<f32>) -> () {
+ // expected-error@+1 {{'tosa.variable_write' op 'stored_var' has not been declared by 'tosa.variable'}}
+ tosa.variable_write @stored_var, %arg0 : tensor<f32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_multiple_declaration(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+ // expected-error@+1 {{'tosa.variable_write' op illegal to have multiple declaration of 'stored_var'}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
+ // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
+ return
+}
+
+// -----
+
+func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
+ tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
+ // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
+ tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
+ return
+}
|
Please update to resolve merge conflicts. |
e3a3d4e
to
37d9d8f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Left a minor question regarding the removed LIT test.
Also left a comment for a possible future update.
37d9d8f
to
1b6299c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates @tatwaichong, just needs a rebase
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
1b6299c
to
eeb0434
Compare
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ