|
13 | 13 | get_symmetric_quantization_config,
|
14 | 14 | XNNPACKQuantizer,
|
15 | 15 | )
|
| 16 | +from executorch.devtools.backend_debug import print_delegation_info |
16 | 17 | from executorch.exir import to_edge_transform_and_lower
|
| 18 | +from executorch.runtime import Runtime |
17 | 19 |
|
18 | 20 | from huggingface_hub import hf_hub_download
|
19 | 21 | from moshi.models import loaders
|
20 |
| -from torch.ao.quantization.quantize_pt2e import ( |
21 |
| - convert_pt2e, |
22 |
| - prepare_pt2e, |
23 |
| -) |
| 22 | +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e |
24 | 23 | from torch.export import export, ExportedProgram
|
| 24 | +from torch.utils._pytree import tree_flatten |
| 25 | + |
| 26 | +os.environ["https_proxy"] = "http://fwdproxy:8080" |
| 27 | + |
| 28 | + |
| 29 | +def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float: |
| 30 | + assert x.shape == y.shape, "Tensor shapes do not match" |
| 31 | + x = x.float() |
| 32 | + y = y.float() |
| 33 | + error = x - y |
| 34 | + original_power = torch.mean(torch.pow(x, 2)) |
| 35 | + error_power = torch.mean(torch.pow(error, 2)) |
| 36 | + sqnr = 10 * torch.log10(original_power / error_power) |
| 37 | + return sqnr.item() |
25 | 38 |
|
26 | 39 |
|
27 | 40 | def read_mp3_from_url(url):
|
@@ -189,6 +202,59 @@ def forward(self, x):
|
189 | 202 | ep_encode_output = exported_encode.module()(chunk)
|
190 | 203 | self.assertTrue(torch.allclose(ep_encode_output, ref_encode_output, atol=1e-6))
|
191 | 204 |
|
| 205 | + def test_exported_decoder_xnnpack(self): |
| 206 | + class MimiDecode(nn.Module): |
| 207 | + def __init__(self, mimi: nn.Module): |
| 208 | + super().__init__() |
| 209 | + self.mimi_model = mimi |
| 210 | + |
| 211 | + def forward(self, x): |
| 212 | + return self.mimi_model.decode(x) |
| 213 | + |
| 214 | + sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None] |
| 215 | + pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate) |
| 216 | + chunk = sample_pcm[..., 0:pcm_chunk_size] |
| 217 | + input = self.mimi.encode(chunk) |
| 218 | + |
| 219 | + mimi_decode = MimiDecode(self.mimi) |
| 220 | + exported_decode: ExportedProgram = export(mimi_decode, (input,), strict=False) |
| 221 | + quantization_config = get_symmetric_quantization_config( |
| 222 | + is_per_channel=True, |
| 223 | + is_dynamic=True, |
| 224 | + ) |
| 225 | + quantizer = XNNPACKQuantizer() |
| 226 | + quantizer.set_global(quantization_config) |
| 227 | + m = exported_decode.module() |
| 228 | + m = prepare_pt2e(m, quantizer) |
| 229 | + m(input) |
| 230 | + m = convert_pt2e(m) |
| 231 | + print("quantized graph:") |
| 232 | + print(m.graph) |
| 233 | + # Export quantized module |
| 234 | + exported_decode: ExportedProgram = export(m, (input,), strict=False) |
| 235 | + # Lower |
| 236 | + edge_manager = to_edge_transform_and_lower( |
| 237 | + exported_decode, |
| 238 | + partitioner=[XnnpackPartitioner()], |
| 239 | + ) |
| 240 | + print("delegate graph:") |
| 241 | + print_delegation_info(edge_manager.exported_program().graph_module) |
| 242 | + exec_prog = edge_manager.to_executorch() |
| 243 | + output_file = "/tmp/mimi_decode.pte" |
| 244 | + with open(output_file, "wb") as file: |
| 245 | + exec_prog.write_to_file(file) |
| 246 | + |
| 247 | + eager_res = mimi_decode(input) |
| 248 | + runtime = Runtime.get() |
| 249 | + program = runtime.load_program(output_file) |
| 250 | + method = program.load_method("forward") |
| 251 | + flattened_x = tree_flatten(input)[0] |
| 252 | + res = method.execute(flattened_x) |
| 253 | + # Compare results |
| 254 | + sqnr = compute_sqnr(eager_res, res[0]) |
| 255 | + print(f"SQNR: {sqnr}") |
| 256 | + torch.testing.assert_close(eager_res, res[0], atol=1e-3, rtol=1e-3) |
| 257 | + |
192 | 258 |
|
193 | 259 | if __name__ == "__main__":
|
194 | 260 | unittest.main()
|
0 commit comments