File tree 5 files changed +24
-7
lines changed
include/mlir/Dialect/Tosa/IR
5 files changed +24
-7
lines changed Original file line number Diff line number Diff line change @@ -48,6 +48,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
48
48
let results = (outs
49
49
Tosa_Tensor: $output
50
50
);
51
+
52
+ let hasVerifier = 1;
51
53
}
52
54
53
55
//===----------------------------------------------------------------------===//
Original file line number Diff line number Diff line change @@ -211,6 +211,21 @@ template <typename T> static LogicalResult verifyConvOp(T op) {
211
211
return success ();
212
212
}
213
213
214
+ LogicalResult tosa::ArgMaxOp::verify () {
215
+ // Ensure output is of 32-bit integer
216
+ const auto resultETy = llvm::cast<ShapedType>(getType ()).getElementType ();
217
+ if (!resultETy.isIntOrIndex ())
218
+ return emitOpError (" result tensor is not of integer type" );
219
+
220
+ // Ensure axis is within the tensor rank
221
+ const auto inputType = llvm::cast<ShapedType>(getInput ().getType ());
222
+ const int64_t axis = getAxisAttr ().getInt ();
223
+ if (inputType.hasRank () && ((axis < 0 ) || axis >= inputType.getRank ()))
224
+ return emitOpError (" specified axis is outside the rank of the tensor" );
225
+
226
+ return success ();
227
+ }
228
+
214
229
LogicalResult tosa::AvgPool2dOp::verify () {
215
230
auto inputType = llvm::cast<ShapedType>(getInput ().getType ());
216
231
if (hasZeroDimension (inputType))
Original file line number Diff line number Diff line change 1
1
// RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
2
2
3
3
// CHECK-LABEL: @argmax_nofold
4
- func.func @argmax_nofold (%arg0: tensor <?x1 xf32 >) -> tensor <?x 1 x f32 > {
4
+ func.func @argmax_nofold (%arg0: tensor <?x1 xf32 >) -> tensor <?x 1 x i32 > {
5
5
// CHECK: tosa.argmax
6
- %0 = tosa.argmax %arg0 {axis = 0 : i32 }: (tensor <?x1 xf32 >) -> tensor <?x 1 x f32 >
7
- return %0 : tensor <?x 1 x f32 >
6
+ %0 = tosa.argmax %arg0 {axis = 0 : i32 }: (tensor <?x1 xf32 >) -> tensor <?x 1 x i32 >
7
+ return %0 : tensor <?x 1 x i32 >
8
8
}
9
9
10
10
// CHECK-LABEL: @add_bcast_zero_int
Original file line number Diff line number Diff line change 6
6
// Uses argmax as canonical example to validate constrained TOSA tensor shapes.
7
7
// CHECK-LABEL: argmax
8
8
func.func @test_argmax (%arg0: tensor <?xf32 >) -> tensor <?xi32 > {
9
- %0 = " tosa.argmax" (%arg0 ) {axis = 1 : i32 } : (tensor <?xf32 >) -> tensor <?xi32 >
9
+ %0 = " tosa.argmax" (%arg0 ) {axis = 0 : i32 } : (tensor <?xf32 >) -> tensor <?xi32 >
10
10
return %0 : tensor <?xi32 >
11
11
}
Original file line number Diff line number Diff line change 1
1
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate
2
2
3
3
4
- func.func @test_argmax (%arg0: tensor <1 x1 x1 x1 x29 x29 x4 xf32 >) -> tensor <1 x 1 x 1 x 1 x 29 x 4 x f32 > {
4
+ func.func @test_argmax (%arg0: tensor <1 x1 x1 x1 x29 x29 x4 xf32 >) -> tensor <1 x 1 x 1 x 1 x 29 x 4 x i32 > {
5
5
// expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
6
- %0 = " tosa.argmax" (%arg0 ) {axis = 4 : i32 } : (tensor <1 x1 x1 x1 x29 x29 x4 xf32 >) -> tensor <1 x 1 x 1 x 1 x 29 x 4 x f32 >
7
- return %0 : tensor <1 x 1 x 1 x 1 x 29 x 4 x f32 >
6
+ %0 = " tosa.argmax" (%arg0 ) {axis = 4 : i32 } : (tensor <1 x1 x1 x1 x29 x29 x4 xf32 >) -> tensor <1 x 1 x 1 x 1 x 29 x 4 x i32 >
7
+ return %0 : tensor <1 x 1 x 1 x 1 x 29 x 4 x i32 >
8
8
}
9
9
10
10
// -----
You can’t perform that action at this time.
0 commit comments