Skip to content

[mlir][tosa] Add error if checks Variable Operators #137291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
attr-dict
custom<TypeOrAttr>($type, $initial_value)
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand All @@ -131,6 +133,8 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
let assemblyFormat = [{
$name attr-dict `,` $input1 `:` type($input1)
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -159,6 +163,8 @@ def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
let assemblyFormat = [{
$name attr-dict `:` type($output1)
}];

let hasVerifier = 1;
}

#endif // TOSA_UTIL_OPS
78 changes: 78 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,58 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
}

// Returns the first declaration point prior to this operation or failure if
// not found.
static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
StringRef symName) {
ModuleOp module = op->getParentOfType<ModuleOp>();
tosa::VariableOp varOp = nullptr;

// TODO: Adopt SymbolTable trait to Varible ops.
// Currently, the variable's definition point is searched via walk(),
// starting from the top-level ModuleOp and stopping at the point of use. Once
// TOSA control flow and variable extensions reach the complete state, may
// leverage MLIR's Symbol Table functionality to look up symbol and enhance
// the search to a TOSA specific graph traversal over the IR structure.
module.walk([&](Operation *tempOp) {
// Reach this op itself.
if (tempOp == op) {
return WalkResult::interrupt();
}

if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
if (symName == tosaOp.getName()) {
varOp = tosaOp;
return WalkResult::interrupt();
}
}

return WalkResult::advance();
});

if (varOp)
return varOp;

return failure();
}

template <typename T>
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
StringRef symName = op.getName();
FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
if (failed(varOp))
return op->emitOpError("'")
<< symName << "' has not been declared by 'tosa.variable'";

// Verify type and shape
Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
.failed())
return failure();

return success();
}

// verify that inType and outType have same element types
template <typename T>
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
Expand Down Expand Up @@ -3660,6 +3712,32 @@ LogicalResult tosa::SelectOp::verify() {
return success();
}

LogicalResult tosa::VariableOp::verify() {
StringRef symName = getName();
FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
if (succeeded(varOp))
return emitOpError("illegal to have multiple declaration of '")
<< symName << "'";

return success();
}

LogicalResult tosa::VariableReadOp::verify() {
if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
.failed())
return failure();

return success();
}

LogicalResult tosa::VariableWriteOp::verify() {
if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
.failed())
return failure();

return success();
}

// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
Expand Down
10 changes: 0 additions & 10 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
// RUN: mlir-opt %s --split-input-file --tosa-to-linalg-pipeline -verify-diagnostics


// -----

// check that -tosa-validate of stateful ops kick in
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}

// -----

// check that -tosa-validate level checking kick in
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten

func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable' op name has already been declared}}
// expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
return
}
Expand All @@ -575,7 +575,7 @@ func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {

func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
return
}
Expand All @@ -584,7 +584,7 @@ func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {

func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable_read' op illegal: operand/result data types not supported}}
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
%0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
return
}
Expand All @@ -593,7 +593,7 @@ func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {

func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable_write' op illegal: operand/result data types not supported}}
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
return
}
Expand All @@ -602,7 +602,7 @@ func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {

func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
// expected-error@+1 {{'tosa.variable_write' op operand type does not equal variable type}}
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
return
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Tosa/variables.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
// RUN: mlir-opt %s --split-input-file | mlir-opt | FileCheck %s
// RUN: mlir-opt %s --split-input-file --mlir-print-op-generic | mlir-opt | FileCheck %s


// -----
Expand Down
79 changes: 79 additions & 0 deletions mlir/test/Dialect/Tosa/verifier.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -785,3 +785,82 @@ func.func @test_while_loop_cond_output_not_bool(%arg0: tensor<10xi32>, %arg1: te
}
return
}

// -----

func.func @test_variable_multiple_declaration() -> () {
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
// expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
tosa.variable @stored_var = dense<-3> : tensor<2x4x8xi32>
return
}

// -----

func.func @test_variable_shape_mismatch() -> () {
// expected-error@+1 {{inferred shape of elements literal ([2]) does not match type ([3])}}
tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32>
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
return
}

// -----

func.func @test_variable_type_mismatch() -> () {
// expected-error@+1 {{expected integer elements, but parsed floating-point}}
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32>
// expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
return
}

// -----

func.func @test_variable_read_no_declaration() -> () {
// expected-error@+1 {{'tosa.variable_read' op 'stored_var' has not been declared by 'tosa.variable'}}
%0 = tosa.variable_read @stored_var : tensor<f32>
return
}

// -----

func.func @test_variable_read_type_mismatch() -> () {
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
// expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
return
}

// -----

func.func @test_variable_read_shape_mismatch() -> () {
tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
// expected-error@+1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
%0 = tosa.variable_read @stored_var : tensor<2x4x8xf32>
return
}

// -----

func.func @test_variable_write_no_declaration(%arg0: tensor<f32>) -> () {
// expected-error@+1 {{'tosa.variable_write' op 'stored_var' has not been declared by 'tosa.variable'}}
tosa.variable_write @stored_var, %arg0 : tensor<f32>
return
}

// -----

func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () {
tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32>
// expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32>
return
}

// -----

func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32>
// expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}}
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
return
}
Loading