Skip to content

Commit 4d0c5c9

Browse files
fumchinDannyYuyang-quic
authored andcommitted
Arm backend: Refactor Ops Tests for AvgPool, Clamp, Clone, Conv1d, Sub (pytorch#9039)
Refactored multiple operator test files: - Replaced `@parameterized.expand` with `@common.parametrize`. - Removed `_test_*_pipeline()` functions and replaced them with `TosaPipelineMI` and `TosaPipelineBI`. - Fixed padding handling for max pooling and avg pooling. - Updated pipeline infrastructure to simplify replacing the default quantize stage. Signed-off-by: Fang-Ching <[email protected]>
1 parent 5993741 commit 4d0c5c9

File tree

10 files changed

+837
-704
lines changed

10 files changed

+837
-704
lines changed

backends/arm/operators/op_avg_pool2d.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,19 @@ def _build_generic_avgpool2d(
4343
output_zp: int,
4444
accumulator_type: ts.DType,
4545
) -> None:
46-
input_tensor = inputs[0]
4746

47+
input_tensor = inputs[0]
4848
kernel_size_list = inputs[1].special
4949
stride_size_list = inputs[2].special
50+
5051
try:
5152
pad_size_list = inputs[3].special
53+
pad_size_list = [
54+
pad_size_list[0],
55+
pad_size_list[0],
56+
pad_size_list[1],
57+
pad_size_list[1],
58+
]
5259
except IndexError:
5360
pad_size_list = [0, 0, 0, 0]
5461

backends/arm/operators/op_max_pool2d.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,15 @@ def define_node(
4242
stride = inputs[2].special
4343

4444
try:
45-
padding = [*inputs[3].special, *inputs[3].special]
45+
pad_size_list = inputs[3].special
46+
pad_size_list = [
47+
pad_size_list[0],
48+
pad_size_list[0],
49+
pad_size_list[1],
50+
pad_size_list[1],
51+
]
4652
except IndexError:
47-
padding = [0, 0, 0, 0]
53+
pad_size_list = [0, 0, 0, 0]
4854

4955
accumulator_type = output.dtype
5056

@@ -63,7 +69,7 @@ def define_node(
6369
attr.PoolAttribute(
6470
kernel=kernel_size,
6571
stride=stride,
66-
pad=padding,
72+
pad=pad_size_list,
6773
input_zp=input_zp,
6874
output_zp=output_zp,
6975
accum_dtype=accumulator_type,

backends/arm/test/ops/test_avg_pool.py

Lines changed: 0 additions & 210 deletions
This file was deleted.

0 commit comments

Comments
 (0)