10
10
11
11
import torch
12
12
import torch .fx
13
- from executorch .backends .arm .quantizer import arm_quantizer_utils
14
- from executorch .backends .arm .quantizer .quantization_config import QuantizationConfig
13
+ from executorch .backends .arm .quantizer import QuantizationConfig
15
14
from executorch .backends .arm .tosa_utils import get_node_debug_info
16
15
from torch .ao .quantization .quantizer import QuantizationSpecBase , SharedQuantizationSpec
17
16
from torch .ao .quantization .quantizer .utils import (
20
19
)
21
20
from torch .fx import Node
22
21
22
+ from .arm_quantizer_utils import (
23
+ is_annotated ,
24
+ is_ok_for_quantization ,
25
+ is_output_annotated ,
26
+ mark_node_as_annotated ,
27
+ )
28
+
23
29
logger = logging .getLogger (__name__ )
24
30
25
31
@@ -69,7 +75,7 @@ def _is_ok_for_quantization(
69
75
"""
70
76
# Check output
71
77
if quant_properties .quant_output is not None :
72
- if not arm_quantizer_utils . is_ok_for_quantization (node , gm ): # type: ignore[attr-defined]
78
+ if not is_ok_for_quantization (node , gm ): # type: ignore[attr-defined]
73
79
logger .debug (
74
80
f"Could not quantize node due to output: "
75
81
f"{ get_node_debug_info (node , gm )} "
@@ -87,7 +93,7 @@ def _is_ok_for_quantization(
87
93
88
94
for n_arg in _as_list (node .args [quant_property .index ]):
89
95
assert isinstance (n_arg , Node )
90
- if not arm_quantizer_utils . is_ok_for_quantization (n_arg , gm ): # type: ignore[attr-defined]
96
+ if not is_ok_for_quantization (n_arg , gm ): # type: ignore[attr-defined]
91
97
logger .debug (
92
98
f'could not quantize node due to input "{ node } ": '
93
99
f"{ get_node_debug_info (node , gm )} "
@@ -99,7 +105,7 @@ def _is_ok_for_quantization(
99
105
100
106
101
107
def _annotate_input (node : Node , quant_property : _QuantProperty ):
102
- assert not arm_quantizer_utils . is_annotated (node )
108
+ assert not is_annotated (node )
103
109
if quant_property .optional and (
104
110
quant_property .index >= len (node .args )
105
111
or node .args [quant_property .index ] is None
@@ -114,11 +120,11 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
114
120
assert isinstance (n_arg , Node )
115
121
_annotate_input_qspec_map (node , n_arg , qspec )
116
122
if quant_property .mark_annotated :
117
- arm_quantizer_utils . mark_node_as_annotated (n_arg ) # type: ignore[attr-defined]
123
+ mark_node_as_annotated (n_arg ) # type: ignore[attr-defined]
118
124
119
125
120
126
def _annotate_output (node : Node , quant_property : _QuantProperty ):
121
- assert not arm_quantizer_utils . is_annotated (node )
127
+ assert not is_annotated (node )
122
128
assert not quant_property .mark_annotated
123
129
assert not quant_property .optional
124
130
assert quant_property .index == 0 , "Only one output annotation supported currently"
@@ -343,7 +349,7 @@ def any_or_hardtanh_min_zero(n: Node):
343
349
elif node .target in _one_to_one_shared_input_or_input_act_qspec :
344
350
input_qspec = (
345
351
SharedQuantizationSpec (node .args [0 ]) # type: ignore[arg-type]
346
- if arm_quantizer_utils . is_output_annotated (node .args [0 ]) # type: ignore
352
+ if is_output_annotated (node .args [0 ]) # type: ignore
347
353
else input_act_qspec
348
354
)
349
355
quant_properties .quant_inputs = [_QuantProperty (0 , input_qspec )] # type: ignore[arg-type]
@@ -396,7 +402,7 @@ def any_or_hardtanh_min_zero(n: Node):
396
402
if not isinstance (node .args [0 ], Node ):
397
403
return None
398
404
399
- if not arm_quantizer_utils . is_output_annotated (node .args [0 ]): # type: ignore[attr-defined]
405
+ if not is_output_annotated (node .args [0 ]): # type: ignore[attr-defined]
400
406
return None
401
407
402
408
shared_qspec = SharedQuantizationSpec (node .args [0 ])
@@ -426,7 +432,7 @@ def annotate_graph( # type: ignore[return]
426
432
if node .op != "call_function" :
427
433
continue
428
434
429
- if arm_quantizer_utils . is_annotated (node ):
435
+ if is_annotated (node ):
430
436
continue
431
437
432
438
if filter_fn is not None and not filter_fn (node ):
@@ -442,7 +448,7 @@ def annotate_graph( # type: ignore[return]
442
448
if quant_properties .quant_output is not None :
443
449
_annotate_output (node , quant_properties .quant_output )
444
450
445
- arm_quantizer_utils . mark_node_as_annotated (node ) # type: ignore[attr-defined]
451
+ mark_node_as_annotated (node ) # type: ignore[attr-defined]
446
452
447
453
# Quantization does not allow kwargs for some reason.
448
454
# Remove from ops we know have and where we know it does not break anything.
0 commit comments