Skip to content

Commit dd71ca8

Browse files
authored
Arm backend: update quantizer/__init__
Differential Revision: D73530601 Pull Request resolved: #10408
1 parent 19fc7ef commit dd71ca8

22 files changed

+84
-57
lines changed

backends/arm/quantizer/TARGETS

+18-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
22

3+
# Exposed through __init__.py
4+
python_library(
5+
name = "quantization_config",
6+
srcs = ["quantization_config.py"],
7+
deps = [
8+
"//caffe2:torch",
9+
],
10+
)
11+
12+
# Exposed through __init__.py
313
python_library(
414
name = "arm_quantizer",
515
srcs = ["arm_quantizer.py"],
@@ -22,17 +32,19 @@ python_library(
2232
)
2333

2434
python_library(
25-
name = "quantization_config",
26-
srcs = ["quantization_config.py"],
35+
name = "arm_quantizer_utils",
36+
srcs = ["arm_quantizer_utils.py"],
2737
deps = [
28-
"//caffe2:torch",
38+
":quantization_config",
2939
],
3040
)
3141

3242
python_library(
33-
name = "arm_quantizer_utils",
34-
srcs = ["arm_quantizer_utils.py"],
43+
name = "lib",
44+
srcs = ["__init__.py"],
3545
deps = [
46+
":arm_quantizer",
3647
":quantization_config",
37-
],
48+
":arm_quantizer_utils",
49+
]
3850
)

backends/arm/quantizer/__init__.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,15 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from .quantization_config import QuantizationConfig # noqa # usort: skip
8+
from .arm_quantizer import ( # noqa
9+
EthosUQuantizer,
10+
get_symmetric_quantization_config,
11+
TOSAQuantizer,
12+
)
13+
14+
# Used in tests
15+
from .arm_quantizer_utils import is_annotated # noqa

backends/arm/quantizer/arm_quantizer.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,11 @@
1919
import torch
2020
from executorch.backends.arm._passes import ArmPassManager
2121

