Skip to content

Commit ec3e471

Browse files
billmguofacebook-github-bot
authored andcommitted
support mimi model export
Summary: Compare the the xnnpack and eager model Reviewed By: iseeyuan Differential Revision: D71634148
1 parent 0dd7e4e commit ec3e471

File tree

1 file changed

+68
-4
lines changed

1 file changed

+68
-4
lines changed

examples/models/moshi/mimi/test_mimi.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,26 @@
1313
get_symmetric_quantization_config,
1414
XNNPACKQuantizer,
1515
)
16+
from executorch.devtools.backend_debug import print_delegation_info
1617
from executorch.exir import to_edge_transform_and_lower
18+
from executorch.runtime import Runtime
1719

1820
from huggingface_hub import hf_hub_download
1921
from moshi.models import loaders
20-
from torch.ao.quantization.quantize_pt2e import (
21-
convert_pt2e,
22-
prepare_pt2e,
23-
)
22+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
2423
from torch.export import export, ExportedProgram
24+
from torch.utils._pytree import tree_flatten
25+
26+
27+
def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:
28+
assert x.shape == y.shape, "Tensor shapes do not match"
29+
x = x.float()
30+
y = y.float()
31+
error = x - y
32+
original_power = torch.mean(torch.pow(x, 2))
33+
error_power = torch.mean(torch.pow(error, 2))
34+
sqnr = 10 * torch.log10(original_power / error_power)
35+
return sqnr.item()
2536

2637

2738
def read_mp3_from_url(url):
@@ -189,6 +200,59 @@ def forward(self, x):
189200
ep_encode_output = exported_encode.module()(chunk)
190201
self.assertTrue(torch.allclose(ep_encode_output, ref_encode_output, atol=1e-6))
191202

203+
def test_exported_decoder_xnnpack(self):
204+
class MimiDecode(nn.Module):
205+
def __init__(self, mimi: nn.Module):
206+
super().__init__()
207+
self.mimi_model = mimi
208+
209+
def forward(self, x):
210+
return self.mimi_model.decode(x)
211+
212+
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None]
213+
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
214+
chunk = sample_pcm[..., 0:pcm_chunk_size]
215+
input = self.mimi.encode(chunk)
216+
217+
mimi_decode = MimiDecode(self.mimi)
218+
exported_decode: ExportedProgram = export(mimi_decode, (input,), strict=False)
219+
quantization_config = get_symmetric_quantization_config(
220+
is_per_channel=True,
221+
is_dynamic=True,
222+
)
223+
quantizer = XNNPACKQuantizer()
224+
quantizer.set_global(quantization_config)
225+
m = exported_decode.module()
226+
m = prepare_pt2e(m, quantizer)
227+
m(input)
228+
m = convert_pt2e(m)
229+
print("quantized graph:")
230+
print(m.graph)
231+
# Export quantized module
232+
exported_decode: ExportedProgram = export(m, (input,), strict=False)
233+
# Lower
234+
edge_manager = to_edge_transform_and_lower(
235+
exported_decode,
236+
partitioner=[XnnpackPartitioner()],
237+
)
238+
print("delegate graph:")
239+
print_delegation_info(edge_manager.exported_program().graph_module)
240+
exec_prog = edge_manager.to_executorch()
241+
output_file = "/tmp/mimi_decode.pte"
242+
with open(output_file, "wb") as file:
243+
exec_prog.write_to_file(file)
244+
245+
eager_res = mimi_decode(input)
246+
runtime = Runtime.get()
247+
program = runtime.load_program(output_file)
248+
method = program.load_method("forward")
249+
flattened_x = tree_flatten(input)[0]
250+
res = method.execute(flattened_x)
251+
# Compare results
252+
sqnr = compute_sqnr(eager_res, res[0])
253+
print(f"SQNR: {sqnr}")
254+
torch.testing.assert_close(eager_res, res[0], atol=1e-3, rtol=1e-3)
255+
192256

193257
if __name__ == "__main__":
194258
unittest.main()

0 commit comments

Comments
 (0)