|
28 | 28 | to_edge,
|
29 | 29 | )
|
30 | 30 | from executorch.exir.pass_base import PassResult
|
| 31 | +from torch._inductor.decomposition import remove_decompositions |
31 | 32 | from torch.ao.quantization.pt2e.export_utils import model_is_exported
|
32 | 33 | from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
33 | 34 |
|
@@ -58,16 +59,33 @@ def convert_pt2(
|
58 | 59 | Returns a GraphModule with the converted model.
|
59 | 60 | """
|
60 | 61 |
|
| 62 | + # Get default decompositions |
| 63 | + decomp_table = torch.export.default_decompositions() |
| 64 | + # Select ops to keep |
| 65 | + ops_to_keep = [ |
| 66 | + torch.ops.aten.conv1d.default, |
| 67 | + torch.ops.aten.conv2d.default, |
| 68 | + torch.ops.aten.layer_norm.default, |
| 69 | + torch.ops.aten.linear.default, |
| 70 | + torch.ops.aten.matmul.default, |
| 71 | + ] |
| 72 | + # Remove decompositions for the ops we want to keep |
| 73 | + # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any |
| 74 | + remove_decompositions(decomp_table, ops_to_keep) |
61 | 75 | # Export with dynamo
|
62 |
| - model_gm = torch.export.export_for_training(model, inputs).module() |
| 76 | + model_gm = ( |
| 77 | + torch.export.export_for_training(model, inputs) |
| 78 | + .run_decompositions(decomp_table) |
| 79 | + .module() |
| 80 | + ) |
63 | 81 |
|
64 |
| - if model_gm_has_SDPA(model_gm): # pyre-fixme[6] |
| 82 | + if model_gm_has_SDPA(model_gm): |
65 | 83 | # Decompose SDPA
|
66 |
| - DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6] |
| 84 | + DecomposeScaledDotProductAttention(False)(model_gm) |
67 | 85 |
|
68 | 86 | # Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
|
69 | 87 | # for details).
|
70 |
| - result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6] |
| 88 | + result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) |
71 | 89 | assert result is not None
|
72 | 90 | model_gm = result.graph_module
|
73 | 91 |
|
|
0 commit comments