Skip to content

Arm backend: update quantizer/__init__ #10408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions backends/arm/quantizer/TARGETS
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

# Exposed through __init__.py
python_library(
name = "quantization_config",
srcs = ["quantization_config.py"],
deps = [
"//caffe2:torch",
],
)

# Exposed through __init__.py
python_library(
name = "arm_quantizer",
srcs = ["arm_quantizer.py"],
Expand All @@ -22,17 +32,19 @@ python_library(
)

python_library(
name = "quantization_config",
srcs = ["quantization_config.py"],
name = "arm_quantizer_utils",
srcs = ["arm_quantizer_utils.py"],
deps = [
"//caffe2:torch",
":quantization_config",
],
)

python_library(
name = "arm_quantizer_utils",
srcs = ["arm_quantizer_utils.py"],
name = "lib",
srcs = ["__init__.py"],
deps = [
":arm_quantizer",
":quantization_config",
],
":arm_quantizer_utils",
]
)
13 changes: 12 additions & 1 deletion backends/arm/quantizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from .quantization_config import QuantizationConfig # noqa # usort: skip
from .arm_quantizer import ( # noqa
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
)

# Used in tests
from .arm_quantizer_utils import is_annotated # noqa
15 changes: 5 additions & 10 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,11 @@
import torch
from executorch.backends.arm._passes import ArmPassManager

from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.arm_quantizer_utils import ( # type: ignore[attr-defined]
mark_node_as_annotated,
)
from executorch.backends.arm.quantizer.quantization_annotator import ( # type: ignore[import-not-found]
annotate_graph,
)

from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.arm.quantizer import QuantizationConfig
from executorch.backends.arm.tosa_specification import TosaSpecification

