Skip to content

Commit 55f218c

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Extract trace from prepare_and_convert and remove export_program
Summary: As titled. Will be used in later changes to fix some inconsistencies. Differential Revision: D73440517
1 parent 1bd7260 commit 55f218c

File tree

2 files changed

+58
-42
lines changed

2 files changed

+58
-42
lines changed

backends/cadence/aot/compiler.py

+53-41
Original file line numberDiff line numberDiff line change
@@ -39,35 +39,31 @@
3939
from torch._inductor.decomposition import remove_decompositions
4040
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
4141

42-
from torch.export import export
4342
from torch.export.exported_program import ExportedProgram
4443

4544
from .passes import get_cadence_passes
4645

4746
from .utils import print_ops_info
4847

4948

50-
def prepare_and_convert_pt2(
49+
def trace(
5150
model: torch.nn.Module,
5251
inputs: tuple[object, ...],
53-
quantizer: CadenceQuantizer,
54-
calibration_data: Optional[list[tuple[object, ...]]] = None,
5552
dump_graphs: bool = False,
56-
) -> torch.fx.GraphModule:
53+
) -> ExportedProgram:
5754
"""
58-
Prepare and convert a model using the given quantizer.
59-
The quantizer must be supplied and be the same as the one used to
60-
fuse the model later, if applicable. If you do not expect that behavior,
61-
please use quantize_and_fuse_pt2 instead, which will instantiate a
62-
default quantizer for you if needed.
63-
If calibration data is provided, it will be used to calibrate the model. If
64-
not, the inputs will be used for calibration instead, which is useful for
65-
unit tests but should not be used for end-to-end use cases.
66-
Returns a GraphModule with the converted model.
55+
Trace the model with export_for_training and return an ExportedProgram.
6756
"""
6857

58+
# Make the model inference mode by calling model.eval()
59+
model.eval()
60+
61+
# Prevent mkldnn decompositions
62+
torch._C._set_mkldnn_enabled(False)
63+
6964
# Get default decompositions
7065
decomp_table = torch.export.default_decompositions()
66+
7167
# Select ops to keep
7268
ops_to_keep = [
7369
torch.ops.aten.conv1d.default,
@@ -77,19 +73,47 @@ def prepare_and_convert_pt2(
7773
torch.ops.aten.matmul.default,
7874
torch.ops.aten.rms_norm.default,
7975
]
76+
8077
# Remove decompositions for the ops we want to keep
8178
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
8279
remove_decompositions(decomp_table, ops_to_keep)
80+
8381
# Export with dynamo
84-
model_gm = (
82+
ep = (
8583
torch.export.export_for_training(model, inputs, strict=True)
8684
.run_decompositions(decomp_table)
87-
.module()
8885
)
8986

9087
if dump_graphs:
9188
logging.info("Graph before quantization:")
92-
logging.info(model_gm.graph.print_tabular())
89+
logging.info(ep.module().graph.print_tabular())
90+
91+
return ep
92+
93+
94+
def prepare_and_convert_pt2(
95+
ep: ExportedProgram,
96+
inputs: tuple[object, ...],
97+
quantizer: CadenceQuantizer,
98+
calibration_data: Optional[list[tuple[object, ...]]] = None,
99+
dump_graphs: bool = False,
100+
) -> torch.fx.GraphModule:
101+
"""
102+
Prepare and convert a model using the given quantizer.
103+
The quantizer must be supplied and be the same as the one used to
104+
fuse the model later, if applicable. If you do not expect that behavior,
105+
please use quantize_and_fuse_pt2 instead, which will instantiate a
106+
default quantizer for you if needed.
107+
If calibration data is provided, it will be used to calibrate the model. If
108+
not, the inputs will be used for calibration instead, which is useful for
109+
unit tests but should not be used for end-to-end use cases.
110+
Returns a GraphModule with the converted model.
111+
"""
112+
113+
# Get the graph module from the ExportedProgram
114+
model_gm = ep.module()
115+
116+
assert isinstance(model_gm, torch.fx.GraphModule)
93117

94118
# Prepare
95119
prepared_model = prepare_pt2e(model_gm, quantizer)
@@ -113,10 +137,10 @@ def prepare_and_convert_pt2(
113137

114138

115139
# Note: this is not meant as a primary API since it can create inconsistencies
116-
# if the quantizer here is different from the quantizer used to convert. It is
117-
# however useful for unit tests to separate the converted model from the fused
118-
# model, to be able to get reference numerics.
119-
# If this does not apply, please use quantize_and_fuse_pt2 instead.
140+
# if the quantizer here is different from the quantizer used to prepare/convert.
141+
# It is however useful for unit tests to separate the converted model from the
142+
# fused model, to be able to get reference numerics.
143+
# If this does not apply, please use quantize_pt2 instead.
120144
def fuse_pt2(
121145
converted_graph_module: torch.fx.GraphModule,
122146
quantizer: CadenceQuantizer,
@@ -151,16 +175,20 @@ def quantize_pt2(
151175
unit tests but should not be used for end-to-end use cases.
152176
Returns a GraphModule with the quantized model.
153177
"""
154-
# Make the model inference mode by calling model.eval()
155-
model.eval()
156178

157179
# Instantiate the quantizer to CadenceQuantizer if not supplied
158180
if not quantizer:
159181
quantizer = CadenceDefaultQuantizer()
160182

183+
ep = trace(model, inputs, dump_graphs=dump_graphs)
184+
185+
if dump_graphs:
186+
logging.info("Graph after trace:")
187+
logging.info(ep.graph.print_tabular())
188+
161189
# Get converted graph module
162190
converted_gm = prepare_and_convert_pt2(
163-
model, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
191+
ep, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
164192
)
165193

166194
# Get fused model
@@ -173,22 +201,6 @@ def quantize_pt2(
173201
return fused_gm
174202

175203

176-
# Export the model and lower it to an ExportedProgram (in aten IR)
177-
def export_program(
178-
model: torch.nn.Module,
179-
inputs: tuple[object, ...],
180-
) -> ExportedProgram:
181-
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
182-
183-
# Prevent mkldnn decompositions
184-
torch._C._set_mkldnn_enabled(False)
185-
186-
# Export the model and return it.
187-
expo_program = export(model, inputs, strict=True)
188-
189-
return expo_program
190-
191-
192204
def lower_ep_to_edge(
193205
expo_program: ExportedProgram,
194206
dump_graphs: bool = False,
@@ -237,7 +249,7 @@ def export_to_edge(
237249
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
238250

239251
# Export the model into an ExportedProgram.
240-
expo_program = export_program(model, inputs)
252+
expo_program = trace(model, inputs)
241253

242254
# Lower the model to edge IR.
243255
edge_prog_manager = lower_ep_to_edge(expo_program, dump_graphs, constant_methods)

backends/cadence/aot/export_example.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
export_to_executorch_gen_etrecord,
1919
fuse_pt2,
2020
prepare_and_convert_pt2,
21+
trace,
2122
)
2223

2324
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
@@ -48,8 +49,11 @@ def export_model(
4849
# Instantiate the quantizer
4950
quantizer = CadenceDefaultQuantizer()
5051

52+
# Trace the model
53+
ep = trace(model, example_inputs)
54+
5155
# Convert the model
52-
converted_model = prepare_and_convert_pt2(model, example_inputs, quantizer)
56+
converted_model = prepare_and_convert_pt2(ep, example_inputs, quantizer)
5357

5458
# Get reference outputs from converted model
5559
ref_outputs = converted_model(*example_inputs)

0 commit comments

Comments
 (0)