Skip to content

Commit 1b6299c

Browse files
committed
[mlir][tosa] Add error if checks Variable Operators
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
1 parent 9f2bcc7 commit 1b6299c

File tree

6 files changed

+195
-18
lines changed

6 files changed

+195
-18
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
106106
attr-dict
107107
custom<TypeOrAttr>($type, $initial_value)
108108
}];
109+
110+
let hasVerifier = 1;
109111
}
110112

111113
//===----------------------------------------------------------------------===//
@@ -131,6 +133,8 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
131133
let assemblyFormat = [{
132134
$name attr-dict `,` $input1 `:` type($input1)
133135
}];
136+
137+
let hasVerifier = 1;
134138
}
135139

136140
//===----------------------------------------------------------------------===//
@@ -159,6 +163,8 @@ def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
159163
let assemblyFormat = [{
160164
$name attr-dict `:` type($output1)
161165
}];
166+
167+
let hasVerifier = 1;
162168
}
163169

164170
#endif // TOSA_UTIL_OPS

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,82 @@ static LogicalResult verifyConvOpErrorIf(T op) {
562562
return success();
563563
}
564564

565+
// Verify whether same type and shape of the given two types.
566+
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
567+
StringRef name1, Type type2,
568+
StringRef name2) {
569+
auto shapeType1 = dyn_cast<ShapedType>(type1);
570+
auto shapeType2 = dyn_cast<ShapedType>(type2);
571+
if (!shapeType1 || !shapeType2)
572+
return failure();
573+
574+
auto elemType1 = shapeType1.getElementType();
575+
auto elemType2 = shapeType2.getElementType();
576+
if (elemType1 != elemType2)
577+
return op->emitOpError()
578+
<< "require same element type for " << name1 << " (" << elemType1
579+
<< ") and " << name2 << " (" << elemType2 << ")";
580+
581+
if (failed(verifyCompatibleShape(type1, type2)))
582+
return op->emitOpError()
583+
<< "require same shapes for " << name1 << " (" << type1 << ") and "
584+
<< name2 << " (" << type2 << ")";
585+
586+
return success();
587+
}
588+
589+
// Returns the first declaration point prior to this operation or failure if
590+
// not found.
591+
static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
592+
StringRef symName) {
593+
ModuleOp module = op->getParentOfType<ModuleOp>();
594+
tosa::VariableOp varOp = nullptr;
595+
596+
// TODO: Adopt SymbolTable trait to Varible ops.
597+
// Currently, the variable's definition point is searched via walk(),
598+
// starting from the top-level ModuleOp and stopping at the point of use. Once
599+
// TOSA control flow and variable extensions reach the complete state, may
600+
// leverage MLIR's Symbol Table functionality to look up symbol and enhance
601+
// the search to a TOSA specific graph traversal over the IR structure.
602+
module.walk([&](Operation *tempOp) {
603+
// Reach this op itself.
604+
if (tempOp == op) {
605+
return WalkResult::interrupt();
606+
}
607+
608+
if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
609+
if (symName == tosaOp.getName()) {
610+
varOp = tosaOp;
611+
return WalkResult::interrupt();
612+
}
613+
}
614+
615+
return WalkResult::advance();
616+
});
617+
618+
if (varOp)
619+
return varOp;
620+
621+
return failure();
622+
}
623+
624+
template <typename T>
625+
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
626+
StringRef symName = op.getName();
627+
FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
628+
if (failed(varOp))
629+
return op->emitOpError("'")
630+
<< symName << "' has not been declared by 'tosa.variable'";
631+
632+
// Verify type and shape
633+
Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
634+
if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
635+
.failed())
636+
return failure();
637+
638+
return success();
639+
}
640+
565641
// verify that inType and outType have same element types
566642
template <typename T>
567643
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3531,6 +3607,32 @@ LogicalResult tosa::SelectOp::verify() {
35313607
return success();
35323608
}
35333609

