Skip to content

Commit 10c54cb

Browse files
19and99dg845
authored andcommitted
Text2video zero refinements (huggingface#3070)
* fix progress bar issue in pipeline_text_to_video_zero.py. Copy scheduler after first backward * fix tensor loading in test_text_to_video_zero.py * make style && make quality
1 parent 115e382 commit 10c54cb

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
from dataclasses import dataclass
23
from typing import Callable, List, Optional, Union
34

@@ -56,8 +57,8 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
5657
is_cross_attention = encoder_hidden_states is not None
5758
if encoder_hidden_states is None:
5859
encoder_hidden_states = hidden_states
59-
elif attn.cross_attention_norm:
60-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
60+
elif attn.norm_cross:
61+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
6162

6263
key = attn.to_k(encoder_hidden_states)
6364
value = attn.to_v(encoder_hidden_states)
@@ -285,7 +286,8 @@ def backward_loop(
285286
latents: latents of backward process output at time timesteps[-1]
286287
"""
287288
do_classifier_free_guidance = guidance_scale > 1.0
288-
with self.progress_bar(total=len(timesteps)) as progress_bar:
289+
num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order
290+
with self.progress_bar(total=num_steps) as progress_bar:
289291
for i, t in enumerate(timesteps):
290292
# expand the latents if we are doing classifier free guidance
291293
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -465,6 +467,7 @@ def __call__(
465467
extra_step_kwargs=extra_step_kwargs,
466468
num_warmup_steps=num_warmup_steps,
467469
)
470+
scheduler_copy = copy.deepcopy(self.scheduler)
468471

469472
# Perform the second backward process up to time T_0
470473
x_1_t0 = self.backward_loop(
@@ -475,7 +478,7 @@ def __call__(
475478
callback=callback,
476479
callback_steps=callback_steps,
477480
extra_step_kwargs=extra_step_kwargs,
478-
num_warmup_steps=num_warmup_steps,
481+
num_warmup_steps=0,
479482
)
480483

481484
# Propagate first frame latents at time T_0 to remaining frames
@@ -502,7 +505,7 @@ def __call__(
502505
b, l, d = prompt_embeds.size()
503506
prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d)
504507

505-
self.scheduler.set_timesteps(num_inference_steps, device=device)
508+
self.scheduler = scheduler_copy
506509
x_1k_0 = self.backward_loop(
507510
timesteps=timesteps[-t1 - 1 :],
508511
prompt_embeds=prompt_embeds,
@@ -511,7 +514,7 @@ def __call__(
511514
callback=callback,
512515
callback_steps=callback_steps,
513516
extra_step_kwargs=extra_step_kwargs,
514-
num_warmup_steps=num_warmup_steps,
517+
num_warmup_steps=0,
515518
)
516519
latents = x_1k_0
517520

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
load_hf_numpy,
8787
load_image,
8888
load_numpy,
89+
load_pt,
8990
nightly,
9091
parse_flag_from_env,
9192
print_tensor_test,

tests/pipelines/text_to_video/test_text_to_video_zero.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from diffusers import DDIMScheduler, TextToVideoZeroPipeline
21-
from diffusers.utils import require_torch_gpu, slow
21+
from diffusers.utils import load_pt, require_torch_gpu, slow
2222

2323
from ...test_pipelines_common import assert_mean_pixel_difference
2424

@@ -35,8 +35,8 @@ def test_full_model(self):
3535
prompt = "A bear is playing a guitar on Times Square"
3636
result = pipe(prompt=prompt, generator=generator).images
3737

38-
expected_result = torch.load(
39-
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/tree/main/text-to-video/A bear is playing a guitar on Times Square.pt"
38+
expected_result = load_pt(
39+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/A bear is playing a guitar on Times Square.pt"
4040
)
4141

4242
assert_mean_pixel_difference(result, expected_result)

0 commit comments

Comments
 (0)