Skip to content

Commit 6a4168f

Browse files
authored
Test mimi: remove redundant codes
Differential Revision: D71698626 Pull Request resolved: #9528
1 parent da7b003 commit 6a4168f

File tree

1 file changed

+0
-50
lines changed

1 file changed

+0
-50
lines changed

examples/models/moshi/mimi/test_mimi.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -147,56 +147,6 @@ def test_streaming_encoding_decoding(self):
147147
pcm_ref = self.mimi.decode(all_codes_th)
148148
self.assertTrue(torch.allclose(pcm_ref, all_pcms, atol=1e-5))
149149

150-
def test_exported_decoding(self):
151-
"""Ensure exported decoding model is consistent with reference output."""
152-
153-
class MimiDecode(nn.Module):
154-
def __init__(self, mimi: nn.Module):
155-
super().__init__()
156-
self.mimi_model = mimi
157-
158-
def forward(self, x):
159-
return self.mimi_model.decode(x)
160-
161-
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None]
162-
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
163-
chunk = sample_pcm[..., 0:pcm_chunk_size]
164-
input = self.mimi.encode(chunk)
165-
166-
mimi_decode = MimiDecode(self.mimi)
167-
ref_decode_output = mimi_decode(input)
168-
exported_decode: ExportedProgram = export(mimi_decode, (input,), strict=False)
169-
ep_decode_output = exported_decode.module()(input)
170-
self.assertTrue(torch.allclose(ep_decode_output, ref_decode_output, atol=1e-6))
171-
172-
# PT2E Quantization
173-
quantizer = XNNPACKQuantizer()
174-
# 8 bit by default
175-
quantization_config = get_symmetric_quantization_config(
176-
is_per_channel=True,
177-
is_dynamic=True,
178-
)
179-
quantizer.set_global(quantization_config)
180-
m = exported_decode.module()
181-
m = prepare_pt2e(m, quantizer)
182-
m(input)
183-
m = convert_pt2e(m)
184-
print("quantized graph:")
185-
print(m.graph)
186-
# Export quantized module
187-
exported_decode: ExportedProgram = export(m, (input,), strict=False)
188-
189-
# Lower
190-
edge_manager = to_edge_transform_and_lower(
191-
exported_decode,
192-
partitioner=[XnnpackPartitioner()],
193-
)
194-
195-
exec_prog = edge_manager.to_executorch()
196-
print("exec graph:")
197-
print(exec_prog.exported_program().graph)
198-
assert len(exec_prog.exported_program().graph.nodes) > 1
199-
200150
def test_exported_encoding(self):
201151
"""Ensure exported encoding model is consistent with reference output."""
202152

0 commit comments

Comments
 (0)