Skip to content

Commit 6e2a93d

Browse files
authored
[tests] fix tests for save load components (#10977)
fix tests
1 parent 37b8edf commit 6e2a93d

File tree

6 files changed

+270
-4
lines changed

6 files changed

+270
-4
lines changed

tests/pipelines/hunyuandit/test_hunyuan_dit.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import gc
17+
import tempfile
1718
import unittest
1819

1920
import numpy as np
@@ -212,6 +213,99 @@ def test_fused_qkv_projections(self):
212213
def test_encode_prompt_works_in_isolation(self):
213214
pass
214215

216+
def test_save_load_optional_components(self):
217+
components = self.get_dummy_components()
218+
pipe = self.pipeline_class(**components)
219+
pipe.to(torch_device)
220+
pipe.set_progress_bar_config(disable=None)
221+
222+
inputs = self.get_dummy_inputs(torch_device)
223+
224+
prompt = inputs["prompt"]
225+
generator = inputs["generator"]
226+
num_inference_steps = inputs["num_inference_steps"]
227+
output_type = inputs["output_type"]
228+
229+
(
230+
prompt_embeds,
231+
negative_prompt_embeds,
232+
prompt_attention_mask,
233+
negative_prompt_attention_mask,
234+
) = pipe.encode_prompt(prompt, device=torch_device, dtype=torch.float32, text_encoder_index=0)
235+
236+
(
237+
prompt_embeds_2,
238+
negative_prompt_embeds_2,
239+
prompt_attention_mask_2,
240+
negative_prompt_attention_mask_2,
241+
) = pipe.encode_prompt(
242+
prompt,
243+
device=torch_device,
244+
dtype=torch.float32,
245+
text_encoder_index=1,
246+
)
247+
248+
# inputs with prompt converted to embeddings
249+
inputs = {
250+
"prompt_embeds": prompt_embeds,
251+
"prompt_attention_mask": prompt_attention_mask,
252+
"negative_prompt_embeds": negative_prompt_embeds,
253+
"negative_prompt_attention_mask": negative_prompt_attention_mask,
254+
"prompt_embeds_2": prompt_embeds_2,
255+
"prompt_attention_mask_2": prompt_attention_mask_2,
256+
"negative_prompt_embeds_2": negative_prompt_embeds_2,
257+
"negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
258+
"generator": generator,
259+
"num_inference_steps": num_inference_steps,
260+
"output_type": output_type,
261+
"use_resolution_binning": False,
262+
}
263+
264+
# set all optional components to None
265+
for optional_component in pipe._optional_components:
266+
setattr(pipe, optional_component, None)
267+
268+
output = pipe(**inputs)[0]
269+
270+
with tempfile.TemporaryDirectory() as tmpdir:
271+
pipe.save_pretrained(tmpdir)
272+
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
273+
pipe_loaded.to(torch_device)
274+
pipe_loaded.set_progress_bar_config(disable=None)
275+
276+
for optional_component in pipe._optional_components:
277+
self.assertTrue(
278+
getattr(pipe_loaded, optional_component) is None,
279+
f"`{optional_component}` did not stay set to None after loading.",
280+
)
281+
282+
inputs = self.get_dummy_inputs(torch_device)
283+
284+
generator = inputs["generator"]
285+
num_inference_steps = inputs["num_inference_steps"]
286+
output_type = inputs["output_type"]
287+
288+
# inputs with prompt converted to embeddings
289+
inputs = {
290+
"prompt_embeds": prompt_embeds,
291+
"prompt_attention_mask": prompt_attention_mask,
292+
"negative_prompt_embeds": negative_prompt_embeds,
293+
"negative_prompt_attention_mask": negative_prompt_attention_mask,
294+
"prompt_embeds_2": prompt_embeds_2,
295+
"prompt_attention_mask_2": prompt_attention_mask_2,
296+
"negative_prompt_embeds_2": negative_prompt_embeds_2,
297+
"negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
298+
"generator": generator,
299+
"num_inference_steps": num_inference_steps,
300+
"output_type": output_type,
301+
"use_resolution_binning": False,
302+
}
303+
304+
output_loaded = pipe_loaded(**inputs)[0]
305+
306+
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
307+
self.assertLess(max_diff, 1e-4)
308+
215309

216310
@slow
217311
@require_torch_accelerator

tests/pipelines/latte/test_latte.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import gc
1717
import inspect
18+
import tempfile
1819
import unittest
1920

2021
import numpy as np
@@ -39,7 +40,7 @@
3940
)
4041

