19
19
import PIL
20
20
import regex as re
21
21
import torch
22
- from transformers import AutoTokenizer , CLIPImageProcessor , CLIPVisionModel , UMT5EncoderModel
22
+ from transformers import AutoTokenizer , CLIPImageProcessor , CLIPVisionModelWithProjection , UMT5EncoderModel
23
23
24
24
from ...callbacks import MultiPipelineCallbacks , PipelineCallback
25
25
from ...image_processor import PipelineImageInput
49
49
>>> import numpy as np
50
50
>>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
51
51
>>> from diffusers.utils import export_to_video, load_image
52
- >>> from transformers import CLIPVisionModel
52
+ >>> from transformers import CLIPVisionModelWithProjection
53
53
54
54
>>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
55
55
>>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
56
- >>> image_encoder = CLIPVisionModel .from_pretrained(
56
+ >>> image_encoder = CLIPVisionModelWithProjection .from_pretrained(
57
57
... model_id, subfolder="image_encoder", torch_dtype=torch.float32
58
58
... )
59
59
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
@@ -109,14 +109,30 @@ def prompt_clean(text):
109
109
110
110
111
111
def retrieve_latents (
112
- encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
112
+ encoder_output : torch .Tensor ,
113
+ latents_mean : torch .Tensor ,
114
+ latents_std : torch .Tensor ,
115
+ generator : Optional [torch .Generator ] = None ,
116
+ sample_mode : str = "sample" ,
113
117
):
114
118
if hasattr (encoder_output , "latent_dist" ) and sample_mode == "sample" :
119
+ encoder_output .latent_dist .mean = (encoder_output .latent_dist .mean - latents_mean ) * latents_std
120
+ encoder_output .latent_dist .logvar = torch .clamp (
121
+ (encoder_output .latent_dist .logvar - latents_mean ) * latents_std , - 30.0 , 20.0
122
+ )
123
+ encoder_output .latent_dist .std = torch .exp (0.5 * encoder_output .latent_dist .logvar )
124
+ encoder_output .latent_dist .var = torch .exp (encoder_output .latent_dist .logvar )
115
125
return encoder_output .latent_dist .sample (generator )
116
126
elif hasattr (encoder_output , "latent_dist" ) and sample_mode == "argmax" :
127
+ encoder_output .latent_dist .mean = (encoder_output .latent_dist .mean - latents_mean ) * latents_std
128
+ encoder_output .latent_dist .logvar = torch .clamp (
129
+ (encoder_output .latent_dist .logvar - latents_mean ) * latents_std , - 30.0 , 20.0
130
+ )
131
+ encoder_output .latent_dist .std = torch .exp (0.5 * encoder_output .latent_dist .logvar )
132
+ encoder_output .latent_dist .var = torch .exp (encoder_output .latent_dist .logvar )
117
133
return encoder_output .latent_dist .mode ()
118
134
elif hasattr (encoder_output , "latents" ):
119
- return encoder_output .latents
135
+ return ( encoder_output .latents - latents_mean ) * latents_std
120
136
else :
121
137
raise AttributeError ("Could not access latents of provided encoder_output" )
122
138
@@ -155,7 +171,7 @@ def __init__(
155
171
self ,
156
172
tokenizer : AutoTokenizer ,
157
173
text_encoder : UMT5EncoderModel ,
158
- image_encoder : CLIPVisionModel ,
174
+ image_encoder : CLIPVisionModelWithProjection ,
159
175
image_processor : CLIPImageProcessor ,
160
176
transformer : WanTransformer3DModel ,
161
177
vae : AutoencoderKLWan ,
@@ -385,13 +401,6 @@ def prepare_latents(
385
401
)
386
402
video_condition = video_condition .to (device = device , dtype = dtype )
387
403
388
- if isinstance (generator , list ):
389
- latent_condition = [retrieve_latents (self .vae .encode (video_condition ), g ) for g in generator ]
390
- latents = latent_condition = torch .cat (latent_condition )
391
- else :
392
- latent_condition = retrieve_latents (self .vae .encode (video_condition ), generator )
393
- latent_condition = latent_condition .repeat (batch_size , 1 , 1 , 1 , 1 )
394
-
395
404
latents_mean = (
396
405
torch .tensor (self .vae .config .latents_mean )
397
406
.view (1 , self .vae .config .z_dim , 1 , 1 , 1 )
@@ -401,7 +410,14 @@ def prepare_latents(
401
410
latents .device , latents .dtype
402
411
)
403
412
404
- latent_condition = (latent_condition - latents_mean ) * latents_std
413
+ if isinstance (generator , list ):
414
+ latent_condition = [
415
+ retrieve_latents (self .vae .encode (video_condition ), latents_mean , latents_std , g ) for g in generator
416
+ ]
417
+ latent_condition = torch .cat (latent_condition )
418
+ else :
419
+ latent_condition = retrieve_latents (self .vae .encode (video_condition ), latents_mean , latents_std , generator )
420
+ latent_condition = latent_condition .repeat (batch_size , 1 , 1 , 1 , 1 )
405
421
406
422
mask_lat_size = torch .ones (batch_size , 1 , num_frames , latent_height , latent_width )
407
423
mask_lat_size [:, :, list (range (1 , num_frames ))] = 0
0 commit comments