Skip to content

Commit a558126

Browse files
committed
Move fixed logic to specific test class
1 parent 7ef29f5 commit a558126

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import gc
1717
import random
18+
import tempfile
1819
import unittest
1920

2021
import numpy as np
@@ -141,6 +142,41 @@ def get_dummy_inversion_inputs(self, device, seed=0):
141142
}
142143
return inputs
143144

145+
def test_save_load_optional_components(self):
146+
if not hasattr(self.pipeline_class, "_optional_components"):
147+
return
148+
149+
components = self.get_dummy_components()
150+
pipe = self.pipeline_class(**components)
151+
pipe.to(torch_device)
152+
pipe.set_progress_bar_config(disable=None)
153+
154+
# set all optional components to None and update pipeline config accordingly
155+
for optional_component in pipe._optional_components:
156+
setattr(pipe, optional_component, None)
157+
pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components})
158+
159+
inputs = self.get_dummy_inputs(torch_device)
160+
output = pipe(**inputs)[0]
161+
162+
with tempfile.TemporaryDirectory() as tmpdir:
163+
pipe.save_pretrained(tmpdir)
164+
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
165+
pipe_loaded.to(torch_device)
166+
pipe_loaded.set_progress_bar_config(disable=None)
167+
168+
for optional_component in pipe._optional_components:
169+
self.assertTrue(
170+
getattr(pipe_loaded, optional_component) is None,
171+
f"`{optional_component}` did not stay set to None after loading.",
172+
)
173+
174+
inputs = self.get_dummy_inputs(torch_device)
175+
output_loaded = pipe_loaded(**inputs)[0]
176+
177+
max_diff = np.abs(output - output_loaded).max()
178+
self.assertLess(max_diff, 1e-4)
179+
144180
def test_stable_diffusion_pix2pix_zero_inversion(self):
145181
device = "cpu" # ensure determinism for the device-dependent torch.Generator
146182
components = self.get_dummy_components()

0 commit comments

Comments
 (0)