@@ -363,11 +363,13 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
363
363
364
364
// CHECK-LABEL: @conv2d_i8
365
365
func.func @conv2d_i8 (%input: tensor <1 x49 x42 x27 xi8 >, %weights: tensor <28 x1 x1 x27 xi8 >, %bias: tensor <28 xi8 >) -> () {
366
+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
367
+ // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
366
368
// CHECK: %[[M_IN:.+]] = tensor.empty()
367
369
// CHECK: %[[CST:.+]] = arith.constant 0
368
370
// CHECK: %[[FILL:.+]] = linalg.fill
369
371
// CHECK: %[[B_IN:.+]] = tensor.empty()
370
- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 , %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8 >, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
372
+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] , %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8 >, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
371
373
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
372
374
// CHECK: arith.extsi
373
375
// CHECK: arith.addi
@@ -383,11 +385,13 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
383
385
384
386
// CHECK-LABEL: @conv2d_f32
385
387
func.func @conv2d_f32 (%input: tensor <1 x49 x42 x27 xf32 >, %weights: tensor <28 x3 x3 x27 xf32 >, %bias: tensor <28 xf32 >) -> () {
388
+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
389
+ // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
386
390
// CHECK: %[[M_IN:.+]] = tensor.empty()
387
391
// CHECK: %[[CST:.+]] = arith.constant 0
388
392
// CHECK: %[[FILL:.+]] = linalg.fill
389
393
// CHECK: %[[B_IN:.+]] = tensor.empty()
390
- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
394
+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
391
395
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
392
396
// CHECK: arith.addf
393
397
// CHECK: linalg.yield
@@ -404,11 +408,13 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
404
408
func.func @conv2d_dyn (%input: tensor <?x49 x42 x27 xf32 >, %weights: tensor <28 x3 x3 x27 xf32 >, %bias: tensor <28 xf32 >) -> () {
405
409
// CHECK: %[[C0:.+]] = arith.constant 0
406
410
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
411
+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
412
+ // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
407
413
// CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]])
408
414
// CHECK: %[[CST:.+]] = arith.constant 0
409
415
// CHECK: %[[FILL:.+]] = linalg.fill
410
416
// CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]])
411
- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
417
+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<?x49x42x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
412
418
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
413
419
// CHECK: %[[ADD:.+]] = arith.addf
414
420
// CHECK: linalg.yield %[[ADD]] : f32
@@ -462,11 +468,13 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
462
468
// CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index
463
469
464
470
// Running convolution
471
+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
472
+ // CHECK: %[[WEIGHT:.+]] = tosa.transpose %arg1, %[[PERM]]
465
473
// CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
466
474
// CHECK: %[[CST:.+]] = arith.constant 0
467
475
// CHECK: %[[FILL:.+]] = linalg.fill
468
476
// CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
469
- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
477
+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
470
478
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
471
479
// CHECK: %[[ADD:.+]] = arith.addf
472
480
// CHECK: linalg.yield %[[ADD]] : f32
@@ -481,7 +489,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28
481
489
// CHECK: %[[C0:.+]] = arith.constant 0
482
490
// CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
483
491
// CHECK: tensor.yield %[[C0]]
484
- // CHECK: linalg.conv_2d_nhwc_fhwc
492
+ // CHECK: linalg.conv_2d_nhwc_hwcf
485
493
%0 = tosa.conv2d %input , %weights , %bias {pad = array<i64 : 1 , 1 , 1 , 1 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 2 , 1 >} : (tensor <1 x47 x40 x28 xf32 >, tensor <28 x3 x3 x28 xf32 >, tensor <28 xf32 >) -> tensor <1 x45 x40 x28 xf32 >
486
494
return
487
495
}
@@ -493,7 +501,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
493
501
// CHECK: %[[C22:.+]] = arith.constant -22
494
502
// CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
495
503
// CHECK: tensor.yield %[[C22]]
496
- // CHECK: linalg.conv_2d_nhwc_fhwc_q
504
+ // CHECK: linalg.conv_2d_nhwc_hwcf_q
497
505
%0 = tosa.conv2d %arg0 , %arg1 , %arg2 {dilation = array<i64 : 1 , 1 >, pad = array<i64 : 1 , 1 , 1 , 1 >, quantization_info = #tosa.conv_quant <input_zp = -22 , weight_zp = 42 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x12 x12 x1 xi8 >, tensor <1024 x3 x3 x1 xi8 >, tensor <1024 xi32 >) -> tensor <1 x12 x12 x1024 xi32 >
498
506
return
499
507
}
0 commit comments