Skip to content

Commit f65a5f8

Browse files
Martin Yuanfacebook-github-bot
Martin Yuan
authored andcommitted
Test mimi: remove redundant codes
Summary: Remove redundant code. Differential Revision: D71698626
1 parent 20abf34 commit f65a5f8

File tree

1 file changed

+0
-50
lines changed

1 file changed

+0
-50
lines changed

examples/models/moshi/mimi/test_mimi.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -132,56 +132,6 @@ def test_streaming_encoding_decoding(self):
132132
pcm_ref = self.mimi.decode(all_codes_th)
133133
self.assertTrue(torch.allclose(pcm_ref, all_pcms, atol=1e-5))
134134

135-
def test_exported_decoding(self):
136-
"""Ensure exported decoding model is consistent with reference output."""
137-
138-
class MimiDecode(nn.Module):
139-
def __init__(self, mimi: nn.Module):
140-
super().__init__()
141-
self.mimi_model = mimi
142-
143-
def forward(self, x):
144-
return self.mimi_model.decode(x)
145-
146-
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None]
147-
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
148-
chunk = sample_pcm[..., 0:pcm_chunk_size]
149-
input = self.mimi.encode(chunk)
150-
151-
mimi_decode = MimiDecode(self.mimi)
152-
ref_decode_output = mimi_decode(input)
153-
exported_decode: ExportedProgram = export(mimi_decode, (input,), strict=False)
154-
ep_decode_output = exported_decode.module()(input)
155-
self.assertTrue(torch.allclose(ep_decode_output, ref_decode_output, atol=1e-6))
156-
157-
# PT2E Quantization
158-
quantizer = XNNPACKQuantizer()
159-
# 8 bit by default
160-
quantization_config = get_symmetric_quantization_config(
161-
is_per_channel=True,
162-
is_dynamic=True,
163-
)
164-
quantizer.set_global(quantization_config)
165-
m = exported_decode.module()
166-
m = prepare_pt2e(m, quantizer)
167-
m(input)
168-
m = convert_pt2e(m)
169-
print("quantized graph:")
170-
print(m.graph)
171-
# Export quantized module
172-
exported_decode: ExportedProgram = export(m, (input,), strict=False)
173-
174-
# Lower
175-
edge_manager = to_edge_transform_and_lower(
176-
exported_decode,
177-
partitioner=[XnnpackPartitioner()],
178-
)
179-
180-
exec_prog = edge_manager.to_executorch()
181-
print("exec graph:")
182-
print(exec_prog.exported_program().graph)
183-
assert len(exec_prog.exported_program().graph.nodes) > 1
184-
185135
def test_exported_encoding(self):
186136
"""Ensure exported encoding model is consistent with reference output."""
187137

0 commit comments

Comments
 (0)