|
15 | 15 |
|
16 | 16 | import gc
|
17 | 17 | import random
|
| 18 | +import tempfile |
18 | 19 | import unittest
|
19 | 20 |
|
20 | 21 | import numpy as np
|
@@ -141,6 +142,41 @@ def get_dummy_inversion_inputs(self, device, seed=0):
|
141 | 142 | }
|
142 | 143 | return inputs
|
143 | 144 |
|
| 145 | + def test_save_load_optional_components(self): |
| 146 | + if not hasattr(self.pipeline_class, "_optional_components"): |
| 147 | + return |
| 148 | + |
| 149 | + components = self.get_dummy_components() |
| 150 | + pipe = self.pipeline_class(**components) |
| 151 | + pipe.to(torch_device) |
| 152 | + pipe.set_progress_bar_config(disable=None) |
| 153 | + |
| 154 | + # set all optional components to None and update pipeline config accordingly |
| 155 | + for optional_component in pipe._optional_components: |
| 156 | + setattr(pipe, optional_component, None) |
| 157 | + pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components}) |
| 158 | + |
| 159 | + inputs = self.get_dummy_inputs(torch_device) |
| 160 | + output = pipe(**inputs)[0] |
| 161 | + |
| 162 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 163 | + pipe.save_pretrained(tmpdir) |
| 164 | + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) |
| 165 | + pipe_loaded.to(torch_device) |
| 166 | + pipe_loaded.set_progress_bar_config(disable=None) |
| 167 | + |
| 168 | + for optional_component in pipe._optional_components: |
| 169 | + self.assertTrue( |
| 170 | + getattr(pipe_loaded, optional_component) is None, |
| 171 | + f"`{optional_component}` did not stay set to None after loading.", |
| 172 | + ) |
| 173 | + |
| 174 | + inputs = self.get_dummy_inputs(torch_device) |
| 175 | + output_loaded = pipe_loaded(**inputs)[0] |
| 176 | + |
| 177 | + max_diff = np.abs(output - output_loaded).max() |
| 178 | + self.assertLess(max_diff, 1e-4) |
| 179 | + |
144 | 180 | def test_stable_diffusion_pix2pix_zero_inversion(self):
|
145 | 181 | device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
146 | 182 | components = self.get_dummy_components()
|
|
0 commit comments