14
14
# limitations under the License.
15
15
16
16
import gc
17
+ import random
18
+ import tempfile
17
19
import unittest
18
20
19
21
import numpy as np
30
32
StableDiffusionPix2PixZeroPipeline ,
31
33
UNet2DConditionModel ,
32
34
)
33
- from diffusers .utils import load_numpy , slow , torch_device
35
+ from diffusers .utils import floats_tensor , load_numpy , slow , torch_device
34
36
from diffusers .utils .testing_utils import load_image , load_pt , require_torch_gpu , skip_mps
35
37
36
38
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS , TEXT_GUIDED_IMAGE_VARIATION_PARAMS
@@ -69,6 +71,7 @@ def get_dummy_components(self):
69
71
cross_attention_dim = 32 ,
70
72
)
71
73
scheduler = DDIMScheduler ()
74
+ inverse_scheduler = DDIMInverseScheduler ()
72
75
torch .manual_seed (0 )
73
76
vae = AutoencoderKL (
74
77
block_out_channels = [32 , 64 ],
@@ -101,7 +104,7 @@ def get_dummy_components(self):
101
104
"tokenizer" : tokenizer ,
102
105
"safety_checker" : None ,
103
106
"feature_extractor" : None ,
104
- "inverse_scheduler" : None ,
107
+ "inverse_scheduler" : inverse_scheduler ,
105
108
"caption_generator" : None ,
106
109
"caption_processor" : None ,
107
110
}
@@ -122,6 +125,90 @@ def get_dummy_inputs(self, device, seed=0):
122
125
}
123
126
return inputs
124
127
128
+ def get_dummy_inversion_inputs (self , device , seed = 0 ):
129
+ dummy_image = floats_tensor ((2 , 3 , 32 , 32 ), rng = random .Random (seed )).to (torch_device )
130
+ generator = torch .manual_seed (seed )
131
+
132
+ inputs = {
133
+ "prompt" : [
134
+ "A painting of a squirrel eating a burger" ,
135
+ "A painting of a burger eating a squirrel" ,
136
+ ],
137
+ "image" : dummy_image .cpu (),
138
+ "num_inference_steps" : 2 ,
139
+ "guidance_scale" : 6.0 ,
140
+ "generator" : generator ,
141
+ "output_type" : "numpy" ,
142
+ }
143
+ return inputs
144
+
145
+ def test_save_load_optional_components (self ):
146
+ if not hasattr (self .pipeline_class , "_optional_components" ):
147
+ return
148
+
149
+ components = self .get_dummy_components ()
150
+ pipe = self .pipeline_class (** components )
151
+ pipe .to (torch_device )
152
+ pipe .set_progress_bar_config (disable = None )
153
+
154
+ # set all optional components to None and update pipeline config accordingly
155
+ for optional_component in pipe ._optional_components :
156
+ setattr (pipe , optional_component , None )
157
+ pipe .register_modules (** {optional_component : None for optional_component in pipe ._optional_components })
158
+
159
+ inputs = self .get_dummy_inputs (torch_device )
160
+ output = pipe (** inputs )[0 ]
161
+
162
+ with tempfile .TemporaryDirectory () as tmpdir :
163
+ pipe .save_pretrained (tmpdir )
164
+ pipe_loaded = self .pipeline_class .from_pretrained (tmpdir )
165
+ pipe_loaded .to (torch_device )
166
+ pipe_loaded .set_progress_bar_config (disable = None )
167
+
168
+ for optional_component in pipe ._optional_components :
169
+ self .assertTrue (
170
+ getattr (pipe_loaded , optional_component ) is None ,
171
+ f"`{ optional_component } ` did not stay set to None after loading." ,
172
+ )
173
+
174
+ inputs = self .get_dummy_inputs (torch_device )
175
+ output_loaded = pipe_loaded (** inputs )[0 ]
176
+
177
+ max_diff = np .abs (output - output_loaded ).max ()
178
+ self .assertLess (max_diff , 1e-4 )
179
+
180
+ def test_stable_diffusion_pix2pix_zero_inversion (self ):
181
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
182
+ components = self .get_dummy_components ()
183
+ sd_pipe = StableDiffusionPix2PixZeroPipeline (** components )
184
+ sd_pipe = sd_pipe .to (device )
185
+ sd_pipe .set_progress_bar_config (disable = None )
186
+
187
+ inputs = self .get_dummy_inversion_inputs (device )
188
+ inputs ["image" ] = inputs ["image" ][:1 ]
189
+ inputs ["prompt" ] = inputs ["prompt" ][:1 ]
190
+ image = sd_pipe .invert (** inputs ).images
191
+ image_slice = image [0 , - 3 :, - 3 :, - 1 ]
192
+ assert image .shape == (1 , 32 , 32 , 3 )
193
+ expected_slice = np .array ([0.4833 , 0.4696 , 0.5574 , 0.5194 , 0.5248 , 0.5638 , 0.5040 , 0.5423 , 0.5072 ])
194
+
195
+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-3
196
+
197
+ def test_stable_diffusion_pix2pix_zero_inversion_batch (self ):
198
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
199
+ components = self .get_dummy_components ()
200
+ sd_pipe = StableDiffusionPix2PixZeroPipeline (** components )
201
+ sd_pipe = sd_pipe .to (device )
202
+ sd_pipe .set_progress_bar_config (disable = None )
203
+
204
+ inputs = self .get_dummy_inversion_inputs (device )
205
+ image = sd_pipe .invert (** inputs ).images
206
+ image_slice = image [1 , - 3 :, - 3 :, - 1 ]
207
+ assert image .shape == (2 , 32 , 32 , 3 )
208
+ expected_slice = np .array ([0.6672 , 0.5203 , 0.4908 , 0.4376 , 0.4517 , 0.5544 , 0.4605 , 0.4826 , 0.5007 ])
209
+
210
+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-3
211
+
125
212
def test_stable_diffusion_pix2pix_zero_default_case (self ):
126
213
device = "cpu" # ensure determinism for the device-dependent torch.Generator
127
214
components = self .get_dummy_components ()
0 commit comments