from .arm_quantizer_utils import is_annotated, mark_node_as_annotated
from .quantization_annotator import annotate_graph
from executorch.backends.arm.arm_backend import (
get_tosa_spec,
is_ethosu,
Expand Down Expand Up @@ -337,7 +332,7 @@ def _annotate_io(
quantization_config: QuantizationConfig,
):
for node in model.graph.nodes:
if arm_quantizer_utils.is_annotated(node):
if is_annotated(node):
continue
if node.op == "placeholder" and len(node.users) > 0:
_annotate_output_qspec(
Expand Down
28 changes: 17 additions & 11 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

import torch
import torch.fx
from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.arm.quantizer import QuantizationConfig
from executorch.backends.arm.tosa_utils import get_node_debug_info
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
from torch.ao.quantization.quantizer.utils import (
Expand All @@ -20,6 +19,13 @@
)
from torch.fx import Node

from .arm_quantizer_utils import (
is_annotated,
is_ok_for_quantization,
is_output_annotated,
mark_node_as_annotated,
)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -69,7 +75,7 @@ def _is_ok_for_quantization(
"""
# Check output
if quant_properties.quant_output is not None:
if not arm_quantizer_utils.is_ok_for_quantization(node, gm): # type: ignore[attr-defined]
if not is_ok_for_quantization(node, gm): # type: ignore[attr-defined]
logger.debug(
f"Could not quantize node due to output: "
f"{get_node_debug_info(node, gm)}"
Expand All @@ -87,7 +93,7 @@ def _is_ok_for_quantization(

for n_arg in _as_list(node.args[quant_property.index]):
assert isinstance(n_arg, Node)
if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined]
if not is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined]
logger.debug(
f'could not quantize node due to input "{node}": '
f"{get_node_debug_info(node, gm)}"
Expand All @@ -99,7 +105,7 @@ def _is_ok_for_quantization(


def _annotate_input(node: Node, quant_property: _QuantProperty):
assert not arm_quantizer_utils.is_annotated(node)
assert not is_annotated(node)
if quant_property.optional and (
quant_property.index >= len(node.args)
or node.args[quant_property.index] is None
Expand All @@ -114,11 +120,11 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
assert isinstance(n_arg, Node)
_annotate_input_qspec_map(node, n_arg, qspec)
if quant_property.mark_annotated:
arm_quantizer_utils.mark_node_as_annotated(n_arg) # type: ignore[attr-defined]
mark_node_as_annotated(n_arg) # type: ignore[attr-defined]


def _annotate_output(node: Node, quant_property: _QuantProperty):
assert not arm_quantizer_utils.is_annotated(node)
assert not is_annotated(node)
assert not quant_property.mark_annotated
assert not quant_property.optional
assert quant_property.index == 0, "Only one output annotation supported currently"
Expand Down Expand Up @@ -343,7 +349,7 @@ def any_or_hardtanh_min_zero(n: Node):
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
input_qspec = (
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
if arm_quantizer_utils.is_output_annotated(node.args[0]) # type: ignore
if is_output_annotated(node.args[0]) # type: ignore
else input_act_qspec
)
quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] # type: ignore[arg-type]
Expand Down Expand Up @@ -396,7 +402,7 @@ def any_or_hardtanh_min_zero(n: Node):
if not isinstance(node.args[0], Node):
return None

if not arm_quantizer_utils.is_output_annotated(node.args[0]): # type: ignore[attr-defined]
if not is_output_annotated(node.args[0]): # type: ignore[attr-defined]
return None

shared_qspec = SharedQuantizationSpec(node.args[0])
Expand Down Expand Up @@ -426,7 +432,7 @@ def annotate_graph( # type: ignore[return]
if node.op != "call_function":
continue

if arm_quantizer_utils.is_annotated(node):
if is_annotated(node):
continue

if filter_fn is not None and not filter_fn(node):
Expand All @@ -442,7 +448,7 @@ def annotate_graph( # type: ignore[return]
if quant_properties.quant_output is not None:
_annotate_output(node, quant_properties.quant_output)

arm_quantizer_utils.mark_node_as_annotated(node) # type: ignore[attr-defined]
mark_node_as_annotated(node) # type: ignore[attr-defined]

# Quantization does not allow kwargs for some reason.
# Remove from ops we know have and where we know it does not break anything.
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ python_library(
":common",
"//executorch/backends/xnnpack/test/tester:tester",
"//executorch/backends/arm:arm_partitioner",
"//executorch/backends/arm/quantizer:arm_quantizer",
"//executorch/backends/arm/quantizer:lib",
"//executorch/backends/arm:tosa_mapping",
"//executorch/devtools/backend_debug:delegation_info",
"fbsource//third-party/pypi/tabulate:tabulate",
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_hardtanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch

from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_max_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch

from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Tuple

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch

from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_sigmoid_16bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_sigmoid_32bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest
import torch
from executorch.backends.arm.quantizer.arm_quantizer import TOSAQuantizer
from executorch.backends.arm.quantizer import TOSAQuantizer
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import unittest

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/quantizer/test_generic_annotater.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unittest

import torch
from executorch.backends.arm.quantizer.arm_quantizer_utils import is_annotated
from executorch.backends.arm.quantizer import is_annotated
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
Expand Down
5 changes: 5 additions & 0 deletions backends/arm/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def define_arm_tests():
"ops/test_tanh.py",
]

# Quantization
test_files += [
"quantizer/test_generic_annotater.py",
]

TESTS = {}

for test_file in test_files:
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
is_tosa,
)
from executorch.backends.arm.ethosu_partitioner import EthosUPartitioner
from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/tester/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
4 changes: 1 addition & 3 deletions backends/arm/tosa_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import math
from typing import cast, List, NamedTuple, Tuple

import executorch.backends.arm.tosa_mapping

import torch.fx
import torch.fx.node

Expand Down Expand Up @@ -234,7 +232,7 @@ def build_rescale(

def build_rescale_to_int32(
tosa_fb: ts.TosaSerializer,
input_arg: executorch.backends.arm.tosa_mapping.TosaArg,
input_arg: TosaArg,
input_zp: int,
rescale_scale: list[float],
is_scale32: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion examples/arm/aot_arm_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
is_tosa,
)
from executorch.backends.arm.ethosu_partitioner import EthosUPartitioner
from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
Loading
Loading