Skip to content

Commit ad04193

Browse files
Add missing __init__.py to passes
backends/arm/_passes missed an init file which caused pyre to not find imports from _passes. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I80e0c2071844ec0fefeb9fd44f9471ed1be32228
1 parent 30d4cc8 commit ad04193

15 files changed

+90
-120
lines changed

backends/arm/_passes/__init__.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from . import arm_pass_utils # noqa
8+
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
9+
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
10+
from .cast_int64_pass import CastInt64ToInt32Pass # noqa
11+
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
12+
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa
13+
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
14+
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
15+
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
16+
from .convert_to_clamp import ConvertToClampPass # noqa
17+
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
18+
from .decompose_div_pass import DecomposeDivPass # noqa
19+
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
20+
from .decompose_linear_pass import DecomposeLinearPass # noqa
21+
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
22+
from .decompose_select import DecomposeSelectPass # noqa
23+
from .decompose_softmaxes_pass import DecomposeSoftmaxesPass # noqa
24+
from .decompose_var_pass import DecomposeVarPass # noqa
25+
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
26+
FoldAndAnnotateQParamsPass,
27+
get_input_qparams,
28+
get_output_qparams,
29+
QuantizeOperatorArguments,
30+
RetraceFoldedDtypesPass,
31+
)
32+
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
33+
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
34+
from .insert_rescales_pass import InsertRescalePass # noqa
35+
from .insert_table_ops import InsertTableOpsPass # noqa
36+
from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa
37+
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
38+
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
39+
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
40+
from .remove_clone_pass import RemoveClonePass # noqa
41+
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
42+
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
43+
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
44+
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
45+
from .arm_pass_manager import ArmPassManager # noqa # usort: skip

backends/arm/_passes/arm_pass_manager.py

Lines changed: 18 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,73 +7,39 @@
77

88
# pyre-unsafe
99

10-
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
10+
from executorch.backends.arm._passes import (
1111
AnnotateChannelsLastDimOrder,
12-
)
13-
from executorch.backends.arm._passes.annotate_decomposed_matmul import (
1412
AnnotateDecomposedMatmulPass,
15-
)
16-
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
17-
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
18-
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
13+
CastInt64ToInt32Pass,
14+
Conv1dUnsqueezePass,
1915
ConvertExpandCopyToRepeatPass,
20-
)
21-
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
2216
ConvertFullLikeToFullPass,
23-
)
24-
from executorch.backends.arm._passes.convert_split_to_slice import (
17+
ConvertMeanDimToAveragePoolPass,
18+
ConvertMmToBmmPass,
2519
ConvertSplitToSlicePass,
26-
)
27-
from executorch.backends.arm._passes.convert_squeezes_to_view import ( # type: ignore[import-not-found]
2820
ConvertSqueezesToViewPass,
29-
)
30-
from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass
31-
from executorch.backends.arm._passes.decompose_batchnorm_pass import (
21+
ConvertToClampPass,
3222
DecomposeBatchNormPass,
33-
)
34-
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
35-
from executorch.backends.arm._passes.decompose_layernorm_pass import (
23+
DecomposeDivPass,
3624
DecomposeLayerNormPass,
37-
)
38-
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
39-
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
40-
from executorch.backends.arm._passes.decompose_select import ( # type: ignore[import-not-found]
25+
DecomposeLinearPass,
26+
DecomposeMeanDimPass,
4127
DecomposeSelectPass,
42-
)
43-
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
4428
DecomposeSoftmaxesPass,
45-
)
46-
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
47-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
29+
DecomposeVarPass,
4830
FoldAndAnnotateQParamsPass,
49-
QuantizeOperatorArguments,
50-
RetraceFoldedDtypesPass,
51-
)
52-
from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
53-
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
31+
FuseBatchnorm2DPass,
5432
FuseQuantizedActivationPass,
55-
)
56-
from executorch.backends.arm._passes.insert_rescales_pass import InsertRescalePass
57-
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
58-
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
33+
InsertRescalePass,
34+
InsertTableOpsPass,
5935
KeepDimsFalseToSqueezePass,
60-
)
61-
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
62-
from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( # type: ignore[attr-defined]
63-
ConvertMeanDimToAveragePoolPass,
64-
)
65-
from executorch.backends.arm._passes.mm_to_bmm_pass import ( # type: ignore[import-not-found]
66-
ConvertMmToBmmPass,
67-
)
68-
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
69-
from executorch.backends.arm._passes.scalars_to_attribute_pass import (
36+
MatchArgRanksPass,
37+
QuantizeOperatorArguments,
38+
RemoveClonePass,
39+
RetraceFoldedDtypesPass,
7040
ScalarsToAttributePass,
71-
)
72-
from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
73-
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
41+
SizeAdjustConv2DPass,
7442
UnsqueezeBeforeRepeatPass,
75-
)
76-
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
7743
UnsqueezeScalarPlaceholdersPass,
7844
)
7945
from executorch.backends.arm.tosa_specification import TosaSpecification

