Skip to content

Commit 0a12e33

Browse files
authored
Run decompositions before the quantizer
Differential Revision: D66461406 Pull Request resolved: #7111
1 parent 2326fff commit 0a12e33

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

backends/cadence/aot/compiler.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
to_edge,
2929
)
3030
from executorch.exir.pass_base import PassResult
31+
from torch._inductor.decomposition import remove_decompositions
3132
from torch.ao.quantization.pt2e.export_utils import model_is_exported
3233
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3334

@@ -58,16 +59,33 @@ def convert_pt2(
5859
Returns a GraphModule with the converted model.
5960
"""
6061

62+
# Get default decompositions
63+
decomp_table = torch.export.default_decompositions()
64+
# Select ops to keep
65+
ops_to_keep = [
66+
torch.ops.aten.conv1d.default,
67+
torch.ops.aten.conv2d.default,
68+
torch.ops.aten.layer_norm.default,
69+
torch.ops.aten.linear.default,
70+
torch.ops.aten.matmul.default,
71+
]
72+
# Remove decompositions for the ops we want to keep
73+
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
74+
remove_decompositions(decomp_table, ops_to_keep)
6175
# Export with dynamo
62-
model_gm = torch.export.export_for_training(model, inputs).module()
76+
model_gm = (
77+
torch.export.export_for_training(model, inputs)
78+
.run_decompositions(decomp_table)
79+
.module()
80+
)
6381

64-
if model_gm_has_SDPA(model_gm): # pyre-fixme[6]
82+
if model_gm_has_SDPA(model_gm):
6583
# Decompose SDPA
66-
DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6]
84+
DecomposeScaledDotProductAttention(False)(model_gm)
6785

6886
# Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
6987
# for details).
70-
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6]
88+
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)
7189
assert result is not None
7290
model_gm = result.graph_module
7391

0 commit comments

Comments
 (0)