Skip to content

Commit 37d9d8f

Browse files
committed
[mlir][tosa] Add error if checks Variable Operators
For VARIABLE, VARIABLE_WRITE & VARIABLE_READ
1 parent 6957699 commit 37d9d8f

File tree

6 files changed

+184
-16
lines changed

6 files changed

+184
-16
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
131131
let assemblyFormat = [{
132132
$name attr-dict `,` $input1 `:` type($input1)
133133
}];
134+
135+
let hasVerifier = 1;
134136
}
135137

136138
//===----------------------------------------------------------------------===//
@@ -159,6 +161,8 @@ def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
159161
let assemblyFormat = [{
160162
$name attr-dict `:` type($output1)
161163
}];
164+
165+
let hasVerifier = 1;
162166
}
163167

164168
#endif // TOSA_UTIL_OPS

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,74 @@ static LogicalResult verifyConvOpErrorIf(T op) {
562562
return success();
563563
}
564564

565+
// Verify whether same type and shape of the given two types.
566+
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
567+
StringRef name1, Type type2,
568+
StringRef name2) {
569+
auto shapeType1 = dyn_cast<ShapedType>(type1);
570+
auto shapeType2 = dyn_cast<ShapedType>(type2);
571+
if (!shapeType1 || !shapeType2)
572+
return failure();
573+
574+
auto elemType1 = shapeType1.getElementType();
575+
auto elemType2 = shapeType2.getElementType();
576+
if (elemType1 != elemType2)
577+
return op->emitOpError()
578+
<< "require same element type for " << name1 << " (" << elemType1
579+
<< ") and " << name2 << " (" << elemType2 << ")";
580+
581+
if (failed(verifyCompatibleShape(type1, type2)))
582+
return op->emitOpError()
583+
<< "require same shapes for " << name1 << " (" << type1 << ") and "
584+
<< name2 << " (" << type2 << ")";
585+
586+
return success();
587+
}
588+
589+
template <typename T>
590+
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
591+
// Currently, the variable's definition point is searched via walk(),
592+
// starting from the top-level ModuleOp and stopping at the point of use. Once
593+
// TOSA control flow and variable extensions reach the complete state, may
594+
// leverage MLIR's Symbol Table functionality to look up symbol and enhance
595+
// the search to a TOSA specific graph traversal over the IR structure.
596+
StringRef symName = op.getName();
597+
tosa::VariableOp varOp = nullptr;
598+
auto thisOp = op.getOperation();
599+
ModuleOp module = thisOp->template getParentOfType<ModuleOp>();
600+
bool found = false;
601+
602+
module.walk([&](Operation *tempOp) {
603+
// Reach this op itself.
604+
if (tempOp == thisOp)
605+
return;
606+
607+
if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
608+
if (symName == tosaOp.getName()) {
609+
if (found == true) {
610+
op->emitOpError("illegal to have multiple declaration of '")
611+
<< symName << "'";
612+
return;
613+
}
614+
found = true;
615+
varOp = tosaOp;
616+
}
617+
}
618+
});
619+
620+
if (found == false)
621+
return op->emitOpError("'")
622+
<< symName << "' has not been declared by 'tosa.variable'";
623+
624+
// Verify type and shape
625+
Type varType = cast<tosa::VariableOp>(varOp).getType();
626+
if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
627+
.failed())
628+
return failure();
629+
630+
return success();
631+
}
632+
565633
// verify that inType and outType have same element types
566634
template <typename T>
567635
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3495,6 +3563,22 @@ LogicalResult tosa::SelectOp::verify() {
34953563
return success();
34963564
}
34973565

3566+
LogicalResult tosa::VariableReadOp::verify() {
3567+
if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
3568+
.failed())
3569+
return failure();
3570+
3571+
return success();
3572+
}
3573+
3574+
LogicalResult tosa::VariableWriteOp::verify() {
3575+
if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
3576+
.failed())
3577+
return failure();
3578+
3579+
return success();
3580+
}
3581+
34983582
// parse and print of WhileOp refer to the implementation of SCF dialect.
34993583
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
35003584
SmallVector<OpAsmParser::Argument, 4> regionArgs;

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
11
// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics
22

