39
39
from torch ._inductor .decomposition import remove_decompositions
40
40
from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
41
41
42
- from torch .export import export
43
42
from torch .export .exported_program import ExportedProgram
44
43
45
44
from .passes import get_cadence_passes
46
45
47
46
from .utils import print_ops_info
48
47
49
48
50
- def prepare_and_convert_pt2 (
49
+ def trace (
51
50
model : torch .nn .Module ,
52
51
inputs : tuple [object , ...],
53
- quantizer : CadenceQuantizer ,
54
- calibration_data : Optional [list [tuple [object , ...]]] = None ,
55
52
dump_graphs : bool = False ,
56
- ) -> torch . fx . GraphModule :
53
+ ) -> ExportedProgram :
57
54
"""
58
- Prepare and convert a model using the given quantizer.
59
- The quantizer must be supplied and be the same as the one used to
60
- fuse the model later, if applicable. If you do not expect that behavior,
61
- please use quantize_and_fuse_pt2 instead, which will instantiate a
62
- default quantizer for you if needed.
63
- If calibration data is provided, it will be used to calibrate the model. If
64
- not, the inputs will be used for calibration instead, which is useful for
65
- unit tests but should not be used for end-to-end use cases.
66
- Returns a GraphModule with the converted model.
55
+ Trace the model with export_for_training and return an ExportedProgram.
67
56
"""
68
57
58
+ # Make the model inference mode by calling model.eval()
59
+ model .eval ()
60
+
61
+ # Prevent mkldnn decompositions
62
+ torch ._C ._set_mkldnn_enabled (False )
63
+
69
64
# Get default decompositions
70
65
decomp_table = torch .export .default_decompositions ()
66
+
71
67
# Select ops to keep
72
68
ops_to_keep = [
73
69
torch .ops .aten .conv1d .default ,
@@ -77,19 +73,47 @@ def prepare_and_convert_pt2(
77
73
torch .ops .aten .matmul .default ,
78
74
torch .ops .aten .rms_norm .default ,
79
75
]
76
+
80
77
# Remove decompositions for the ops we want to keep
81
78
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
82
79
remove_decompositions (decomp_table , ops_to_keep )
80
+
83
81
# Export with dynamo
84
- model_gm = (
82
+ ep = (
85
83
torch .export .export_for_training (model , inputs , strict = True )
86
84
.run_decompositions (decomp_table )
87
- .module ()
88
85
)
89
86
90
87
if dump_graphs :
91
88
logging .info ("Graph before quantization:" )
92
- logging .info (model_gm .graph .print_tabular ())
89
+ logging .info (ep .module ().graph .print_tabular ())
90
+
91
+ return ep
92
+
93
+
94
+ def prepare_and_convert_pt2 (
95
+ ep : ExportedProgram ,
96
+ inputs : tuple [object , ...],
97
+ quantizer : CadenceQuantizer ,
98
+ calibration_data : Optional [list [tuple [object , ...]]] = None ,
99
+ dump_graphs : bool = False ,
100
+ ) -> torch .fx .GraphModule :
101
+ """
102
+ Prepare and convert a model using the given quantizer.
103
+ The quantizer must be supplied and be the same as the one used to
104
+ fuse the model later, if applicable. If you do not expect that behavior,
105
+ please use quantize_and_fuse_pt2 instead, which will instantiate a
106
+ default quantizer for you if needed.
107
+ If calibration data is provided, it will be used to calibrate the model. If
108
+ not, the inputs will be used for calibration instead, which is useful for
109
+ unit tests but should not be used for end-to-end use cases.
110
+ Returns a GraphModule with the converted model.
111
+ """
112
+
113
+ # Get the graph module from the ExportedProgram
114
+ model_gm = ep .module ()
115
+
116
+ assert isinstance (model_gm , torch .fx .GraphModule )
93
117
94
118
# Prepare
95
119
prepared_model = prepare_pt2e (model_gm , quantizer )
@@ -113,10 +137,10 @@ def prepare_and_convert_pt2(
113
137
114
138
115
139
# Note: this is not meant as a primary API since it can create inconsistencies
116
- # if the quantizer here is different from the quantizer used to convert. It is
117
- # however useful for unit tests to separate the converted model from the fused
118
- # model, to be able to get reference numerics.
119
- # If this does not apply, please use quantize_and_fuse_pt2 instead.
140
+ # if the quantizer here is different from the quantizer used to prepare/ convert.
141
+ # It is however useful for unit tests to separate the converted model from the
142
+ # fused model, to be able to get reference numerics.
143
+ # If this does not apply, please use quantize_pt2 instead.
120
144
def fuse_pt2 (
121
145
converted_graph_module : torch .fx .GraphModule ,
122
146
quantizer : CadenceQuantizer ,
@@ -151,16 +175,20 @@ def quantize_pt2(
151
175
unit tests but should not be used for end-to-end use cases.
152
176
Returns a GraphModule with the quantized model.
153
177
"""
154
- # Make the model inference mode by calling model.eval()
155
- model .eval ()
156
178
157
179
# Instantiate the quantizer to CadenceQuantizer if not supplied
158
180
if not quantizer :
159
181
quantizer = CadenceDefaultQuantizer ()
160
182
183
+ ep = trace (model , inputs , dump_graphs = dump_graphs )
184
+
185
+ if dump_graphs :
186
+ logging .info ("Graph after trace:" )
187
+ logging .info (ep .graph .print_tabular ())
188
+
161
189
# Get converted graph module
162
190
converted_gm = prepare_and_convert_pt2 (
163
- model , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
191
+ ep , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
164
192
)
165
193
166
194
# Get fused model
@@ -173,22 +201,6 @@ def quantize_pt2(
173
201
return fused_gm
174
202
175
203
176
- # Export the model and lower it to an ExportedProgram (in aten IR)
177
- def export_program (
178
- model : torch .nn .Module ,
179
- inputs : tuple [object , ...],
180
- ) -> ExportedProgram :
181
- assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
182
-
183
- # Prevent mkldnn decompositions
184
- torch ._C ._set_mkldnn_enabled (False )
185
-
186
- # Export the model and return it.
187
- expo_program = export (model , inputs , strict = True )
188
-
189
- return expo_program
190
-
191
-
192
204
def lower_ep_to_edge (
193
205
expo_program : ExportedProgram ,
194
206
dump_graphs : bool = False ,
@@ -237,7 +249,7 @@ def export_to_edge(
237
249
assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
238
250
239
251
# Export the model into an ExportedProgram.
240
- expo_program = export_program (model , inputs )
252
+ expo_program = trace (model , inputs )
241
253
242
254
# Lower the model to edge IR.
243
255
edge_prog_manager = lower_ep_to_edge (expo_program , dump_graphs , constant_methods )
0 commit comments