Skip to content

Commit 69ebdcd

Browse files
committed
[mlir][tosa] Add verifier for ArgMax operator
Verifier ensures that operator is valid by checking: * Output type is of integer type * Axis is within the rank of the tensor Signed-off-by: Georgios Pinitas <[email protected]>
1 parent b935882 commit 69ebdcd

File tree

5 files changed

+24
-7
lines changed

5 files changed

+24
-7
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
4848
let results = (outs
4949
Tosa_Tensor: $output
5050
);
51+
52+
let hasVerifier = 1;
5153
}
5254

5355
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,21 @@ template <typename T> static LogicalResult verifyConvOp(T op) {
211211
return success();
212212
}
213213

214+
LogicalResult tosa::ArgMaxOp::verify() {
215+
// Ensure output is of 32-bit integer
216+
const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
217+
if (!resultETy.isIntOrIndex())
218+
return emitOpError("result tensor is not of integer type");
219+
220+
// Ensure axis is within the tensor rank
221+
const auto inputType = llvm::cast<ShapedType>(getInput().getType());
222+
const int64_t axis = getAxisAttr().getInt();
223+
if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
224+
return emitOpError("specified axis is outside the rank of the tensor");
225+
226+
return success();
227+
}
228+
214229
LogicalResult tosa::AvgPool2dOp::verify() {
215230
auto inputType = llvm::cast<ShapedType>(getInput().getType());
216231
if (hasZeroDimension(inputType))

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
22

33
// CHECK-LABEL: @argmax_nofold
4-
func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
4+
func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
55
// CHECK: tosa.argmax
6-
%0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
7-
return %0 : tensor<?x1xf32>
6+
%0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
7+
return %0 : tensor<?x1xi32>
88
}
99

1010
// CHECK-LABEL: @add_bcast_zero_int

mlir/test/Dialect/Tosa/constrained_shapes.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66
// Uses argmax as canonical example to validate constrained TOSA tensor shapes.
77
// CHECK-LABEL: argmax
88
func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
9-
%0 = "tosa.argmax"(%arg0) {axis = 1 : i32} : (tensor<?xf32>) -> tensor<?xi32>
9+
%0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
1010
return %0 : tensor<?xi32>
1111
}

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate
22

33

4-
func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32> {
4+
func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
55
// expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
6-
%0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32>
7-
return %0 : tensor<1x1x1x1x29x4xf32>
6+
%0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32>
7+
return %0 : tensor<1x1x1x1x29x4xi32>
88
}
99

1010
// -----

0 commit comments

Comments
 (0)