33

4-
// -----
5-
6-
// check that -tosa-validate of stateful ops kick in
7-
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
8-
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
9-
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
10-
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
11-
return
12-
}
13-
144
// -----
155

166
// check that -tosa-validate level checking kick in

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
585585

586586
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
587587
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
588-
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
588+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
589589
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
590590
return
591591
}
@@ -594,7 +594,7 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
594594

595595
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
596596
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
597-
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
597+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
598598
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
599599
return
600600
}
@@ -603,7 +603,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
603603

604604
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
605605
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
606-
// expected-error@+1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
606+
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
607607
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
608608
return
609609
}
@@ -612,7 +612,7 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
612612

613613
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
614614
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
615-
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
615+
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
616616
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
617617
return
618618
}

mlir/test/Dialect/Tosa/variables.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
2-
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
1+
// RUN: mlir-opt %s --split-input-file | mlir-opt | FileCheck %s
2+
// RUN: mlir-opt %s --split-input-file --mlir-print-op-generic | mlir-opt | FileCheck %s
33

44

55
// -----

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,93 @@ func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1
403403
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x8xf32>
404404
return %0 : tensor<13x26x8xf32>
405405
}
406+
407+
// -----
408+
409+
func.func @test_variable_shape_mismatch() -> () {
410+
// expected-error@+1 {{inferred shape of elements literal ([2]) does not match type ([3])}}
411+
tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32>
412+
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
413+
return
414+
}
415+
416+
// -----
417+
418+
func.func @test_variable_type_mismatch() -> () {
419+
// expected-error@+1 {{expected integer elements, but parsed floating-point}}
420+
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32>
421+
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
422+
return
423+
}
424+
425+
// -----
426+
427+
func.func @test_variable_read_no_declaration() -> () {
428+
// expected-error@+1 {{'tosa.variable_read' op 'stored_var' has not been declared by 'tosa.variable'}}
429+
%0 = tosa.variable_read @stored_var : tensor<f32>
430+
return
431+
}
432+
433+
// -----
434+
435+
func.func @test_variable_read_multiple_declaration() -> () {
436+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
437+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
438+
// expected-error@+1 {{'tosa.variable_read' op illegal to have multiple declaration of 'stored_var'}}
439+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
440+
return
441+
}
442+
443+
// -----
444+
445+
func.func @test_variable_read_type_mismatch() -> () {
446+
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
447+
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}}
448+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
449+
return
450+
}
451+
452+
// -----
453+
454+
func.func @test_variable_read_shape_mismatch() -> () {
455+
tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
456+
// expected-error@+1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
457+
%0 = tosa.variable_read @stored_var : tensor<2x4x8xf32>
458+
return
459+
}
460+
461+
// -----
462+
463+
func.func @test_variable_write_no_declaration(%arg0: tensor<f32>) -> () {
464+
// expected-error@+1 {{'tosa.variable_write' op 'stored_var' has not been declared by 'tosa.variable'}}
465+
tosa.variable_write @stored_var, %arg0 : tensor<f32>
466+
return
467+
}
468+
469+
// -----
470+
471+
func.func @test_variable_write_multiple_declaration(%arg0: tensor<2x4x8xi32>) -> () {
472+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
473+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
474+
// expected-error@+1 {{'tosa.variable_write' op illegal to have multiple declaration of 'stored_var'}}
475+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
476+
return
477+
}
478+
479+
// -----
480+
481+
func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () {
482+
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
483+
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}}
484+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
485+
return
486+
}
487+
488+
// -----
489+
490+
func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
491+
tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
492+
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
493+
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
494+
return
495+
}

0 commit comments

Comments
 (0)