Skip to content

Commit 30bedb3

Browse files
authored
[mlir][tosa] Add error if checks Variable Operators (#137291)
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
1 parent 9d1f1c4 commit 30bedb3

File tree

6 files changed

+170
-17
lines changed

6 files changed

+170
-17
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
106106
attr-dict
107107
custom<TypeOrAttr>($type, $initial_value)
108108
}];
109+
110+
let hasVerifier = 1;
109111
}
110112

111113
//===----------------------------------------------------------------------===//
@@ -131,6 +133,8 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
131133
let assemblyFormat = [{
132134
$name attr-dict `,` $input1 `:` type($input1)
133135
}];
136+
137+
let hasVerifier = 1;
134138
}
135139

136140
//===----------------------------------------------------------------------===//
@@ -159,6 +163,8 @@ def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
159163
let assemblyFormat = [{
160164
$name attr-dict `:` type($output1)
161165
}];
166+
167+
let hasVerifier = 1;
162168
}
163169

164170
#endif // TOSA_UTIL_OPS

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,58 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
613613
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
614614
}
615615

616+
// Returns the first declaration point prior to this operation or failure if
617+
// not found.
618+
static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
619+
StringRef symName) {
620+
ModuleOp module = op->getParentOfType<ModuleOp>();
621+
tosa::VariableOp varOp = nullptr;
622+
623+
// TODO: Adopt SymbolTable trait to Varible ops.
624+
// Currently, the variable's definition point is searched via walk(),
625+
// starting from the top-level ModuleOp and stopping at the point of use. Once
626+
// TOSA control flow and variable extensions reach the complete state, may
627+
// leverage MLIR's Symbol Table functionality to look up symbol and enhance
628+
// the search to a TOSA specific graph traversal over the IR structure.
629+
module.walk([&](Operation *tempOp) {
630+
// Reach this op itself.
631+
if (tempOp == op) {
632+
return WalkResult::interrupt();
633+
}
634+
635+
if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
636+
if (symName == tosaOp.getName()) {
637+
varOp = tosaOp;
638+
return WalkResult::interrupt();
639+
}
640+
}
641+
642+
return WalkResult::advance();
643+
});
644+
645+
if (varOp)
646+
return varOp;
647+
648+
return failure();
649+
}
650+
651+
template <typename T>
652+
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
653+
StringRef symName = op.getName();
654+
FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
655+
if (failed(varOp))
656+
return op->emitOpError("'")
657+
<< symName << "' has not been declared by 'tosa.variable'";
658+
659+
// Verify type and shape
660+
Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
661+
if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
662+
.failed())
663+
return failure();
664+
665+
return success();
666+
}
667+
616668
// verify that inType and outType have same element types
617669
template <typename T>
618670
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3660,6 +3712,32 @@ LogicalResult tosa::SelectOp::verify() {
36603712
return success();
36613713
}
36623714

3715+
LogicalResult tosa::VariableOp::verify() {
3716+
StringRef symName = getName();
3717+
FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
3718+
if (succeeded(varOp))
3719+
return emitOpError("illegal to have multiple declaration of '")
3720+
<< symName << "'";
3721+
3722+
return success();
3723+
}
3724+
3725+
LogicalResult tosa::VariableReadOp::verify() {
3726+
if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
3727+
.failed())
3728+
return failure();
3729+
3730+
return success();
3731+
}
3732+
3733+
LogicalResult tosa::VariableWriteOp::verify() {
3734+
if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
3735+
.failed())
3736+
return failure();
3737+
3738+
return success();
3739+
}
3740+
36633741
// parse and print of WhileOp refer to the implementation of SCF dialect.
36643742
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
36653743
SmallVector<OpAsmParser::Argument, 4> regionArgs;

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
11
// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
22

33

4-
// -----
5-
6-
// check that -tosa-validate of stateful ops kick in
7-
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
8-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
9-
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
10-
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
11-
return
12-
}
13-
144
// -----
155

166
// check that -tosa-validate level checking kick in

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
566566

567567
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
568568
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
569-
// expected-error@+1 {{'tosa.variable' op name has already been declared}}
569+
// expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
570570
tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
571571
return
572572
}
@@ -575,7 +575,7 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
575575

576576
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
577577
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
578-
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
578+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
579579
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
580580
return
581581
}
@@ -584,7 +584,7 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
584584

585585
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
586586
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
587-
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
587+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
588588
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
589589
return
590590
}
@@ -593,7 +593,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
593593

594594
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
595595
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
596-
// expected-error@+1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
596+
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
597597
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
598598
return
599599
}
@@ -602,7 +602,7 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
602602

603603
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
604604
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
605-
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
605+
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
606606
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
607607
return
608608
}

mlir/test/Dialect/Tosa/variables.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
2-
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
1+
// RUN: mlir-opt %s --split-input-file | mlir-opt | FileCheck %s
2+
// RUN: mlir-opt %s --split-input-file --mlir-print-op-generic | mlir-opt | FileCheck %s
33

44

55
// -----

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,3 +785,82 @@ func.func @test_while_loop_cond_output_not_bool(%arg0: tensor<10xi32>, %arg1: te
785785
}
786786
return
787787
}
788+
789+
// -----
790+
791+
func.func @test_variable_multiple_declaration() -> () {
792+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
793+
// expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
794+
tosa.variable @stored_var = dense<-3> : tensor<2x4x8xi32>
795+
return
796+
}
797+
798+
// -----
799+
800+
func.func @test_variable_shape_mismatch() -> () {
801+
// expected-error@+1 {{inferred shape of elements literal ([2]) does not match type ([3])}}
802+
tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32>
803+
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
804+
return
805+
}
806+
807+
// -----
808+
809+
func.func @test_variable_type_mismatch() -> () {
810+
// expected-error@+1 {{expected integer elements, but parsed floating-point}}
811+
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32>
812+
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
813+
return
814+
}
815+
816+
// -----
817+
818+
func.func @test_variable_read_no_declaration() -> () {
819+
// expected-error@+1 {{'tosa.variable_read' op 'stored_var' has not been declared by 'tosa.variable'}}
820+
%0 = tosa.variable_read @stored_var : tensor<f32>
821+
return
822+
}
823+
824+
// -----
825+
826+
func.func @test_variable_read_type_mismatch() -> () {
827+
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
828+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}}
829+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
830+
return
831+
}
832+
833+
// -----
834+
835+
func.func @test_variable_read_shape_mismatch() -> () {
836+
tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
837+
// expected-error@+1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
838+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xf32>
839+
return
840+
}
841+
842+
// -----
843+
844+
func.func @test_variable_write_no_declaration(%arg0: tensor<f32>) -> () {
845+
// expected-error@+1 {{'tosa.variable_write' op 'stored_var' has not been declared by 'tosa.variable'}}
846+
tosa.variable_write @stored_var, %arg0 : tensor<f32>
847+
return
848+
}
849+
850+
// -----
851+
852+
func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () {
853+
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
854+
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}}
855+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
856+
return
857+
}
858+
859+
// -----
860+
861+
func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
862+
tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
863+
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
864+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
865+
return
866+
}

0 commit comments

Comments
 (0)