backends/arm/operators/op_avg_pool2d.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@
99
import serializer.tosa_serializer as ts # type: ignore
1010
import torch
1111

12-
# pyre-fixme[21]: ' Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`
13-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
14-
get_input_qparams,
15-
get_output_qparams,
16-
)
12+
from executorch.backends.arm._passes import get_input_qparams, get_output_qparams
1713
from executorch.backends.arm.operators.node_visitor import (
1814
NodeVisitor,
1915
register_node_visitor,
@@ -81,10 +77,10 @@ def define_node(
8177

8278
accumulator_type = ts.DType.INT32
8379

84-
input_qargs = get_input_qparams(node) # pyre-ignore[16]
80+
input_qargs = get_input_qparams(node)
8581
input_zp = input_qargs[0].zp
8682

87-
output_qargs = get_output_qparams(node) # pyre-ignore[16]
83+
output_qargs = get_output_qparams(node)
8884
output_zp = output_qargs[0].zp
8985

9086
self._build_generic_avgpool2d(

backends/arm/operators/op_bmm.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@
1010
import serializer.tosa_serializer as ts # type: ignore
1111
import torch
1212

13-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
14-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
15-
get_input_qparams,
16-
get_output_qparams,
17-
)
13+
from executorch.backends.arm._passes import get_input_qparams, get_output_qparams
1814
from executorch.backends.arm.operators.node_visitor import (
1915
NodeVisitor,
2016
register_node_visitor,
@@ -51,7 +47,7 @@ def define_node(
5147
# for a later rescale.
5248

5349
if inputs[0].dtype == ts.DType.INT8:
54-
input_qparams = get_input_qparams(node) # pyre-ignore[16]
50+
input_qparams = get_input_qparams(node)
5551
input0_zp = input_qparams[0].zp
5652
input1_zp = input_qparams[1].zp
5753
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
@@ -73,7 +69,7 @@ def define_node(
7369

7470
# As INT8 accumulates into INT32, we need to rescale it back to INT8
7571
if output.dtype == ts.DType.INT8:
76-
output_qparams = get_output_qparams(node)[0] # pyre-ignore[16]
72+
output_qparams = get_output_qparams(node)[0]
7773
final_output_scale = (
7874
input_qparams[0].scale * input_qparams[1].scale # type: ignore[possibly-undefined] # pyre-ignore[61]
7975
) / output_qparams.scale

backends/arm/operators/op_conv2d.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@
99
import serializer.tosa_serializer as ts # type: ignore
1010
import torch
1111

12-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
13-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
14-
get_input_qparams,
15-
get_output_qparams,
16-
)
12+
from executorch.backends.arm._passes import get_input_qparams, get_output_qparams
1713
from executorch.backends.arm.operators.node_visitor import (
1814
NodeVisitor,
1915
register_node_visitor,
@@ -85,7 +81,7 @@ def define_node(
8581
input_zp = 0
8682
if inputs[0].dtype == ts.DType.INT8:
8783
# int8 input requires quantization information
88-
input_qparams = get_input_qparams(node) # pyre-ignore[16]
84+
input_qparams = get_input_qparams(node)
8985
input_zp = input_qparams[0].zp
9086

9187
attr.ConvAttribute(
@@ -169,7 +165,7 @@ def define_node(
169165
# Get scale_factor from input, weight, and output.
170166
input_scale = input_qparams[0].scale # type: ignore[possibly-undefined] # pyre-ignore [61]
171167
weight_scale = input_qparams[1].scale # pyre-ignore [61]
172-
output_qargs = get_output_qparams(node) # pyre-ignore [16]
168+
output_qargs = get_output_qparams(node)
173169
build_rescale_conv_output(
174170
tosa_graph,
175171
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.

backends/arm/operators/op_max.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@
1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111
import serializer.tosa_serializer as ts # type: ignore
1212

13-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
14-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
15-
get_input_qparams,
16-
)
13+
from executorch.backends.arm._passes import get_input_qparams
1714
from executorch.backends.arm.operators.node_visitor import (
1815
NodeVisitor,
1916
register_node_visitor,
@@ -44,9 +41,7 @@ def define_node(
4441
scale_back = 1.0
4542
max_output = output
4643
if inputs[0].dtype == ts.DType.INT8:
47-
input_qparams = get_input_qparams( # pyre-ignore[16]: 'Module `executorch.backends.arm` has no attribute `_passes`.'
48-
node
49-
)
44+
input_qparams = get_input_qparams(node)
5045
assert (
5146
len(input_qparams) == 2
5247
), f"Both inputs needs to have quantization information for {node}"

backends/arm/operators/op_max_pool2d.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@
99
import serializer.tosa_serializer as ts # type: ignore
1010
import torch
1111

12-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
13-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
14-
get_input_qparams,
15-
get_output_qparams,
16-
)
12+
from executorch.backends.arm._passes import get_input_qparams, get_output_qparams
1713
from executorch.backends.arm.operators.node_visitor import (
1814
NodeVisitor,
1915
register_node_visitor,
@@ -51,12 +47,12 @@ def define_node(
5147
# Initilize zero point to zero.
5248
input_zp = 0
5349
if inputs[0].dtype == ts.DType.INT8:
54-
input_qparams = get_input_qparams(node) # pyre-ignore[16]
50+
input_qparams = get_input_qparams(node)
5551
input_zp = input_qparams[0].zp
5652

5753
output_zp = 0
5854
if output.dtype == ts.DType.INT8:
59-
output_qparams = get_output_qparams(node) # pyre-ignore[16]
55+
output_qparams = get_output_qparams(node)
6056
output_zp = output_qparams[0].zp
6157

6258
attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_min.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111

1212
import serializer.tosa_serializer as ts # type: ignore
1313

14-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
15-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
16-
get_input_qparams,
17-
)
14+
from executorch.backends.arm._passes import get_input_qparams
1815
from executorch.backends.arm.operators.node_visitor import (
1916
NodeVisitor,
2017
register_node_visitor,
@@ -45,9 +42,7 @@ def define_node(
4542
scale_back = 1.0
4643
min_output = output
4744
if inputs[0].dtype == ts.DType.INT8:
48-
input_qparams = get_input_qparams( # pyre-ignore[16]: 'Module `executorch.backends.arm` has no attribute `_passes`.'
49-
node
50-
)
45+
input_qparams = get_input_qparams(node)
5146
assert (
5247
len(input_qparams) == 2
5348
), f"Both inputs needs to have quantization information for {node}"

backends/arm/operators/op_mul.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
import serializer.tosa_serializer as ts # type: ignore
1414
import torch
1515

16-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
17-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
18-
get_input_qparams,
19-
)
16+
from executorch.backends.arm._passes import get_input_qparams
2017

2118
from executorch.backends.arm.operators.node_visitor import (
2219
NodeVisitor,
@@ -45,7 +42,7 @@ def define_node(
4542
assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8
4643
input_A = inputs[0]
4744
input_B = inputs[1]
48-
input_qparams = get_input_qparams(node) # pyre-ignore[16]
45+
input_qparams = get_input_qparams(node)
4946
input_A_qargs = input_qparams[0]
5047
input_B_qargs = input_qparams[1]
5148
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any, Callable, Dict, List, Optional
1818

1919
import torch
20-
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
20+
from executorch.backends.arm._passes import ArmPassManager
2121

2222
from executorch.backends.arm.quantizer import arm_quantizer_utils
2323
from executorch.backends.arm.quantizer.arm_quantizer_utils import ( # type: ignore[attr-defined]

backends/arm/test/passes/test_fold_qdq_pass.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from typing import Tuple
77

88
import torch
9-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
10-
FoldAndAnnotateQParamsPass,
11-
)
9+
from executorch.backends.arm._passes import FoldAndAnnotateQParamsPass
1210
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1311

1412

backends/arm/test/passes/test_meandim_to_averagepool2d.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
from typing import Tuple
88

99
import torch
10-
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
11-
ConvertMeanDimToAveragePoolPass,
12-
)
10+
from executorch.backends.arm._passes import ConvertMeanDimToAveragePoolPass
1311
from executorch.backends.arm.test import common
1412
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1513

backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from typing import Dict, Tuple
77

88
import torch
9-
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
10-
UnsqueezeBeforeRepeatPass,
11-
)
9+
from executorch.backends.arm._passes import UnsqueezeBeforeRepeatPass
1210
from executorch.backends.arm.test import common
1311
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1412

backends/arm/tosa_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from executorch.backends.arm.arm_backend import get_tosa_spec
2020
from executorch.backends.arm.operators.node_visitor import get_node_visitors
21-
from executorch.backends.arm._passes.arm_pass_manager import (
21+
from executorch.backends.arm._passes import (
2222
ArmPassManager,
2323
) # usort: skip
2424
from executorch.backends.arm.process_node import (

0 commit comments

Comments
 (0)