3610+
LogicalResult tosa::VariableOp::verify() {
3611+
StringRef symName = getName();
3612+
FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
3613+
if (succeeded(varOp))
3614+
return emitOpError("illegal to have multiple declaration of '")
3615+
<< symName << "'";
3616+
3617+
return success();
3618+
}
3619+
3620+
LogicalResult tosa::VariableReadOp::verify() {
3621+
if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
3622+
.failed())
3623+
return failure();
3624+
3625+
return success();
3626+
}
3627+
3628+
LogicalResult tosa::VariableWriteOp::verify() {
3629+
if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
3630+
.failed())
3631+
return failure();
3632+
3633+
return success();
3634+
}
3635+
35343636
// parse and print of WhileOp refer to the implementation of SCF dialect.
35353637
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
35363638
SmallVector<OpAsmParser::Argument, 4> regionArgs;

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
11
// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
22

33

4-
// -----
5-
6-
// check that -tosa-validate of stateful ops kick in
7-
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
8-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
9-
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
10-
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
11-
return
12-
}
13-
144
// -----
155

166
// check that -tosa-validate level checking kick in

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
566566

567567
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
568568
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
569-
// expected-error@+1 {{'tosa.variable' op name has already been declared}}
569+
// expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
570570
tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
571571
return
572572
}
@@ -575,7 +575,7 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
575575

576576
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
577577
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
578-
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
578+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
579579
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
580580
return
581581
}
@@ -584,7 +584,7 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
584584

585585
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
586586
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
587-
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
587+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
588588
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
589589
return
590590
}
@@ -593,7 +593,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
593593

594594
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
595595
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
596-
// expected-error@+1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
596+
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
597597
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
598598
return
599599
}
@@ -602,7 +602,7 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
602602

603603
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
604604
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
605-
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
605+
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
606606
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
607607
return
608608
}

mlir/test/Dialect/Tosa/variables.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
2-
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
1+
// RUN: mlir-opt %s --split-input-file | mlir-opt | FileCheck %s
2+
// RUN: mlir-opt %s --split-input-file --mlir-print-op-generic | mlir-opt | FileCheck %s
33

44

55
// -----

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,4 +436,83 @@ func.func @test_pad_invalid_padding_value(%arg0: tensor<10xi8>, %arg1: tensor<1x
436436
// expected-error@+1 {{invalid padding values at dimension 0: values must be non-negative or -1 for dynamic padding, got [-2, 2]}}
437437
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<10xi8>, !tosa.shape<2>, tensor<1xi8>) -> tensor<10xi8>
438438
return %1 : tensor<10xi8>
439-
}
439+
}
440+
441+
// -----
442+
443+
func.func @test_variable_multiple_declaration() -> () {
444+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
445+
// expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
446+
tosa.variable @stored_var = dense<-3> : tensor<2x4x8xi32>
447+
return
448+
}
449+
450+
// -----
451+
452+
func.func @test_variable_shape_mismatch() -> () {
453+
// expected-error@+1 {{inferred shape of elements literal ([2]) does not match type ([3])}}
454+
tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32>
455+
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
456+
return
457+
}
458+
459+
// -----
460+
461+
func.func @test_variable_type_mismatch() -> () {
462+
// expected-error@+1 {{expected integer elements, but parsed floating-point}}
463+
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32>
464+
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
465+
return
466+
}
467+
468+
// -----
469+
470+
func.func @test_variable_read_no_declaration() -> () {
471+
// expected-error@+1 {{'tosa.variable_read' op 'stored_var' has not been declared by 'tosa.variable'}}
472+
%0 = tosa.variable_read @stored_var : tensor<f32>
473+
return
474+
}
475+
476+
// -----
477+
478+
func.func @test_variable_read_type_mismatch() -> () {
479+
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
480+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}}
481+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
482+
return
483+
}
484+
485+
// -----
486+
487+
func.func @test_variable_read_shape_mismatch() -> () {
488+
tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
489+
// expected-error@+1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
490+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xf32>
491+
return
492+
}
493+
494+
// -----
495+
496+
func.func @test_variable_write_no_declaration(%arg0: tensor<f32>) -> () {
497+
// expected-error@+1 {{'tosa.variable_write' op 'stored_var' has not been declared by 'tosa.variable'}}
498+
tosa.variable_write @stored_var, %arg0 : tensor<f32>
499+
return
500+
}
501+
502+
// -----
503+
504+
func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () {
505+
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
506+
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}}
507+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
508+
return
509+
}
510+
511+
// -----
512+
513+
func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
514+
tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
515+
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
516+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
517+
return
518+
}

0 commit comments

Comments
 (0)