4142
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
42-
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin
43+
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
4344

4445

4546
enable_full_determinism()
@@ -217,6 +218,73 @@ def test_xformers_attention_forwardGenerator_pass(self):
217218
def test_encode_prompt_works_in_isolation(self):
218219
pass
219220

221+
def test_save_load_optional_components(self):
222+
if not hasattr(self.pipeline_class, "_optional_components"):
223+
return
224+
225+
components = self.get_dummy_components()
226+
pipe = self.pipeline_class(**components)
227+
228+
for component in pipe.components.values():
229+
if hasattr(component, "set_default_attn_processor"):
230+
component.set_default_attn_processor()
231+
pipe.to(torch_device)
232+
pipe.set_progress_bar_config(disable=None)
233+
234+
inputs = self.get_dummy_inputs(torch_device)
235+
236+
prompt = inputs["prompt"]
237+
generator = inputs["generator"]
238+
239+
(
240+
prompt_embeds,
241+
negative_prompt_embeds,
242+
) = pipe.encode_prompt(prompt)
243+
244+
# inputs with prompt converted to embeddings
245+
inputs = {
246+
"prompt_embeds": prompt_embeds,
247+
"negative_prompt": None,
248+
"negative_prompt_embeds": negative_prompt_embeds,
249+
"generator": generator,
250+
"num_inference_steps": 2,
251+
"guidance_scale": 5.0,
252+
"height": 8,
253+
"width": 8,
254+
"video_length": 1,
255+
"mask_feature": False,
256+
"output_type": "pt",
257+
"clean_caption": False,
258+
}
259+
260+
# set all optional components to None
261+
for optional_component in pipe._optional_components:
262+
setattr(pipe, optional_component, None)
263+
264+
output = pipe(**inputs)[0]
265+
266+
with tempfile.TemporaryDirectory() as tmpdir:
267+
pipe.save_pretrained(tmpdir, safe_serialization=False)
268+
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
269+
pipe_loaded.to(torch_device)
270+
271+
for component in pipe_loaded.components.values():
272+
if hasattr(component, "set_default_attn_processor"):
273+
component.set_default_attn_processor()
274+
275+
pipe_loaded.set_progress_bar_config(disable=None)
276+
277+
for optional_component in pipe._optional_components:
278+
self.assertTrue(
279+
getattr(pipe_loaded, optional_component) is None,
280+
f"`{optional_component}` did not stay set to None after loading.",
281+
)
282+
283+
output_loaded = pipe_loaded(**inputs)[0]
284+
285+
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
286+
self.assertLess(max_diff, 1.0)
287+
220288

221289
@slow
222290
@require_torch_accelerator

tests/pipelines/pag/test_pag_hunyuan_dit.py

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import inspect
17+
import tempfile
1718
import unittest
1819

1920
import numpy as np
@@ -27,9 +28,7 @@
2728
HunyuanDiTPAGPipeline,
2829
HunyuanDiTPipeline,
2930
)
30-
from diffusers.utils.testing_utils import (
31-
enable_full_determinism,
32-
)
31+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
3332

