Skip to content

Commit 41375d3

Browse files
committed
Wan Pipeline scaling fix, type hint warning, multi generator fix
1 parent b38450d commit 41375d3

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import PIL
2020
import regex as re
2121
import torch
22-
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
22+
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel
2323

2424
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2525
from ...image_processor import PipelineImageInput
@@ -49,11 +49,11 @@
4949
>>> import numpy as np
5050
>>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
5151
>>> from diffusers.utils import export_to_video, load_image
52-
>>> from transformers import CLIPVisionModel
52+
>>> from transformers import CLIPVisionModelWithProjection
5353
5454
>>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
5555
>>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
56-
>>> image_encoder = CLIPVisionModel.from_pretrained(
56+
>>> image_encoder = CLIPVisionModelWithProjection.from_pretrained(
5757
... model_id, subfolder="image_encoder", torch_dtype=torch.float32
5858
... )
5959
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
@@ -109,14 +109,30 @@ def prompt_clean(text):
109109

110110

111111
def retrieve_latents(
112-
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
112+
encoder_output: torch.Tensor,
113+
latents_mean: torch.Tensor,
114+
latents_std: torch.Tensor,
115+
generator: Optional[torch.Generator] = None,
116+
sample_mode: str = "sample",
113117
):
114118
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
119+
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
120+
encoder_output.latent_dist.logvar = torch.clamp(
121+
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
122+
)
123+
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
124+
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
115125
return encoder_output.latent_dist.sample(generator)
116126
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
127+
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
128+
encoder_output.latent_dist.logvar = torch.clamp(
129+
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
130+
)
131+
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
132+
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
117133
return encoder_output.latent_dist.mode()
118134
elif hasattr(encoder_output, "latents"):
119-
return encoder_output.latents
135+
return (encoder_output.latents - latents_mean) * latents_std
120136
else:
121137
raise AttributeError("Could not access latents of provided encoder_output")
122138

@@ -155,7 +171,7 @@ def __init__(
155171
self,
156172
tokenizer: AutoTokenizer,
157173
text_encoder: UMT5EncoderModel,
158-
image_encoder: CLIPVisionModel,
174+
image_encoder: CLIPVisionModelWithProjection,
159175
image_processor: CLIPImageProcessor,
160176
transformer: WanTransformer3DModel,
161177
vae: AutoencoderKLWan,
@@ -385,13 +401,6 @@ def prepare_latents(
385401
)
386402
video_condition = video_condition.to(device=device, dtype=dtype)
387403

388-
if isinstance(generator, list):
389-
latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator]
390-
latents = latent_condition = torch.cat(latent_condition)
391-
else:
392-
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
393-
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
394-
395404
latents_mean = (
396405
torch.tensor(self.vae.config.latents_mean)
397406
.view(1, self.vae.config.z_dim, 1, 1, 1)
@@ -401,7 +410,14 @@ def prepare_latents(
401410
latents.device, latents.dtype
402411
)
403412

404-
latent_condition = (latent_condition - latents_mean) * latents_std
413+
if isinstance(generator, list):
414+
latent_condition = [
415+
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
416+
]
417+
latent_condition = torch.cat(latent_condition)
418+
else:
419+
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
420+
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
405421

406422
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
407423
mask_lat_size[:, :, list(range(1, num_frames))] = 0

0 commit comments

Comments
 (0)