Skip to content

Commit c048ea2

Browse files
Arm backend: Add missing __init__.py to passes (#9710)
backends/arm/_passes missed an init file which caused pyre to not find imports from _passes. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 6939136 commit c048ea2

15 files changed

+86
-96
lines changed

backends/arm/_passes/__init__.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from . import arm_pass_utils # noqa
8+
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
9+
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
10+
from .cast_int64_pass import CastInt64ToInt32Pass # noqa
11+
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
12+
from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa
13+
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa
14+
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
15+
from .convert_minmax_pass import ConvertMinMaxPass # noqa
16+
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
17+
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
18+
from .convert_to_clamp import ConvertToClampPass # noqa
19+
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
20+
from .decompose_div_pass import DecomposeDivPass # noqa
21+
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
22+
from .decompose_linear_pass import DecomposeLinearPass # noqa
23+
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
24+
from .decompose_select import DecomposeSelectPass # noqa
25+
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
26+
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
27+
from .decompose_var_pass import DecomposeVarPass # noqa
28+
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
29+
FoldAndAnnotateQParamsPass,
30+
QuantizeOperatorArguments,
31+
RetraceFoldedDtypesPass,
32+
)
33+
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
34+
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
35+
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
36+
from .insert_rescales_pass import InsertRescalePass # noqa
37+
from .insert_table_ops import InsertTableOpsPass # noqa
38+
from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa
39+
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
40+
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
41+
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
42+
from .remove_clone_pass import RemoveClonePass # noqa
43+
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
44+
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
45+
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
46+
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
47+
from .arm_pass_manager import ArmPassManager # noqa # usort: skip

backends/arm/_passes/arm_pass_manager.py

Lines changed: 21 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,82 +7,44 @@
77

88
# pyre-unsafe
99

10-
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
10+
from executorch.backends.arm._passes import (
1111
AnnotateChannelsLastDimOrder,
12-
)
13-
from executorch.backends.arm._passes.annotate_decomposed_matmul import (
1412
AnnotateDecomposedMatmulPass,
15-
)
16-
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
17-
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
18-
from executorch.backends.arm._passes.convert_any_default_dim_dims_pass import (
13+
CastInt64ToInt32Pass,
14+
ComputeConstantOpsAOT,
15+
Conv1dUnsqueezePass,
1916
ConvertAnyDefaultDimDimsPass,
20-
)
21-
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
2217
ConvertExpandCopyToRepeatPass,
23-
)
24-
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
2518
ConvertFullLikeToFullPass,
26-
)
27-
from executorch.backends.arm._passes.convert_minmax_pass import ConvertMinMaxPass
28-
from executorch.backends.arm._passes.convert_split_to_slice import (
19+
ConvertMeanDimToAveragePoolPass,
20+
ConvertMinMaxPass,
21+
ConvertMmToBmmPass,
2922
ConvertSplitToSlicePass,
30-
)
31-
from executorch.backends.arm._passes.convert_squeezes_to_view import ( # type: ignore[import-not-found]
3223
ConvertSqueezesToViewPass,
33-
)
34-
from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass
35-
from executorch.backends.arm._passes.decompose_batchnorm_pass import (
24+
ConvertToClampPass,
3625
DecomposeBatchNormPass,
37-
)
38-
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
39-
from executorch.backends.arm._passes.decompose_layernorm_pass import (
26+
DecomposeDivPass,
4027
DecomposeLayerNormPass,
41-
)
42-
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
43-
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
44-
from executorch.backends.arm._passes.decompose_select import ( # type: ignore[import-not-found]
28+
DecomposeLinearPass,
29+
DecomposeMeanDimPass,
4530
DecomposeSelectPass,
46-
)
47-
from executorch.backends.arm._passes.decompose_softmax_pass import DecomposeSoftmaxPass
48-
from executorch.backends.arm._passes.decompose_softmax_unstable_pass import (
31+
DecomposeSoftmaxPass,
4932
DecomposeSoftmaxUnstablePass,
50-
)
51-
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
52-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
33+
DecomposeVarPass,
5334
FoldAndAnnotateQParamsPass,
54-
QuantizeOperatorArguments,
55-
RetraceFoldedDtypesPass,
56-
)
57-
from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
58-
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
59-
ComputeConstantOpsAOT,
35+
FuseBatchnorm2DPass,
6036
FuseConstantArgsPass,
61-
)
62-
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
6337
FuseQuantizedActivationPass,
64-
)
65-
from executorch.backends.arm._passes.insert_rescales_pass import InsertRescalePass
66-
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
67-
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
38+
InsertRescalePass,
39+
InsertTableOpsPass,
6840
KeepDimsFalseToSqueezePass,
69-
)
70-
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
71-
from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( # type: ignore[attr-defined]
72-
ConvertMeanDimToAveragePoolPass,
73-
)
74-
from executorch.backends.arm._passes.mm_to_bmm_pass import ( # type: ignore[import-not-found]
75-
ConvertMmToBmmPass,
76-
)
77-
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
78-
from executorch.backends.arm._passes.scalars_to_attribute_pass import (
41+
MatchArgRanksPass,
42+
QuantizeOperatorArguments,
43+
RemoveClonePass,
44+
RetraceFoldedDtypesPass,
7945
ScalarsToAttributePass,
80-
)
81-
from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
82-
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
46+
SizeAdjustConv2DPass,
8347
UnsqueezeBeforeRepeatPass,
84-
)
85-
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
8648
UnsqueezeScalarPlaceholdersPass,
8749
)
8850
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification

