@@ -147,56 +147,6 @@ def test_streaming_encoding_decoding(self):
147
147
pcm_ref = self .mimi .decode (all_codes_th )
148
148
self .assertTrue (torch .allclose (pcm_ref , all_pcms , atol = 1e-5 ))
149
149
150
- def test_exported_decoding (self ):
151
- """Ensure exported decoding model is consistent with reference output."""
152
-
153
- class MimiDecode (nn .Module ):
154
- def __init__ (self , mimi : nn .Module ):
155
- super ().__init__ ()
156
- self .mimi_model = mimi
157
-
158
- def forward (self , x ):
159
- return self .mimi_model .decode (x )
160
-
161
- sample_pcm = torch .tensor (self .sample_pcm , device = self .device )[None ]
162
- pcm_chunk_size = int (self .mimi .sample_rate / self .mimi .frame_rate )
163
- chunk = sample_pcm [..., 0 :pcm_chunk_size ]
164
- input = self .mimi .encode (chunk )
165
-
166
- mimi_decode = MimiDecode (self .mimi )
167
- ref_decode_output = mimi_decode (input )
168
- exported_decode : ExportedProgram = export (mimi_decode , (input ,), strict = False )
169
- ep_decode_output = exported_decode .module ()(input )
170
- self .assertTrue (torch .allclose (ep_decode_output , ref_decode_output , atol = 1e-6 ))
171
-
172
- # PT2E Quantization
173
- quantizer = XNNPACKQuantizer ()
174
- # 8 bit by default
175
- quantization_config = get_symmetric_quantization_config (
176
- is_per_channel = True ,
177
- is_dynamic = True ,
178
- )
179
- quantizer .set_global (quantization_config )
180
- m = exported_decode .module ()
181
- m = prepare_pt2e (m , quantizer )
182
- m (input )
183
- m = convert_pt2e (m )
184
- print ("quantized graph:" )
185
- print (m .graph )
186
- # Export quantized module
187
- exported_decode : ExportedProgram = export (m , (input ,), strict = False )
188
-
189
- # Lower
190
- edge_manager = to_edge_transform_and_lower (
191
- exported_decode ,
192
- partitioner = [XnnpackPartitioner ()],
193
- )
194
-
195
- exec_prog = edge_manager .to_executorch ()
196
- print ("exec graph:" )
197
- print (exec_prog .exported_program ().graph )
198
- assert len (exec_prog .exported_program ().graph .nodes ) > 1
199
-
200
150
def test_exported_encoding (self ):
201
151
"""Ensure exported encoding model is consistent with reference output."""
202
152
0 commit comments