1
1
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
2
+ // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
2
3
3
4
// CHECK-LABEL: @matmul
4
5
func.func @matmul (%arg0: tensor <1 x5 x3 xf32 >, %arg1: tensor <1 x3 x6 xf32 >) -> (tensor <1 x5 x6 xf32 >) {
@@ -363,11 +364,14 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
363
364
364
365
// CHECK-LABEL: @conv2d_i8
365
366
func.func @conv2d_i8 (%input: tensor <1 x49 x42 x27 xi8 >, %weights: tensor <28 x1 x1 x27 xi8 >, %bias: tensor <28 xi8 >) -> () {
367
+ // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
368
+ // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
366
369
// CHECK: %[[M_IN:.+]] = tensor.empty()
367
370
// CHECK: %[[CST:.+]] = arith.constant 0
368
371
// CHECK: %[[FILL:.+]] = linalg.fill
369
372
// CHECK: %[[B_IN:.+]] = tensor.empty()
370
373
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
374
+ // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
371
375
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
372
376
// CHECK: arith.extsi
373
377
// CHECK: arith.addi
@@ -383,11 +387,14 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
383
387
384
388
// CHECK-LABEL: @conv2d_f32
385
389
func.func @conv2d_f32 (%input: tensor <1 x49 x42 x27 xf32 >, %weights: tensor <28 x3 x3 x27 xf32 >, %bias: tensor <28 xf32 >) -> () {
390
+ // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
391
+ // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
386
392
// CHECK: %[[M_IN:.+]] = tensor.empty()
387
393
// CHECK: %[[CST:.+]] = arith.constant 0
388
394
// CHECK: %[[FILL:.+]] = linalg.fill
389
395
// CHECK: %[[B_IN:.+]] = tensor.empty()
390
396
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
397
+ // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
391
398
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
392
399
// CHECK: arith.addf
393
400
// CHECK: linalg.yield
0 commit comments