Skip to content

Commit 76ae537

Browse files
authored
support mimi model export
Differential Revision: D71634148 Pull Request resolved: #9522
1 parent de0f6f1 commit 76ae537

File tree

1 file changed

+70
-4
lines changed

1 file changed

+70
-4
lines changed

examples/models/moshi/mimi/test_mimi.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,28 @@
1313
get_symmetric_quantization_config,
1414
XNNPACKQuantizer,
1515
)
16+
from executorch.devtools.backend_debug import print_delegation_info
1617
from executorch.exir import to_edge_transform_and_lower
18+
from executorch.runtime import Runtime
1719

1820
from huggingface_hub import hf_hub_download
1921
from moshi.models import loaders
20-
from torch.ao.quantization.quantize_pt2e import (
21-
convert_pt2e,
22-
prepare_pt2e,
23-
)
22+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
2423
from torch.export import export, ExportedProgram
24+
from torch.utils._pytree import tree_flatten
25+
26+
os.environ["https_proxy"] = "http://fwdproxy:8080"
27+
28+
29+
def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:
30+
assert x.shape == y.shape, "Tensor shapes do not match"
31+
x = x.float()
32+
y = y.float()
33+
error = x - y
34+
original_power = torch.mean(torch.pow(x, 2))
35+
error_power = torch.mean(torch.pow(error, 2))
36+
sqnr = 10 * torch.log10(original_power / error_power)
37+
return sqnr.item()
2538

2639

2740
def read_mp3_from_url(url):
@@ -189,6 +202,59 @@ def forward(self, x):
189202
ep_encode_output = exported_encode.module()(chunk)
190203
self.assertTrue(torch.allclose(ep_encode_output, ref_encode_output, atol=1e-6))
191204

205+
def test_exported_decoder_xnnpack(self):
206+
class MimiDecode(nn.Module):
207+
def __init__(self, mimi: nn.Module):
208+
super().__init__()
209+
self.mimi_model = mimi
210+
211+
def forward(self, x):
212+
return self.mimi_model.decode(x)
213+
214+
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None]
215+
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
216+
chunk = sample_pcm[..., 0:pcm_chunk_size]
217+
input = self.mimi.encode(chunk)
218+
219+
mimi_decode = MimiDecode(self.mimi)
220+
exported_decode: ExportedProgram = export(mimi_decode, (input,), strict=False)
221+
quantization_config = get_symmetric_quantization_config(
222+
is_per_channel=True,
223+
is_dynamic=True,
224+
)
225+
quantizer = XNNPACKQuantizer()
226+
quantizer.set_global(quantization_config)
227+
m = exported_decode.module()
228+
m = prepare_pt2e(m, quantizer)
229+
m(input)
230+
m = convert_pt2e(m)
231+
print("quantized graph:")
232+
print(m.graph)
233+
# Export quantized module
234+
exported_decode: ExportedProgram = export(m, (input,), strict=False)
235+
# Lower
236+
edge_manager = to_edge_transform_and_lower(
237+
exported_decode,
238+
partitioner=[XnnpackPartitioner()],
239+
)
240+
print("delegate graph:")
241+
print_delegation_info(edge_manager.exported_program().graph_module)
242+
exec_prog = edge_manager.to_executorch()
243+
output_file = "/tmp/mimi_decode.pte"
244+
with open(output_file, "wb") as file:
245+
exec_prog.write_to_file(file)
246+
247+
eager_res = mimi_decode(input)
248+
runtime = Runtime.get()
249+
program = runtime.load_program(output_file)
250+
method = program.load_method("forward")
251+
flattened_x = tree_flatten(input)[0]
252+
res = method.execute(flattened_x)
253+
# Compare results
254+
sqnr = compute_sqnr(eager_res, res[0])
255+
print(f"SQNR: {sqnr}")
256+
torch.testing.assert_close(eager_res, res[0], atol=1e-3, rtol=1e-3)
257+
192258

193259
if __name__ == "__main__":
194260
unittest.main()

0 commit comments

Comments
 (0)