Skip to content

Commit d5c4ba7

Browse files
authored
Qualcomm AI Engine Direct - Add rewrite function of observer
Differential Revision: D75420336 Pull Request resolved: #10093
1 parent 95a1db5 commit d5c4ba7

File tree

3 files changed

+90
-1
lines changed

3 files changed

+90
-1
lines changed

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,14 @@ def forward(self, x):
12691269
return x.repeat(1, 2, 3, 4)
12701270

12711271

1272+
class ReWriteObs(torch.nn.Module):
1273+
def __init__(self):
1274+
super().__init__()
1275+
1276+
def forward(self, x):
1277+
return torch.nn.functional.relu(x).expand(3, 4)
1278+
1279+
12721280
class Reshape(torch.nn.Module):
12731281
def __init__(self):
12741282
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
generate_qnn_executorch_compiler_spec,
5050
PyQnnManagerAdaptor,
5151
QnnPartitioner,
52+
rewrite_prepared_observer,
5253
skip_annotation,
5354
to_edge_transform_and_lower_to_qnn,
5455
update_spill_fill_size,
@@ -3058,6 +3059,36 @@ def test_qnn_backend_dynamic_shape(self):
30583059
check_io_shape=True,
30593060
)
30603061

3062+
def test_qnn_backend_rewrite_prepared_observer(self):
3063+
from torchao.quantization.pt2e import FixedQParamsObserver
3064+
3065+
module = ReWriteObs() # noqa: F405
3066+
sample_input = (torch.randn([3, 1]),)
3067+
module = torch.export.export(module, sample_input, strict=True).module()
3068+
3069+
quantizer = make_quantizer()
3070+
3071+
prepared = prepare_pt2e(module, quantizer)
3072+
prepared(*sample_input)
3073+
3074+
new_obs = FixedQParamsObserver(
3075+
scale=0.004,
3076+
zero_point=0,
3077+
dtype=torch.uint8,
3078+
quant_min=0,
3079+
quant_max=255,
3080+
qscheme=torch.per_tensor_affine,
3081+
)
3082+
3083+
rewrite_prepared_observer(prepared, {"activation_post_process_2": new_obs})
3084+
self.assertTrue(
3085+
prepared.activation_post_process_1
3086+
== prepared.activation_post_process_2
3087+
== new_obs
3088+
)
3089+
quantized_module = convert_pt2e(prepared)
3090+
self.lower_module_and_test_output(quantized_module, sample_input)
3091+
30613092
def test_qnn_backend_skip_node_id_partitioner(self):
30623093
module = SimpleModel() # noqa: F405
30633094
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))

backends/qualcomm/utils/utils.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import operator
77
import warnings
8-
from collections import OrderedDict
8+
from collections import defaultdict, OrderedDict
99
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1010

1111
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
@@ -1038,3 +1038,53 @@ def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
10381038
for node in gm.graph.nodes:
10391039
if dtype := get_quant_io_dtype_fn(node):
10401040
node.meta[QCOM_QUANTIZED_IO] = dtype
1041+
1042+
1043+
def rewrite_prepared_observer(
1044+
graph_module: torch.fx.GraphModule, name_obs_dict: Dict[str, torch.nn.Module]
1045+
):
1046+
"""
1047+
Rewrite the observer of the specified observer module name in the graph_module.
1048+
1049+
Example:
1050+
Consider the following graph_module after prepare_pt2e:
1051+
gm = prepare_pt2e(gm)
1052+
print(gm)
1053+
1054+
GraphModule(
1055+
(activation_post_process_0): MinMaxObserver(min_val=inf, max_val=-inf)
1056+
(activation_post_process_1): MinMaxObserver(min_val=inf, max_val=-inf)
1057+
(activation_post_process_2): MinMaxObserver(min_val=inf, max_val=-inf)
1058+
(activation_post_process_3): MinMaxObserver(min_val=inf, max_val=-inf)
1059+
)
1060+
1061+
new_observer = observer.FixedQParamsObserver(
1062+
scale=0.125,
1063+
zero_point=42,
1064+
dtype=torch.uint8,
1065+
quant_min=0,
1066+
quant_max=255,
1067+
qscheme=torch.per_tensor_affine,
1068+
)
1069+
1070+
Calling rewrite_prepared_observer(gm, {"activation_post_process_0": new_observer})
1071+
is equivalent to:
1072+
gm.activation_post_process_0 = new_observer
1073+
1074+
Note:
1075+
If the rewritten observer is a SharedQuantizationSpec, all other shared observers will also be rewritten.
1076+
"""
1077+
module_name_list = defaultdict(list)
1078+
for name, module in graph_module.named_modules(remove_duplicate=False):
1079+
module_name_list[module].append(name)
1080+
1081+
for name, new_observer in name_obs_dict.items():
1082+
old_module = getattr(graph_module, name, None)
1083+
1084+
if not old_module:
1085+
print(
1086+
f"[WARNING], No observer named as {name} found, please check the moudle name"
1087+
)
1088+
continue
1089+
for target_name in module_name_list[old_module]:
1090+
setattr(graph_module, target_name, new_observer)

0 commit comments

Comments
 (0)