Skip to content

Commit d55f411

Browse files
yupeng1111github-actions[bot]yiyixuxu
authored
fix wan i2v pipeline bugs (#10975)
* fix wan i2v pipeline bugs --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: YiYi Xu <[email protected]>
1 parent 748cb0f commit d55f411

File tree

2 files changed

+35
-14
lines changed

2 files changed

+35
-14
lines changed

src/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,30 @@
4545
Examples:
4646
```python
4747
>>> import torch
48-
>>> from diffusers import AutoencoderKLWan, WanPipeline
4948
>>> from diffusers.utils import export_to_video
49+
>>> from diffusers import AutoencoderKLWan, WanPipeline
50+
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
5051
5152
>>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
5253
>>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
5354
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
5455
>>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
56+
>>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
57+
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
5558
>>> pipe.to("cuda")
5659
57-
>>> prompt = "A cat walks on the grass, realistic"
60+
>>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
5861
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
5962
6063
>>> output = pipe(
6164
... prompt=prompt,
6265
... negative_prompt=negative_prompt,
63-
... height=480,
64-
... width=832,
66+
... height=720,
67+
... width=1280,
6568
... num_frames=81,
6669
... guidance_scale=5.0,
6770
... ).frames[0]
68-
>>> export_to_video(output, "output.mp4", fps=15)
71+
>>> export_to_video(output, "output.mp4", fps=16)
6972
```
7073
"""
7174

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import PIL
2020
import regex as re
2121
import torch
22-
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel
22+
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
2323

2424
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2525
from ...image_processor import PipelineImageInput
@@ -46,29 +46,47 @@
4646
Examples:
4747
```python
4848
>>> import torch
49+
>>> import numpy as np
4950
>>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
5051
>>> from diffusers.utils import export_to_video, load_image
52+
>>> from transformers import CLIPVisionModel
5153
52-
>>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-1.3B-720P-Diffusers
54+
>>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
5355
>>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
56+
>>> image_encoder = CLIPVisionModel.from_pretrained(
57+
... model_id, subfolder="image_encoder", torch_dtype=torch.float32
58+
... )
5459
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
55-
>>> pipe = WanImageToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
60+
>>> pipe = WanImageToVideoPipeline.from_pretrained(
61+
... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
62+
... )
5663
>>> pipe.to("cuda")
5764
58-
>>> height, width = 480, 832
5965
>>> image = load_image(
6066
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
61-
... ).resize((width, height))
67+
... )
68+
>>> max_area = 480 * 832
69+
>>> aspect_ratio = image.height / image.width
70+
>>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
71+
>>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
72+
>>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
73+
>>> image = image.resize((width, height))
6274
>>> prompt = (
6375
... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
6476
... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
6577
... )
6678
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
6779
6880
>>> output = pipe(
69-
... image=image, prompt=prompt, negative_prompt=negative_prompt, num_frames=81, guidance_scale=5.0
81+
... image=image,
82+
... prompt=prompt,
83+
... negative_prompt=negative_prompt,
84+
... height=height,
85+
... width=width,
86+
... num_frames=81,
87+
... guidance_scale=5.0,
7088
... ).frames[0]
71-
>>> export_to_video(output, "output.mp4", fps=15)
89+
>>> export_to_video(output, "output.mp4", fps=16)
7290
```
7391
"""
7492

@@ -137,7 +155,7 @@ def __init__(
137155
self,
138156
tokenizer: AutoTokenizer,
139157
text_encoder: UMT5EncoderModel,
140-
image_encoder: CLIPVisionModelWithProjection,
158+
image_encoder: CLIPVisionModel,
141159
image_processor: CLIPImageProcessor,
142160
transformer: WanTransformer3DModel,
143161
vae: AutoencoderKLWan,
@@ -204,7 +222,7 @@ def _get_t5_prompt_embeds(
204222
def encode_image(self, image: PipelineImageInput):
205223
image = self.image_processor(images=image, return_tensors="pt").to(self.device)
206224
image_embeds = self.image_encoder(**image, output_hidden_states=True)
207-
return image_embeds.hidden_states[-1]
225+
return image_embeds.hidden_states[-2]
208226

209227
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
210228
def encode_prompt(

0 commit comments

Comments
 (0)