Skip to content

Commit 4f46b75

Browse files
authored
[mlir][tosa] Add expected output shape check to argmax verifier (#129870)
Fixes some test cases which incorrectly declared the output shape and added a negative test case. Signed-off-by: Luke Hutton <[email protected]>
1 parent 4554b30 commit 4f46b75

File tree

4 files changed

+35
-10
lines changed

4 files changed

+35
-10
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -454,17 +454,34 @@ static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
454454
}
455455

456456
LogicalResult tosa::ArgMaxOp::verify() {
457+
const ShapedType resultType = llvm::cast<ShapedType>(getType());
458+
457459
// Ensure output is of 32-bit integer
458-
const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
459-
if (!resultETy.isIntOrIndex())
460+
if (const auto resultETy = resultType.getElementType();
461+
!resultETy.isIntOrIndex())
460462
return emitOpError("result tensor is not of integer type");
461463

462-
// Ensure axis is within the tensor rank
463464
const auto inputType = llvm::cast<ShapedType>(getInput().getType());
465+
if (!inputType.hasRank())
466+
return success();
467+
468+
// Ensure axis is within the tensor rank
464469
const int64_t axis = getAxisAttr().getInt();
465-
if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
470+
if (((axis < 0) || axis >= inputType.getRank()))
466471
return emitOpError("specified axis is outside the rank of the tensor");
467472

473+
if (!resultType.hasRank())
474+
return success();
475+
476+
const ArrayRef<int64_t> inputShape = inputType.getShape();
477+
const ArrayRef<int64_t> outputShape = resultType.getShape();
478+
llvm::SmallVector<int64_t> expectedOutputShape(inputShape.begin(),
479+
inputShape.end());
480+
expectedOutputShape.erase(expectedOutputShape.begin() + axis);
481+
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
482+
return emitOpError("expected output shape '")
483+
<< expectedOutputShape << "', got '" << outputShape << "'";
484+
468485
return success();
469486
}
470487

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s
22

33
// CHECK-LABEL: @argmax_nofold
4-
func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
4+
func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
55
// CHECK: tosa.argmax
6-
%0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
7-
return %0 : tensor<?x1xi32>
6+
%0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<1xi32>
7+
return %0 : tensor<1xi32>
88
}
99

1010
// -----

mlir/test/Dialect/Tosa/constrained_shapes.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// -----
66
// Uses argmax as canonical example to validate constrained TOSA tensor shapes.
77
// CHECK-LABEL: argmax
8-
func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
9-
%0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
10-
return %0 : tensor<?xi32>
8+
func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<i32> {
9+
%0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<i32>
10+
return %0 : tensor<i32>
1111
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,3 +1423,11 @@ func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (te
14231423
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
14241424
return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
14251425
}
1426+
1427+
// -----
1428+
1429+
func.func @test_argmax_invalid_output_shape(%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
1430+
// expected-error@+1 {{'tosa.argmax' op expected output shape '2, 3', got '1, 2, 3'}}
1431+
%0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<1x2x3xf32>) -> tensor<1x2x3xi32>
1432+
return %0 : tensor<1x2x3xi32>
1433+
}

0 commit comments

Comments
 (0)