@@ -230,10 +230,13 @@ def test_stable_diffusion_two_xl_mixture_of_denoiser(self):
230
230
pipe_2 = StableDiffusionXLImg2ImgPipeline (** components ).to (torch_device )
231
231
pipe_2 .unet .set_default_attn_processor ()
232
232
233
- def assert_run_mixture (num_steps , split , scheduler_cls ):
233
+ def assert_run_mixture (num_steps , split , scheduler_cls_orig ):
234
234
inputs = self .get_dummy_inputs (torch_device )
235
235
inputs ["num_inference_steps" ] = num_steps
236
236
237
+ class scheduler_cls (scheduler_cls_orig ):
238
+ pass
239
+
237
240
pipe_1 .scheduler = scheduler_cls .from_config (pipe_1 .scheduler .config )
238
241
pipe_2 .scheduler = scheduler_cls .from_config (pipe_2 .scheduler .config )
239
242
@@ -287,10 +290,13 @@ def test_stable_diffusion_three_xl_mixture_of_denoiser(self):
287
290
pipe_3 = StableDiffusionXLImg2ImgPipeline (** components ).to (torch_device )
288
291
pipe_3 .unet .set_default_attn_processor ()
289
292
290
- def assert_run_mixture (num_steps , split_1 , split_2 , scheduler_cls ):
293
+ def assert_run_mixture (num_steps , split_1 , split_2 , scheduler_cls_orig ):
291
294
inputs = self .get_dummy_inputs (torch_device )
292
295
inputs ["num_inference_steps" ] = num_steps
293
296
297
+ class scheduler_cls (scheduler_cls_orig ):
298
+ pass
299
+
294
300
pipe_1 .scheduler = scheduler_cls .from_config (pipe_1 .scheduler .config )
295
301
pipe_2 .scheduler = scheduler_cls .from_config (pipe_2 .scheduler .config )
296
302
pipe_3 .scheduler = scheduler_cls .from_config (pipe_3 .scheduler .config )
0 commit comments