3433
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
3534
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -269,3 +268,96 @@ def test_pag_applied_layers(self):
269268
)
270269
def test_encode_prompt_works_in_isolation(self):
271270
pass
271+
272+
def test_save_load_optional_components(self):
273+
components = self.get_dummy_components()
274+
pipe = self.pipeline_class(**components)
275+
pipe.to(torch_device)
276+
pipe.set_progress_bar_config(disable=None)
277+
278+
inputs = self.get_dummy_inputs(torch_device)
279+
280+
prompt = inputs["prompt"]
281+
generator = inputs["generator"]
282+
num_inference_steps = inputs["num_inference_steps"]
283+
output_type = inputs["output_type"]
284+
285+
(
286+
prompt_embeds,
287+
negative_prompt_embeds,
288+
prompt_attention_mask,
289+
negative_prompt_attention_mask,
290+
) = pipe.encode_prompt(prompt, device=torch_device, dtype=torch.float32, text_encoder_index=0)
291+
292+
(
293+
prompt_embeds_2,
294+
negative_prompt_embeds_2,
295+
prompt_attention_mask_2,
296+
negative_prompt_attention_mask_2,
297+
) = pipe.encode_prompt(
298+
prompt,
299+
device=torch_device,
300+
dtype=torch.float32,
301+
text_encoder_index=1,
302+
)
303+
304+
# inputs with prompt converted to embeddings
305+
inputs = {
306+
"prompt_embeds": prompt_embeds,
307+
"prompt_attention_mask": prompt_attention_mask,
308+
"negative_prompt_embeds": negative_prompt_embeds,
309+
"negative_prompt_attention_mask": negative_prompt_attention_mask,
310+
"prompt_embeds_2": prompt_embeds_2,
311+
"prompt_attention_mask_2": prompt_attention_mask_2,
312+
"negative_prompt_embeds_2": negative_prompt_embeds_2,
313+
"negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
314+
"generator": generator,
315+
"num_inference_steps": num_inference_steps,
316+
"output_type": output_type,
317+
"use_resolution_binning": False,
318+
}
319+
320+
# set all optional components to None
321+
for optional_component in pipe._optional_components:
322+
setattr(pipe, optional_component, None)
323+
324+
output = pipe(**inputs)[0]
325+
326+
with tempfile.TemporaryDirectory() as tmpdir:
327+
pipe.save_pretrained(tmpdir)
328+
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
329+
pipe_loaded.to(torch_device)
330+
pipe_loaded.set_progress_bar_config(disable=None)
331+
332+
for optional_component in pipe._optional_components:
333+
self.assertTrue(
334+
getattr(pipe_loaded, optional_component) is None,
335+
f"`{optional_component}` did not stay set to None after loading.",
336+
)
337+
338+
inputs = self.get_dummy_inputs(torch_device)
339+
340+
generator = inputs["generator"]
341+
num_inference_steps = inputs["num_inference_steps"]
342+
output_type = inputs["output_type"]
343+
344+
# inputs with prompt converted to embeddings
345+
inputs = {
346+
"prompt_embeds": prompt_embeds,
347+
"prompt_attention_mask": prompt_attention_mask,
348+
"negative_prompt_embeds": negative_prompt_embeds,
349+
"negative_prompt_attention_mask": negative_prompt_attention_mask,
350+
"prompt_embeds_2": prompt_embeds_2,
351+
"prompt_attention_mask_2": prompt_attention_mask_2,
352+
"negative_prompt_embeds_2": negative_prompt_embeds_2,
353+
"negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
354+
"generator": generator,
355+
"num_inference_steps": num_inference_steps,
356+
"output_type": output_type,
357+
"use_resolution_binning": False,
358+
}
359+
360+
output_loaded = pipe_loaded(**inputs)[0]
361+
362+
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
363+
self.assertLess(max_diff, 1e-4)

tests/pipelines/pag/test_pag_pixart_sigma.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,7 @@ def test_components_function(self):
343343

344344
self.assertTrue(hasattr(pipe, "components"))
345345
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
346+
347+
@unittest.skip("Test is already covered through encode_prompt isolation.")
348+
def test_save_load_optional_components(self):
349+
pass

tests/pipelines/pixart_alpha/test_pixart.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ def test_inference_non_square_images(self):
144144
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
145145
self.assertLessEqual(max_diff, 1e-3)
146146

147+
@unittest.skip("Test is already covered through encode_prompt isolation.")
148+
def test_save_load_optional_components(self):
149+
pass
150+
147151
def test_inference_with_embeddings_and_multiple_images(self):
148152
components = self.get_dummy_components()
149153
pipe = self.pipeline_class(**components)

tests/pipelines/pixart_sigma/test_pixart.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,10 @@ def test_inference_with_multiple_images_per_prompt(self):
239239
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
240240
self.assertLessEqual(max_diff, 1e-3)
241241

242+
@unittest.skip("Test is already covered through encode_prompt isolation.")
243+
def test_save_load_optional_components(self):
244+
pass
245+
242246
def test_inference_batch_single_identical(self):
243247
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
244248

0 commit comments

Comments
 (0)