Skip to content

Commit 1357931

Browse files
authored
[Single File] Add single file support for Wan T2V/I2V (#10991)
* update * update * update * update * update * update * update
1 parent a2d3d6a commit 1357931

File tree

8 files changed

+518
-50
lines changed

8 files changed

+518
-50
lines changed

docs/source/en/api/pipelines/wan.md

+16
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler
4545
pipe.scheduler = <CUSTOM_SCHEDULER_HERE>
4646
```
4747

48+
### Using single file loading with Wan
49+
50+
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
51+
method.
52+
53+
54+
```python
55+
import torch
56+
from diffusers import WanPipeline, WanTransformer3DModel
57+
58+
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
59+
transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
60+
61+
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
62+
```
63+
4864
## WanPipeline
4965

5066
[[autodoc]] WanPipeline

src/diffusers/loaders/single_file_model.py

+10
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
convert_mochi_transformer_checkpoint_to_diffusers,
4040
convert_sd3_transformer_checkpoint_to_diffusers,
4141
convert_stable_cascade_unet_single_file_to_diffusers,
42+
convert_wan_transformer_to_diffusers,
43+
convert_wan_vae_to_diffusers,
4244
create_controlnet_diffusers_config_from_ldm,
4345
create_unet_diffusers_config_from_ldm,
4446
create_vae_diffusers_config_from_ldm,
@@ -117,6 +119,14 @@
117119
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
118120
"default_subfolder": "transformer",
119121
},
122+
"WanTransformer3DModel": {
123+
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
124+
"default_subfolder": "transformer",
125+
},
126+
"AutoencoderKLWan": {
127+
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
128+
"default_subfolder": "vae",
129+
},
120130
}
121131

122132

src/diffusers/loaders/single_file_utils.py

+330-45
Large diffs are not rendered by default.

src/diffusers/models/attention_processor.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,9 @@ def __init__(
284284
self.norm_added_q = RMSNorm(dim_head, eps=eps)
285285
self.norm_added_k = RMSNorm(dim_head, eps=eps)
286286
elif qk_norm == "rms_norm_across_heads":
287-
# Wanx applies qk norm across all heads
288-
self.norm_added_q = RMSNorm(dim_head * heads, eps=eps)
287+
# Wan applies qk norm across all heads
288+
# Wan also doesn't apply a q norm
289+
self.norm_added_q = None
289290
self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
290291
else:
291292
raise ValueError(

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.utils.checkpoint
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
23+
from ...loaders import FromOriginalModelMixin
2324
from ...utils import logging
2425
from ...utils.accelerate_utils import apply_forward_hook
2526
from ..activations import get_activation
@@ -655,7 +656,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
655656
return x
656657

657658

658-
class AutoencoderKLWan(ModelMixin, ConfigMixin):
659+
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
659660
r"""
660661
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
661662
Introduced in [Wan 2.1].

src/diffusers/models/transformers/transformer_wan.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.nn.functional as F
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
23-
from ...loaders import PeftAdapterMixin
23+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
2424
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2525
from ..attention import FeedForward
2626
from ..attention_processor import Attention
@@ -288,7 +288,7 @@ def forward(
288288
return hidden_states
289289

290290

291-
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
291+
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
292292
r"""
293293
A Transformer model for video-like data used in the Wan model.
294294
@@ -329,6 +329,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
329329
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
330330
_no_split_modules = ["WanTransformerBlock"]
331331
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
332+
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
332333

333334
@register_to_config
334335
def __init__(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import unittest
18+
19+
from diffusers import (
20+
AutoencoderKLWan,
21+
)
22+
from diffusers.utils.testing_utils import (
23+
backend_empty_cache,
24+
enable_full_determinism,
25+
require_torch_accelerator,
26+
torch_device,
27+
)
28+
29+
30+
enable_full_determinism()
31+
32+
33+
@require_torch_accelerator
34+
class AutoencoderKLWanSingleFileTests(unittest.TestCase):
35+
model_class = AutoencoderKLWan
36+
ckpt_path = (
37+
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
38+
)
39+
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
40+
41+
def setUp(self):
42+
super().setUp()
43+
gc.collect()
44+
backend_empty_cache(torch_device)
45+
46+
def tearDown(self):
47+
super().tearDown()
48+
gc.collect()
49+
backend_empty_cache(torch_device)
50+
51+
def test_single_file_components(self):
52+
model = self.model_class.from_pretrained(self.repo_id, subfolder="vae")
53+
model_single_file = self.model_class.from_single_file(self.ckpt_path)
54+
55+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
56+
for param_name, param_value in model_single_file.config.items():
57+
if param_name in PARAMS_TO_IGNORE:
58+
continue
59+
assert (
60+
model.config[param_name] == param_value
61+
), f"{param_name} differs between single file loading and pretrained loading"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import unittest
18+
19+
import torch
20+
21+
from diffusers import (
22+
WanTransformer3DModel,
23+
)
24+
from diffusers.utils.testing_utils import (
25+
backend_empty_cache,
26+
enable_full_determinism,
27+
require_big_gpu_with_torch_cuda,
28+
require_torch_accelerator,
29+
torch_device,
30+
)
31+
32+
33+
enable_full_determinism()
34+
35+
36+
@require_torch_accelerator
37+
class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase):
38+
model_class = WanTransformer3DModel
39+
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
40+
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
41+
42+
def setUp(self):
43+
super().setUp()
44+
gc.collect()
45+
backend_empty_cache(torch_device)
46+
47+
def tearDown(self):
48+
super().tearDown()
49+
gc.collect()
50+
backend_empty_cache(torch_device)
51+
52+
def test_single_file_components(self):
53+
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
54+
model_single_file = self.model_class.from_single_file(self.ckpt_path)
55+
56+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
57+
for param_name, param_value in model_single_file.config.items():
58+
if param_name in PARAMS_TO_IGNORE:
59+
continue
60+
assert (
61+
model.config[param_name] == param_value
62+
), f"{param_name} differs between single file loading and pretrained loading"
63+
64+
65+
@require_big_gpu_with_torch_cuda
66+
@require_torch_accelerator
67+
class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase):
68+
model_class = WanTransformer3DModel
69+
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors"
70+
repo_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
71+
torch_dtype = torch.float8_e4m3fn
72+
73+
def setUp(self):
74+
super().setUp()
75+
gc.collect()
76+
backend_empty_cache(torch_device)
77+
78+
def tearDown(self):
79+
super().tearDown()
80+
gc.collect()
81+
backend_empty_cache(torch_device)
82+
83+
def test_single_file_components(self):
84+
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer", torch_dtype=self.torch_dtype)
85+
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=self.torch_dtype)
86+
87+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
88+
for param_name, param_value in model_single_file.config.items():
89+
if param_name in PARAMS_TO_IGNORE:
90+
continue
91+
assert (
92+
model.config[param_name] == param_value
93+
), f"{param_name} differs between single file loading and pretrained loading"

0 commit comments

Comments
 (0)