@@ -132,56 +132,6 @@ def test_streaming_encoding_decoding(self):
132
132
pcm_ref = self .mimi .decode (all_codes_th )
133
133
self .assertTrue (torch .allclose (pcm_ref , all_pcms , atol = 1e-5 ))
134
134
135
- def test_exported_decoding (self ):
136
- """Ensure exported decoding model is consistent with reference output."""
137
-
138
- class MimiDecode (nn .Module ):
139
- def __init__ (self , mimi : nn .Module ):
140
- super ().__init__ ()
141
- self .mimi_model = mimi
142
-
143
- def forward (self , x ):
144
- return self .mimi_model .decode (x )
145
-
146
- sample_pcm = torch .tensor (self .sample_pcm , device = self .device )[None ]
147
- pcm_chunk_size = int (self .mimi .sample_rate / self .mimi .frame_rate )
148
- chunk = sample_pcm [..., 0 :pcm_chunk_size ]
149
- input = self .mimi .encode (chunk )
150
-
151
- mimi_decode = MimiDecode (self .mimi )
152
- ref_decode_output = mimi_decode (input )
153
- exported_decode : ExportedProgram = export (mimi_decode , (input ,), strict = False )
154
- ep_decode_output = exported_decode .module ()(input )
155
- self .assertTrue (torch .allclose (ep_decode_output , ref_decode_output , atol = 1e-6 ))
156
-
157
- # PT2E Quantization
158
- quantizer = XNNPACKQuantizer ()
159
- # 8 bit by default
160
- quantization_config = get_symmetric_quantization_config (
161
- is_per_channel = True ,
162
- is_dynamic = True ,
163
- )
164
- quantizer .set_global (quantization_config )
165
- m = exported_decode .module ()
166
- m = prepare_pt2e (m , quantizer )
167
- m (input )
168
- m = convert_pt2e (m )
169
- print ("quantized graph:" )
170
- print (m .graph )
171
- # Export quantized module
172
- exported_decode : ExportedProgram = export (m , (input ,), strict = False )
173
-
174
- # Lower
175
- edge_manager = to_edge_transform_and_lower (
176
- exported_decode ,
177
- partitioner = [XnnpackPartitioner ()],
178
- )
179
-
180
- exec_prog = edge_manager .to_executorch ()
181
- print ("exec graph:" )
182
- print (exec_prog .exported_program ().graph )
183
- assert len (exec_prog .exported_program ().graph .nodes ) > 1
184
-
185
135
def test_exported_encoding (self ):
186
136
"""Ensure exported encoding model is consistent with reference output."""
187
137
0 commit comments