Skip to content

Commit ee2f277

Browse files
[tests] use parent class for monkey patching to not break other tests (#4088)
* [tests] use parent class for monkey patching to not break other tests * fix
1 parent 692b7a9 commit ee2f277

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,13 @@ def test_stable_diffusion_two_xl_mixture_of_denoiser(self):
230230
pipe_2 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device)
231231
pipe_2.unet.set_default_attn_processor()
232232

233-
def assert_run_mixture(num_steps, split, scheduler_cls):
233+
def assert_run_mixture(num_steps, split, scheduler_cls_orig):
234234
inputs = self.get_dummy_inputs(torch_device)
235235
inputs["num_inference_steps"] = num_steps
236236

237+
class scheduler_cls(scheduler_cls_orig):
238+
pass
239+
237240
pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
238241
pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
239242

@@ -287,10 +290,13 @@ def test_stable_diffusion_three_xl_mixture_of_denoiser(self):
287290
pipe_3 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device)
288291
pipe_3.unet.set_default_attn_processor()
289292

290-
def assert_run_mixture(num_steps, split_1, split_2, scheduler_cls):
293+
def assert_run_mixture(num_steps, split_1, split_2, scheduler_cls_orig):
291294
inputs = self.get_dummy_inputs(torch_device)
292295
inputs["num_inference_steps"] = num_steps
293296

297+
class scheduler_cls(scheduler_cls_orig):
298+
pass
299+
294300
pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
295301
pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
296302
pipe_3.scheduler = scheduler_cls.from_config(pipe_3.scheduler.config)

0 commit comments

Comments
 (0)