backends/arm/operators/op_avg_pool2d.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import serializer.tosa_serializer as ts # type: ignore
1010
import torch
1111

12-
# pyre-fixme[21]: ' Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`
1312
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1413
get_input_qparams,
1514
get_output_qparams,
@@ -88,10 +87,10 @@ def define_node(
8887

8988
accumulator_type = ts.DType.INT32
9089

91-
input_qargs = get_input_qparams(node) # pyre-ignore[16]
90+
input_qargs = get_input_qparams(node)
9291
input_zp = input_qargs[0].zp
9392

94-
output_qargs = get_output_qparams(node) # pyre-ignore[16]
93+
output_qargs = get_output_qparams(node)
9594
output_zp = output_qargs[0].zp
9695

9796
self._build_generic_avgpool2d(

backends/arm/operators/op_bmm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import serializer.tosa_serializer as ts # type: ignore
1111
import torch
1212

13-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
1413
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1514
get_input_qparams,
1615
get_output_qparams,
@@ -51,7 +50,7 @@ def define_node(
5150
# for a later rescale.
5251

5352
if inputs[0].dtype == ts.DType.INT8:
54-
input_qparams = get_input_qparams(node) # pyre-ignore[16]
53+
input_qparams = get_input_qparams(node)
5554
input0_zp = input_qparams[0].zp
5655
input1_zp = input_qparams[1].zp
5756
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
@@ -73,7 +72,7 @@ def define_node(
7372

7473
# As INT8 accumulates into INT32, we need to rescale it back to INT8
7574
if output.dtype == ts.DType.INT8:
76-
output_qparams = get_output_qparams(node)[0] # pyre-ignore[16]
75+
output_qparams = get_output_qparams(node)[0]
7776
final_output_scale = (
7877
input_qparams[0].scale * input_qparams[1].scale # type: ignore[possibly-undefined] # pyre-ignore[61]
7978
) / output_qparams.scale

backends/arm/operators/op_conv2d.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import serializer.tosa_serializer as ts # type: ignore
1010
import torch
1111

12-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
1312
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1413
get_input_qparams,
1514
get_output_qparams,
@@ -85,7 +84,7 @@ def define_node(
8584
input_zp = 0
8685
if inputs[0].dtype == ts.DType.INT8:
8786
# int8 input requires quantization information
88-
input_qparams = get_input_qparams(node) # pyre-ignore[16]
87+
input_qparams = get_input_qparams(node)
8988
input_zp = input_qparams[0].zp
9089

9190
attr.ConvAttribute(
@@ -169,7 +168,7 @@ def define_node(
169168
# Get scale_factor from input, weight, and output.
170169
input_scale = input_qparams[0].scale # type: ignore[possibly-undefined] # pyre-ignore [61]
171170
weight_scale = input_qparams[1].scale # pyre-ignore [61]
172-
output_qargs = get_output_qparams(node) # pyre-ignore [16]
171+
output_qargs = get_output_qparams(node)
173172
build_rescale_conv_output(
174173
tosa_graph,
175174
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.

backends/arm/operators/op_max_pool2d.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import serializer.tosa_serializer as ts # type: ignore
1010
import torch
1111

12-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
1312
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1413
get_input_qparams,
1514
get_output_qparams,
@@ -57,12 +56,12 @@ def define_node(
5756
# Initilize zero point to zero.
5857
input_zp = 0
5958
if inputs[0].dtype == ts.DType.INT8:
60-
input_qparams = get_input_qparams(node) # pyre-ignore[16]
59+
input_qparams = get_input_qparams(node)
6160
input_zp = input_qparams[0].zp
6261

6362
output_zp = 0
6463
if output.dtype == ts.DType.INT8:
65-
output_qparams = get_output_qparams(node) # pyre-ignore[16]
64+
output_qparams = get_output_qparams(node)
6665
output_zp = output_qparams[0].zp
6766

6867
attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_maximum.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111
import serializer.tosa_serializer as ts # type: ignore
1212

13-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
1413
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1514
get_input_qparams,
1615
)
@@ -44,9 +43,7 @@ def define_node(
4443
scale_back = 1.0
4544
max_output = output
4645
if inputs[0].dtype == ts.DType.INT8:
47-
input_qparams = get_input_qparams( # pyre-ignore[16]: 'Module `executorch.backends.arm` has no attribute `_passes`.'
48-
node
49-
)
46+
input_qparams = get_input_qparams(node)
5047
assert (
5148
len(input_qparams) == 2
5249
), f"Both inputs needs to have quantization information for {node}"

backends/arm/operators/op_minimum.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import serializer.tosa_serializer as ts # type: ignore
1313

14-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
1514
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1615
get_input_qparams,
1716
)
@@ -45,9 +44,7 @@ def define_node(
4544
scale_back = 1.0
4645
min_output = output
4746
if inputs[0].dtype == ts.DType.INT8:
48-
input_qparams = get_input_qparams( # pyre-ignore[16]: 'Module `executorch.backends.arm` has no attribute `_passes`.'
49-
node
50-
)
47+
input_qparams = get_input_qparams(node)
5148
assert (
5249
len(input_qparams) == 2
5350
), f"Both inputs needs to have quantization information for {node}"

backends/arm/operators/op_mul.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import serializer.tosa_serializer as ts # type: ignore
1414
import torch
1515

16-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
1716
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1817
get_input_qparams,
1918
)
@@ -52,7 +51,7 @@ def define_node(
5251
)
5352
input_A = inputs[0]
5453
input_B = inputs[1]
55-
input_qparams = get_input_qparams(node) # pyre-ignore[16]
54+
input_qparams = get_input_qparams(node)
5655
input_A_qargs = input_qparams[0]
5756
input_B_qargs = input_qparams[1]
5857
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any, Callable, Dict, List, Optional
1818

1919
import torch
20-
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
20+
from executorch.backends.arm._passes import ArmPassManager
2121

2222
from executorch.backends.arm.quantizer import arm_quantizer_utils
2323
from executorch.backends.arm.quantizer.arm_quantizer_utils import ( # type: ignore[attr-defined]

backends/arm/test/passes/test_fold_qdq_pass.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from typing import Tuple
77

88
import torch
9-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
10-
FoldAndAnnotateQParamsPass,
11-
)
9+
from executorch.backends.arm._passes import FoldAndAnnotateQParamsPass
1210
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1311

1412

backends/arm/test/passes/test_meandim_to_averagepool2d.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
from typing import Tuple
88

99
import torch
10-
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
11-
ConvertMeanDimToAveragePoolPass,
12-
)
10+
from executorch.backends.arm._passes import ConvertMeanDimToAveragePoolPass
1311
from executorch.backends.arm.test import common
1412
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1513

backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from typing import Dict, Tuple
77

88
import torch
9-
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
10-
UnsqueezeBeforeRepeatPass,
11-
)
9+
from executorch.backends.arm._passes import UnsqueezeBeforeRepeatPass
1210
from executorch.backends.arm.test import common
1311
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1412

backends/arm/tosa_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from executorch.backends.arm.arm_backend import get_tosa_spec
2020
from executorch.backends.arm.operators.node_visitor import get_node_visitors
21-
from executorch.backends.arm._passes.arm_pass_manager import (
21+
from executorch.backends.arm._passes import (
2222
ArmPassManager,
2323
) # usort: skip
2424
from executorch.backends.arm.process_node import (

backends/arm/tosa_quant_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def insert_rescale_ops_to_int32(
4242
in the node meta dict.
4343
"""
4444

45-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
4645
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
4746
get_input_qparams,
4847
)
@@ -54,7 +53,7 @@ def insert_rescale_ops_to_int32(
5453
dim_order = tensor.dim_order
5554
tensor.shape = [tensor.shape[i] for i in dim_order]
5655

57-
input_qparams = get_input_qparams(node) # pyre-ignore[16]
56+
input_qparams = get_input_qparams(node)
5857
qargs = input_qparams.values()
5958

6059
# Scale the int8 quantized input to a common scale in the integer
@@ -92,12 +91,11 @@ def insert_rescale_op_to_int8(
9291
handled by the DQ/D folding pass, which stores the quantization parameters
9392
in the node meta dict.
9493
"""
95-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
9694
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
9795
get_output_qparams,
9896
)
9997

100-
output_qparams = get_output_qparams(node) # pyre-ignore[16]
98+
output_qparams = get_output_qparams(node)
10199
assert len(output_qparams) == 1, "More than one output not supported"
102100

103101
qargs_out = output_qparams[0]

0 commit comments

Comments
 (0)