Skip to content

Commit e3a3d4e

Browse files
committed
[mlir][tosa] Add error if checks Variable Operators
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
1 parent 5eca2dd commit e3a3d4e

File tree

6 files changed

+184
-16
lines changed

6 files changed

+184
-16
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
131131
let assemblyFormat = [{
132132
$name attr-dict `,` $input1 `:` type($input1)
133133
}];
134+
135+
let hasVerifier = 1;
134136
}
135137

136138
//===----------------------------------------------------------------------===//
@@ -159,6 +161,8 @@ def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
159161
let assemblyFormat = [{
160162
$name attr-dict `:` type($output1)
161163
}];
164+
165+
let hasVerifier = 1;
162166
}
163167

164168
#endif // TOSA_UTIL_OPS

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,74 @@ static LogicalResult verifyConvOpErrorIf(T op) {
572572
return success();
573573
}
574574

575+
// Verify whether same type and shape of the given two types.
576+
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
577+
StringRef name1, Type type2,
578+
StringRef name2) {
579+
auto shapeType1 = dyn_cast<ShapedType>(type1);
580+
auto shapeType2 = dyn_cast<ShapedType>(type2);
581+
if (!shapeType1 || !shapeType2)
582+
return failure();
583+
584+
auto elemType1 = shapeType1.getElementType();
585+
auto elemType2 = shapeType2.getElementType();
586+
if (elemType1 != elemType2)
587+
return op->emitOpError()
588+
<< "require same element type for " << name1 << " (" << elemType1
589+
<< ") and " << name2 << " (" << elemType2 << ")";
590+
591+
if (failed(verifyCompatibleShape(type1, type2)))
592+
return op->emitOpError()
593+
<< "require same shapes for " << name1 << " (" << type1 << ") and "
594+
<< name2 << " (" << type2 << ")";
595+
596+
return success();
597+
}
598+
599+
template <typename T>
600+
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
601+
// Currently, the variable's definition point is searched via walk(),
602+
// starting from the top-level ModuleOp and stopping at the point of use. Once
603+
// TOSA control flow and variable extensions reach the complete state, may
604+
// leverage MLIR's Symbol Table functionality to look up symbol and enhance
605+
// the search to a TOSA specific graph traversal over the IR structure.
606+
StringRef symName = op.getName();
607+
tosa::VariableOp varOp = nullptr;
608+
auto thisOp = op.getOperation();
609+
ModuleOp module = thisOp->template getParentOfType<ModuleOp>();
610+
bool found = false;
611+
612+
module.walk([&](Operation *tempOp) {
613+
// Reach this op itself.
614+
if (tempOp == thisOp)
615+
return;
616+
617+
if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
618+
if (symName == tosaOp.getName()) {
619+
if (found == true) {
620+
op->emitOpError("illegal to have multiple declaration of '")
621+
<< symName << "'";
622+
return;
623+
}
624+
found = true;
625+
varOp = tosaOp;
626+
}
627+
}
628+
});
629+
630+
if (found == false)
631+
return op->emitOpError("'")
632+
<< symName << "' has not been declared by 'tosa.variable'";
633+
634+
// Verify type and shape
635+
Type varType = cast<tosa::VariableOp>(varOp).getType();
636+
if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
637+
.failed())
638+
return failure();
639+
640+
return success();
641+
}
642+
575643
// verify that inType and outType have same element types
576644
template <typename T>
577645
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3455,6 +3523,22 @@ LogicalResult tosa::SelectOp::verify() {
34553523
return success();
34563524
}
34573525

3526+
LogicalResult tosa::VariableReadOp::verify() {
3527+
if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
3528+
.failed())
3529+
return failure();
3530+
3531+
return success();
3532+
}
3533+
3534+
LogicalResult tosa::VariableWriteOp::verify() {
3535+
if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
3536+
.failed())
3537+
return failure();
3538+
3539+
return success();
3540+
}
3541+
34583542
// parse and print of WhileOp refer to the implementation of SCF dialect.
34593543
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
34603544
SmallVector<OpAsmParser::Argument, 4> regionArgs;

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
11
// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
22

33

4-
// -----
5-
6-
// check that -tosa-validate of stateful ops kick in
7-
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
8-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
9-
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
10-
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
11-
return
12-
}
13-
144
// -----
155

166
// check that -tosa-validate level checking kick in

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
595595

596596
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
597597
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
598-
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
598+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
599599
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
600600
return
601601
}
@@ -604,7 +604,7 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
604604

605605
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
606606
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
607-
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
607+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
608608
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
609609
return
610610
}
@@ -613,7 +613,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
613613

614614
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
615615
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
616-
// expected-error@+1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
616+
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
617617
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
618618
return
619619
}
@@ -622,7 +622,7 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
622622

623623
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
624624
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
625-
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
625+
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
626626
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
627627
return
628628
}

mlir/test/Dialect/Tosa/variables.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
2-
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
1+
// RUN: mlir-opt %s --split-input-file | mlir-opt | FileCheck %s
2+
// RUN: mlir-opt %s --split-input-file --mlir-print-op-generic | mlir-opt | FileCheck %s
33

44

55
// -----

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,93 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
358358
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
359359
return %0 : tensor<2x?xf32>
360360
}
361+
362+
// -----
363+
364+
func.func @test_variable_shape_mismatch() -> () {
365+
// expected-error@+1 {{inferred shape of elements literal ([2]) does not match type ([3])}}
366+
tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32>
367+
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
368+
return
369+
}
370+
371+
// -----
372+
373+
func.func @test_variable_type_mismatch() -> () {
374+
// expected-error@+1 {{expected integer elements, but parsed floating-point}}
375+
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32>
376+
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
377+
return
378+
}
379+
380+
// -----
381+
382+
func.func @test_variable_read_no_declaration() -> () {
383+
// expected-error@+1 {{'tosa.variable_read' op 'stored_var' has not been declared by 'tosa.variable'}}
384+
%0 = tosa.variable_read @stored_var : tensor<f32>
385+
return
386+
}
387+
388+
// -----
389+
390+
func.func @test_variable_read_multiple_declaration() -> () {
391+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
392+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
393+
// expected-error@+1 {{'tosa.variable_read' op illegal to have multiple declaration of 'stored_var'}}
394+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
395+
return
396+
}
397+
398+
// -----
399+
400+
func.func @test_variable_read_type_mismatch() -> () {
401+
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
402+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}}
403+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
404+
return
405+
}
406+
407+
// -----
408+
409+
func.func @test_variable_read_shape_mismatch() -> () {
410+
tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
411+
// expected-error@+1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
412+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xf32>
413+
return
414+
}
415+
416+
// -----
417+
418+
func.func @test_variable_write_no_declaration(%arg0: tensor<f32>) -> () {
419+
// expected-error@+1 {{'tosa.variable_write' op 'stored_var' has not been declared by 'tosa.variable'}}
420+
tosa.variable_write @stored_var, %arg0 : tensor<f32>
421+
return
422+
}
423+
424+
// -----
425+
426+
func.func @test_variable_write_multiple_declaration(%arg0: tensor<2x4x8xi32>) -> () {
427+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
428+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
429+
// expected-error@+1 {{'tosa.variable_write' op illegal to have multiple declaration of 'stored_var'}}
430+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
431+
return
432+
}
433+
434+
// -----
435+
436+
func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () {
437+
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
438+
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}}
439+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
440+
return
441+
}
442+
443+
// -----
444+
445+
func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
446+
tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
447+
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
448+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
449+
return
450+
}

0 commit comments

Comments
 (0)