Skip to content

Commit 498121e

Browse files
authored
[mlir][tosa] Allow unranked indices argument for gather/scatter (#140618)
This commit allows the indices argument for gather and scatter to be unranked. This can be computed during shape inference.
1 parent 8416bac commit 498121e

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2125,7 +2125,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
21252125

21262126
let arguments = (ins
21272127
Tosa_Tensor3D:$values,
2128-
TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
2128+
Tosa_Int32Tensor2D:$indices
21292129
);
21302130

21312131
let results = (outs
@@ -2159,7 +2159,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
21592159

21602160
let arguments = (ins
21612161
Tosa_Tensor3D:$values_in,
2162-
TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
2162+
Tosa_Int32Tensor2D:$indices,
21632163
Tosa_Tensor3D:$input
21642164
);
21652165

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def Tosa_TensorUpto4D : AnyTypeOf<[
181181

182182
def Tosa_Int32TensorUpto4D : AnyTypeOf<[
183183
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
184+
def Tosa_Int32Tensor2D : AnyTypeOf<[
185+
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [2]>]>;
184186

185187
def Tosa_TensorAtLeast1D : AnyTypeOf<[
186188
Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,20 @@ func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %a
719719
return %0 : tensor<13x21x3xf32>
720720
}
721721

722+
// -----
723+
// CHECK-LABEL: gather_unranked_indices
724+
func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>) -> tensor<13x26x3xf32> {
725+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<*xi32>) -> tensor<13x26x3xf32>
726+
return %0 : tensor<13x26x3xf32>
727+
}
728+
729+
// -----
730+
// CHECK-LABEL: scatter_unranked_indices
731+
func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
732+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
733+
return %0 : tensor<13x21x3xf32>
734+
}
735+
722736
// -----
723737
// CHECK-LABEL: resize
724738
func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {

0 commit comments

Comments
 (0)