1
+ import copy
1
2
from dataclasses import dataclass
2
3
from typing import Callable , List , Optional , Union
3
4
@@ -56,8 +57,8 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
56
57
is_cross_attention = encoder_hidden_states is not None
57
58
if encoder_hidden_states is None :
58
59
encoder_hidden_states = hidden_states
59
- elif attn .cross_attention_norm :
60
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
60
+ elif attn .norm_cross :
61
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
61
62
62
63
key = attn .to_k (encoder_hidden_states )
63
64
value = attn .to_v (encoder_hidden_states )
@@ -285,7 +286,8 @@ def backward_loop(
285
286
latents: latents of backward process output at time timesteps[-1]
286
287
"""
287
288
do_classifier_free_guidance = guidance_scale > 1.0
288
- with self .progress_bar (total = len (timesteps )) as progress_bar :
289
+ num_steps = (len (timesteps ) - num_warmup_steps ) // self .scheduler .order
290
+ with self .progress_bar (total = num_steps ) as progress_bar :
289
291
for i , t in enumerate (timesteps ):
290
292
# expand the latents if we are doing classifier free guidance
291
293
latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
@@ -465,6 +467,7 @@ def __call__(
465
467
extra_step_kwargs = extra_step_kwargs ,
466
468
num_warmup_steps = num_warmup_steps ,
467
469
)
470
+ scheduler_copy = copy .deepcopy (self .scheduler )
468
471
469
472
# Perform the second backward process up to time T_0
470
473
x_1_t0 = self .backward_loop (
@@ -475,7 +478,7 @@ def __call__(
475
478
callback = callback ,
476
479
callback_steps = callback_steps ,
477
480
extra_step_kwargs = extra_step_kwargs ,
478
- num_warmup_steps = num_warmup_steps ,
481
+ num_warmup_steps = 0 ,
479
482
)
480
483
481
484
# Propagate first frame latents at time T_0 to remaining frames
@@ -502,7 +505,7 @@ def __call__(
502
505
b , l , d = prompt_embeds .size ()
503
506
prompt_embeds = prompt_embeds [:, None ].repeat (1 , video_length , 1 , 1 ).reshape (b * video_length , l , d )
504
507
505
- self .scheduler . set_timesteps ( num_inference_steps , device = device )
508
+ self .scheduler = scheduler_copy
506
509
x_1k_0 = self .backward_loop (
507
510
timesteps = timesteps [- t1 - 1 :],
508
511
prompt_embeds = prompt_embeds ,
@@ -511,7 +514,7 @@ def __call__(
511
514
callback = callback ,
512
515
callback_steps = callback_steps ,
513
516
extra_step_kwargs = extra_step_kwargs ,
514
- num_warmup_steps = num_warmup_steps ,
517
+ num_warmup_steps = 0 ,
515
518
)
516
519
latents = x_1k_0
517
520
0 commit comments