22-
from executorch.backends.arm.quantizer import arm_quantizer_utils
23-
from executorch.backends.arm.quantizer.arm_quantizer_utils import ( # type: ignore[attr-defined]
24-
mark_node_as_annotated,
25-
)
26-
from executorch.backends.arm.quantizer.quantization_annotator import ( # type: ignore[import-not-found]
27-
annotate_graph,
28-
)
29-
30-
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
22+
from executorch.backends.arm.quantizer import QuantizationConfig
3123
from executorch.backends.arm.tosa_specification import TosaSpecification
24+
25+
from .arm_quantizer_utils import is_annotated, mark_node_as_annotated
26+
from .quantization_annotator import annotate_graph
3227
from executorch.backends.arm.arm_backend import (
3328
get_tosa_spec,
3429
is_ethosu,
@@ -337,7 +332,7 @@ def _annotate_io(
337332
quantization_config: QuantizationConfig,
338333
):
339334
for node in model.graph.nodes:
340-
if arm_quantizer_utils.is_annotated(node):
335+
if is_annotated(node):
341336
continue
342337
if node.op == "placeholder" and len(node.users) > 0:
343338
_annotate_output_qspec(

backends/arm/quantizer/quantization_annotator.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010

1111
import torch
1212
import torch.fx
13-
from executorch.backends.arm.quantizer import arm_quantizer_utils
14-
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
13+
from executorch.backends.arm.quantizer import QuantizationConfig
1514
from executorch.backends.arm.tosa_utils import get_node_debug_info
1615
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
1716
from torch.ao.quantization.quantizer.utils import (
@@ -20,6 +19,13 @@
2019
)
2120
from torch.fx import Node
2221

22+
from .arm_quantizer_utils import (
23+
is_annotated,
24+
is_ok_for_quantization,
25+
is_output_annotated,
26+
mark_node_as_annotated,
27+
)
28+
2329
logger = logging.getLogger(__name__)
2430

2531

@@ -69,7 +75,7 @@ def _is_ok_for_quantization(
6975
"""
7076
# Check output
7177
if quant_properties.quant_output is not None:
72-
if not arm_quantizer_utils.is_ok_for_quantization(node, gm): # type: ignore[attr-defined]
78+
if not is_ok_for_quantization(node, gm): # type: ignore[attr-defined]
7379
logger.debug(
7480
f"Could not quantize node due to output: "
7581
f"{get_node_debug_info(node, gm)}"
@@ -87,7 +93,7 @@ def _is_ok_for_quantization(
8793

8894
for n_arg in _as_list(node.args[quant_property.index]):
8995
assert isinstance(n_arg, Node)
90-
if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined]
96+
if not is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined]
9197
logger.debug(
9298
f'could not quantize node due to input "{node}": '
9399
f"{get_node_debug_info(node, gm)}"
@@ -99,7 +105,7 @@ def _is_ok_for_quantization(
99105

100106

101107
def _annotate_input(node: Node, quant_property: _QuantProperty):
102-
assert not arm_quantizer_utils.is_annotated(node)
108+
assert not is_annotated(node)
103109
if quant_property.optional and (
104110
quant_property.index >= len(node.args)
105111
or node.args[quant_property.index] is None
@@ -114,11 +120,11 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
114120
assert isinstance(n_arg, Node)
115121
_annotate_input_qspec_map(node, n_arg, qspec)
116122
if quant_property.mark_annotated:
117-
arm_quantizer_utils.mark_node_as_annotated(n_arg) # type: ignore[attr-defined]
123+
mark_node_as_annotated(n_arg) # type: ignore[attr-defined]
118124

119125

120126
def _annotate_output(node: Node, quant_property: _QuantProperty):
121-
assert not arm_quantizer_utils.is_annotated(node)
127+
assert not is_annotated(node)
122128
assert not quant_property.mark_annotated
123129
assert not quant_property.optional
124130
assert quant_property.index == 0, "Only one output annotation supported currently"
@@ -343,7 +349,7 @@ def any_or_hardtanh_min_zero(n: Node):
343349
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
344350
input_qspec = (
345351
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
346-
if arm_quantizer_utils.is_output_annotated(node.args[0]) # type: ignore
352+
if is_output_annotated(node.args[0]) # type: ignore
347353
else input_act_qspec
348354
)
349355
quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] # type: ignore[arg-type]
@@ -396,7 +402,7 @@ def any_or_hardtanh_min_zero(n: Node):
396402
if not isinstance(node.args[0], Node):
397403
return None
398404

399-
if not arm_quantizer_utils.is_output_annotated(node.args[0]): # type: ignore[attr-defined]
405+
if not is_output_annotated(node.args[0]): # type: ignore[attr-defined]
400406
return None
401407

402408
shared_qspec = SharedQuantizationSpec(node.args[0])
@@ -426,7 +432,7 @@ def annotate_graph( # type: ignore[return]
426432
if node.op != "call_function":
427433
continue
428434

429-
if arm_quantizer_utils.is_annotated(node):
435+
if is_annotated(node):
430436
continue
431437

432438
if filter_fn is not None and not filter_fn(node):
@@ -442,7 +448,7 @@ def annotate_graph( # type: ignore[return]
442448
if quant_properties.quant_output is not None:
443449
_annotate_output(node, quant_properties.quant_output)
444450

445-
arm_quantizer_utils.mark_node_as_annotated(node) # type: ignore[attr-defined]
451+
mark_node_as_annotated(node) # type: ignore[attr-defined]
446452

447453
# Quantization does not allow kwargs for some reason.
448454
# Remove from ops we know have and where we know it does not break anything.

backends/arm/test/TARGETS

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ python_library(
4242
":common",
4343
"//executorch/backends/xnnpack/test/tester:tester",
4444
"//executorch/backends/arm:arm_partitioner",
45-
"//executorch/backends/arm/quantizer:arm_quantizer",
45+
"//executorch/backends/arm/quantizer:lib",
4646
"//executorch/backends/arm:tosa_mapping",
4747
"//executorch/devtools/backend_debug:delegation_info",
4848
"fbsource//third-party/pypi/tabulate:tabulate",

backends/arm/test/ops/test_expand.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import torch
1717

18-
from executorch.backends.arm.quantizer.arm_quantizer import (
18+
from executorch.backends.arm.quantizer import (
1919
EthosUQuantizer,
2020
get_symmetric_quantization_config,
2121
TOSAQuantizer,

backends/arm/test/ops/test_hardtanh.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515

16-
from executorch.backends.arm.quantizer.arm_quantizer import (
16+
from executorch.backends.arm.quantizer import (
1717
EthosUQuantizer,
1818
get_symmetric_quantization_config,
1919
TOSAQuantizer,

backends/arm/test/ops/test_max_pool.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytest
1313

1414
import torch
15-
from executorch.backends.arm.quantizer.arm_quantizer import (
15+
from executorch.backends.arm.quantizer import (
1616
EthosUQuantizer,
1717
get_symmetric_quantization_config,
1818
TOSAQuantizer,

backends/arm/test/ops/test_permute.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515

16-
from executorch.backends.arm.quantizer.arm_quantizer import (
16+
from executorch.backends.arm.quantizer import (
1717
EthosUQuantizer,
1818
get_symmetric_quantization_config,
1919
TOSAQuantizer,

backends/arm/test/ops/test_relu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Tuple
1111

1212
import torch
13-
from executorch.backends.arm.quantizer.arm_quantizer import (
13+
from executorch.backends.arm.quantizer import (
1414
EthosUQuantizer,
1515
get_symmetric_quantization_config,
1616
TOSAQuantizer,

backends/arm/test/ops/test_repeat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515

16-
from executorch.backends.arm.quantizer.arm_quantizer import (
16+
from executorch.backends.arm.quantizer import (
1717
EthosUQuantizer,
1818
get_symmetric_quantization_config,
1919
TOSAQuantizer,

backends/arm/test/ops/test_sigmoid_16bit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77

88
import torch
9-
from executorch.backends.arm.quantizer.arm_quantizer import (
9+
from executorch.backends.arm.quantizer import (
1010
get_symmetric_quantization_config,
1111
TOSAQuantizer,
1212
)

backends/arm/test/ops/test_sigmoid_32bit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77
import torch
8-
from executorch.backends.arm.quantizer.arm_quantizer import TOSAQuantizer
8+
from executorch.backends.arm.quantizer import TOSAQuantizer
99
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1010
from executorch.backends.arm.test import common
1111
from executorch.backends.arm.test.tester.test_pipeline import (

backends/arm/test/ops/test_var.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import unittest
1212

1313
import torch
14-
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
from executorch.backends.arm.quantizer import (
1515
EthosUQuantizer,
1616
get_symmetric_quantization_config,
1717
TOSAQuantizer,

backends/arm/test/ops/test_where.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111

12-
from executorch.backends.arm.quantizer.arm_quantizer import (
12+
from executorch.backends.arm.quantizer import (
1313
EthosUQuantizer,
1414
get_symmetric_quantization_config,
1515
TOSAQuantizer,

backends/arm/test/quantizer/test_generic_annotater.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import unittest
88

99
import torch
10-
from executorch.backends.arm.quantizer.arm_quantizer_utils import is_annotated
10+
from executorch.backends.arm.quantizer import is_annotated
1111
from executorch.backends.arm.test import common
1212
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1313
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions

backends/arm/test/targets.bzl

+5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ def define_arm_tests():
1919
"ops/test_tanh.py",
2020
]
2121

22+
# Quantization
23+
test_files += [
24+
"quantizer/test_generic_annotater.py",
25+
]
26+
2227
TESTS = {}
2328

2429
for test_file in test_files:

backends/arm/test/tester/arm_tester.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
is_tosa,
2828
)
2929
from executorch.backends.arm.ethosu_partitioner import EthosUPartitioner
30-
from executorch.backends.arm.quantizer.arm_quantizer import (
30+
from executorch.backends.arm.quantizer import (
3131
EthosUQuantizer,
3232
get_symmetric_quantization_config,
3333
TOSAQuantizer,

backends/arm/test/tester/test_pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010

11-
from executorch.backends.arm.quantizer.arm_quantizer import (
11+
from executorch.backends.arm.quantizer import (
1212
EthosUQuantizer,
1313
get_symmetric_quantization_config,
1414
TOSAQuantizer,

backends/arm/tosa_quant_utils.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
import math
1111
from typing import cast, List, NamedTuple, Tuple
1212

13-
import executorch.backends.arm.tosa_mapping
14-
1513
import torch.fx
1614
import torch.fx.node
1715

@@ -234,7 +232,7 @@ def build_rescale(
234232

235233
def build_rescale_to_int32(
236234
tosa_fb: ts.TosaSerializer,
237-
input_arg: executorch.backends.arm.tosa_mapping.TosaArg,
235+
input_arg: TosaArg,
238236
input_zp: int,
239237
rescale_scale: list[float],
240238
is_scale32: bool = True,

examples/arm/aot_arm_compiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
is_tosa,
2525
)
2626
from executorch.backends.arm.ethosu_partitioner import EthosUPartitioner
27-
from executorch.backends.arm.quantizer.arm_quantizer import (
27+
from executorch.backends.arm.quantizer import (
2828
EthosUQuantizer,
2929
get_symmetric_quantization_config,
3030
TOSAQuantizer,

0 commit comments

Comments
 (0)