Skip to content

Commit c35386f

Browse files
ArmQuantizer: Quantize AvgPool2d (#5873)
Summary: - AvgPool2d is quantized by propagating the parent nodes output quantization spec using a shared quantization spec. - Adds Conv2D + AvgPool2d unittest. - Enables AvgPool2d BI unittests. Pull Request resolved: #5873 Reviewed By: mergennachin Differential Revision: D64047355 Pulled By: digantdesai fbshipit-source-id: 811a18f828c46d7e6727565579d73677df817dcf
1 parent c44f334 commit c35386f

File tree

3 files changed

+71
-12
lines changed

3 files changed

+71
-12
lines changed

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
151151
torch.ops.aten.unsqueeze.default,
152152
# TODO: remove?
153153
torch.ops.aten.adaptive_avg_pool2d.default,
154+
torch.ops.aten.avg_pool2d.default,
154155
torch.ops.aten.view_copy.default,
155156
torch.ops.aten.view.default,
156157
torch.ops.aten.slice.Tensor,

backends/arm/test/ops/test_avg_pool.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,21 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import logging
98
import unittest
109

1110
from typing import Tuple
1211

1312
import torch
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
ArmQuantizer,
15+
get_symmetric_quantization_config,
16+
)
1417
from executorch.backends.arm.test import common
1518
from executorch.backends.arm.test.tester.arm_tester import ArmTester
19+
from executorch.backends.xnnpack.test.tester.tester import Quantize
1620
from executorch.exir.backend.backend_details import CompileSpec
1721
from parameterized import parameterized
1822

19-
logger = logging.getLogger(__name__)
20-
logger.setLevel(logging.INFO)
2123

2224
test_data_suite = [
2325
# (test_name, test_data, [kernel_size, stride, padding])
@@ -69,13 +71,14 @@ def _test_avgpool2d_tosa_MI_pipeline(
6971
def _test_avgpool2d_tosa_BI_pipeline(
7072
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
7173
):
74+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
7275
(
7376
ArmTester(
7477
module,
7578
example_inputs=test_data,
7679
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
7780
)
78-
.quantize()
81+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
7982
.export()
8083
.check_count({"torch.ops.aten.avg_pool2d.default": 1})
8184
.check(["torch.ops.quantized_decomposed"])
@@ -93,13 +96,14 @@ def _test_avgpool2d_tosa_ethos_BI_pipeline(
9396
compile_spec: CompileSpec,
9497
test_data: Tuple[torch.tensor],
9598
):
99+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
96100
(
97101
ArmTester(
98102
module,
99103
example_inputs=test_data,
100104
compile_spec=compile_spec,
101105
)
102-
.quantize()
106+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
103107
.export()
104108
.check_count({"torch.ops.aten.avg_pool2d.default": 1})
105109
.check(["torch.ops.quantized_decomposed"])
@@ -121,10 +125,7 @@ def test_avgpool2d_tosa_MI(
121125
self.AvgPool2d(*model_params), (test_data,)
122126
)
123127

124-
# Expected to fail since ArmQuantizer cannot quantize a AvgPool2D layer
125-
# TODO(MLETORCH-93)
126128
@parameterized.expand(test_data_suite)
127-
@unittest.expectedFailure
128129
def test_avgpool2d_tosa_BI(
129130
self,
130131
test_name: str,
@@ -135,10 +136,7 @@ def test_avgpool2d_tosa_BI(
135136
self.AvgPool2d(*model_params), (test_data,)
136137
)
137138

138-
# Expected to fail since ArmQuantizer cannot quantize a AvgPool2D layer
139-
# TODO(MLETORCH-93)
140139
@parameterized.expand(test_data_suite)
141-
@unittest.expectedFailure
142140
def test_avgpool2d_tosa_u55_BI(
143141
self,
144142
test_name: str,
@@ -152,7 +150,6 @@ def test_avgpool2d_tosa_u55_BI(
152150
)
153151

154152
@parameterized.expand(test_data_suite)
155-
@unittest.expectedFailure
156153
def test_avgpool2d_tosa_u85_BI(
157154
self,
158155
test_name: str,

backends/arm/test/ops/test_conv_combos.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,32 @@ def forward(self, x):
156156
return x
157157

158158

159+
class ComboConvAvgPool2d(torch.nn.Module):
160+
edge_op_list = [
161+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
162+
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default",
163+
]
164+
165+
test_data = [
166+
(20 * torch.randn(1, 3, 64, 32),),
167+
(torch.randn(1, 3, 100, 200),),
168+
(5 * torch.randn(1, 3, 256, 256),),
169+
(torch.rand(1, 3, 512, 128),),
170+
]
171+
172+
def __init__(self):
173+
super().__init__()
174+
self.conv2d = torch.nn.Conv2d(
175+
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
176+
)
177+
self.avg_pool2d = torch.nn.AvgPool2d(kernel_size=(2, 2))
178+
179+
def forward(self, x):
180+
x = self.conv2d(x)
181+
x = self.avg_pool2d(x)
182+
return x
183+
184+
159185
class TestConvCombos(unittest.TestCase):
160186
"""Tests conv combined with other ops."""
161187

@@ -334,3 +360,38 @@ def test_block_bottleneck_residual_u85_BI(self):
334360
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
335361
model.get_inputs(),
336362
)
363+
364+
######################
365+
## Conv + AvgPool2d ##
366+
######################
367+
@parameterized.expand(ComboConvAvgPool2d.test_data)
368+
def test_conv_avgpool2d_tosa_MI(self, test_data: torch.Tensor):
369+
model = ComboConvAvgPool2d()
370+
test_data = (test_data,)
371+
self._test_conv_combo_tosa_MI_pipeline(model, test_data)
372+
373+
@parameterized.expand(ComboConvAvgPool2d.test_data)
374+
def test_conv_avgpool2d_tosa_BI(self, test_data: torch.Tensor):
375+
model = ComboConvAvgPool2d()
376+
test_data = (test_data,)
377+
self._test_conv_combo_tosa_BI_pipeline(model, test_data)
378+
379+
@parameterized.expand(ComboConvAvgPool2d.test_data)
380+
def test_conv_avgpool2d_u55_BI(self, test_data: torch.Tensor):
381+
model = ComboConvAvgPool2d()
382+
test_data = (test_data,)
383+
self._test_conv_combo_ethos_BI_pipeline(
384+
model,
385+
common.get_u55_compile_spec(),
386+
test_data,
387+
)
388+
389+
@parameterized.expand(ComboConvAvgPool2d.test_data)
390+
def test_conv_avgpool2d_u85_BI(self, test_data: torch.Tensor):
391+
model = ComboConvAvgPool2d()
392+
test_data = (test_data,)
393+
self._test_conv_combo_ethos_BI_pipeline(
394+
model,
395+
common.get_u85_compile_spec(),
396+
test_data,
397+
)

0 commit comments

Comments
 (0)