Skip to content

Commit d72e217

Browse files
committed
[TOSA] Add SameOperandsAndResultRank to TOSA Ops
This patch adds SameOperandsAndResultRank trait to TOSA operators with ResultsBroadcastableShape trait. SameOperandsAndResultRank trait requiring that all operands and results have matching ranks unless the operand/result is unranked. This also renders the TosaMakeBroadcastable pass unnecessary - but this pass is left in for now just in case it is still used in some flows. The lit test, broadcast.mlir, is removed. This also adds verify of the SameOperandsAndResultRank trait in the TosaInferShapes pass to validate inferred shapes. Signed-off-by: Tai Ly <[email protected]> Change-Id: I27bf16b31f15aa92d42ad5376b8791cf74e4f6ac
1 parent 2b6b7f6 commit d72e217

File tree

7 files changed

+97
-365
lines changed

7 files changed

+97
-365
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
228228
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
229229
["inferReturnTypeComponents"]>,
230230
ResultsBroadcastableShape,
231+
SameOperandsAndResultRank,
231232
TosaElementwiseOperator,
232233
Pure])> {
233234
let assemblyFormat =

mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,32 @@ void propagateShapesInRegion(Region &region, TypeModificationState &state) {
303303
}
304304
}
305305

306+
/// recursively validate tosa ops with SameOperandsAndResultRank trait in region
307+
/// and all nested regions
308+
void validateSameOperandsAndResultRankTrait(Region &region) {
309+
int errs = 0;
310+
for (auto &block : region) {
311+
for (auto &op : block) {
312+
if (!op.getDialect() ||
313+
op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
314+
continue;
315+
if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
316+
if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) {
317+
errs++;
318+
}
319+
}
320+
WhileOp whileOp = dyn_cast<WhileOp>(op);
321+
IfOp ifOp = dyn_cast<IfOp>(op);
322+
if (whileOp || ifOp) {
323+
// recurse into whileOp's regions
324+
for (auto &next : op.getRegions()) {
325+
validateSameOperandsAndResultRankTrait(next);
326+
}
327+
}
328+
}
329+
}
330+
}
331+
306332
/// Pass that performs shape propagation across TOSA operations. This includes
307333
/// migrating to within the regions of if/while operations.
308334
struct TosaInferShapes
@@ -313,6 +339,8 @@ struct TosaInferShapes
313339
TypeModificationState state;
314340
propagateShapesInRegion(func.getBody(), state);
315341
state.commit();
342+
343+
validateSameOperandsAndResultRankTrait(func.getBody());
316344
}
317345
};
318346
} // namespace

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
9494
// CHECK: } -> tensor<f32>
9595
%0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
9696

97+
9798
// CHECK: return [[RESULT]] : tensor<f32>
9899
return %0 : tensor<f32>
99100
}
@@ -341,23 +342,9 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
341342

342343
// -----
343344

344-
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
345-
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
346-
// CHECK-LABEL: @test_add_2d_different_ranks
347-
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
348-
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
349345
func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
350-
351-
// CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32>
352-
// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
353-
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
354-
// CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
355-
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
356-
// CHECK: linalg.yield %[[VAL_4]] : f32
357-
// CHECK: } -> tensor<2x3x4xf32>
358-
%0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
359-
360-
// CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
346+
// expected-error@+1 {{'tosa.add' op operands don't have matching ranks}}
347+
%0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
361348
return %0 : tensor<2x3x4xf32>
362349
}
363350

mlir/test/Dialect/Tosa/broadcast.mlir

Lines changed: 0 additions & 285 deletions
This file was deleted.

mlir/test/Dialect/Tosa/constant_folding.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
1515
}
1616

1717
// CHECK-LABEL: func @try_fold_equal_with_unranked_tensor
18-
func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<i32>) {
18+
func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) {
1919
// CHECK: tosa.equal
2020
// CHECK-NEXT: return
21-
%0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
21+
%0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
2222
return
2323
}

mlir/test/Dialect/Tosa/inlining.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ func.func @inlined_while_fn(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tenso
4747
}
4848
func.func private @while_body_50(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>) {
4949
%1 = "tosa.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
50-
%2 = "tosa.add"(%arg3, %1) : (tensor<10xi32>, tensor<i32>) -> tensor<10xi32>
50+
%3 = "tosa.reshape"(%1) {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
51+
%2 = "tosa.add"(%arg3, %3) : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
5152
return %1, %arg1, %arg2, %2: tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>
5253
}
5354
func.func private @while_cond_40(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10xi32>) -> tensor<i1> {

0 commit comments

Comments
 (0)