Skip to content

Commit 93dfaec

Browse files
DN6sayakpaul
authored andcommitted
[Single File] Add single file support for AutoencoderDC (#10183)
* update * update * update
1 parent d1411b3 commit 93dfaec

File tree

4 files changed

+243
-0
lines changed

4 files changed

+243
-0
lines changed

docs/source/en/api/models/autoencoder_dc.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,26 @@ from diffusers import AutoencoderDC
3737
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32).to("cuda")
3838
```
3939

40+
## Load a model in Diffusers via `from_single_file`
41+
42+
```python
43+
from difusers import AutoencoderDC
44+
45+
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
46+
model = AutoencoderDC.from_single_file(ckpt_path)
47+
48+
```
49+
50+
The `AutoencoderDC` model has `in` and `mix` single file checkpoint variants that have matching checkpoint keys, but use different scaling factors. It is not possible for Diffusers to automatically infer the correct config file to use with the model based on just the checkpoint and will default to configuring the model using the `mix` variant config file. To override the automatically determined config, please use the `config` argument when using single file loading with `in` variant checkpoints.
51+
52+
```python
53+
from diffusers import AutoencoderDC
54+
55+
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0/blob/main/model.safetensors"
56+
model = AutoencoderDC.from_single_file(ckpt_path, config="mit-han-lab/dc-ae-f128c512-in-1.0-diffusers")
57+
```
58+
59+
4060
## AutoencoderDC
4161

4262
[[autodoc]] AutoencoderDC

src/diffusers/loaders/single_file_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .single_file_utils import (
2424
SingleFileComponentError,
2525
convert_animatediff_checkpoint_to_diffusers,
26+
convert_autoencoder_dc_checkpoint_to_diffusers,
2627
convert_controlnet_checkpoint,
2728
convert_flux_transformer_checkpoint_to_diffusers,
2829
convert_ldm_unet_checkpoint,
@@ -82,6 +83,7 @@
8283
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
8384
"default_subfolder": "transformer",
8485
},
86+
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
8587
}
8688

8789

src/diffusers/loaders/single_file_utils.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292
"double_blocks.0.img_attn.norm.key_norm.scale",
9393
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
9494
],
95+
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
96+
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
9597
}
9698

9799
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -138,6 +140,10 @@
138140
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
139141
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
140142
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
143+
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
144+
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
145+
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
146+
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
141147
}
142148

143149
# Use to configure model sample size when original config is provided
@@ -564,6 +570,23 @@ def infer_diffusers_model_type(checkpoint):
564570
model_type = "flux-dev"
565571
else:
566572
model_type = "flux-schnell"
573+
574+
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
575+
encoder_key = "encoder.project_in.conv.conv.bias"
576+
decoder_key = "decoder.project_in.main.conv.weight"
577+
578+
if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint:
579+
model_type = "autoencoder-dc-f32c32-sana"
580+
581+
elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32:
582+
model_type = "autoencoder-dc-f32c32"
583+
584+
elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128:
585+
model_type = "autoencoder-dc-f64c128"
586+
587+
else:
588+
model_type = "autoencoder-dc-f128c512"
589+
567590
else:
568591
model_type = "v1"
569592

@@ -2198,3 +2221,75 @@ def swap_scale_shift(weight):
21982221
)
21992222

22002223
return converted_state_dict
2224+
2225+
2226+
def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
2227+
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
2228+
2229+
def remap_qkv_(key: str, state_dict):
2230+
qkv = state_dict.pop(key)
2231+
q, k, v = torch.chunk(qkv, 3, dim=0)
2232+
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
2233+
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
2234+
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
2235+
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
2236+
2237+
def remap_proj_conv_(key: str, state_dict):
2238+
parent_module, _, _ = key.rpartition(".proj.conv.weight")
2239+
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
2240+
2241+
AE_KEYS_RENAME_DICT = {
2242+
# common
2243+
"main.": "",
2244+
"op_list.": "",
2245+
"context_module": "attn",
2246+
"local_module": "conv_out",
2247+
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
2248+
# If there were more scales, there would be more layers, so a loop would be better to handle this
2249+
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
2250+
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
2251+
"depth_conv.conv": "conv_depth",
2252+
"inverted_conv.conv": "conv_inverted",
2253+
"point_conv.conv": "conv_point",
2254+
"point_conv.norm": "norm",
2255+
"conv.conv.": "conv.",
2256+
"conv1.conv": "conv1",
2257+
"conv2.conv": "conv2",
2258+
"conv2.norm": "norm",
2259+
"proj.norm": "norm_out",
2260+
# encoder
2261+
"encoder.project_in.conv": "encoder.conv_in",
2262+
"encoder.project_out.0.conv": "encoder.conv_out",
2263+
"encoder.stages": "encoder.down_blocks",
2264+
# decoder
2265+
"decoder.project_in.conv": "decoder.conv_in",
2266+
"decoder.project_out.0": "decoder.norm_out",
2267+
"decoder.project_out.2.conv": "decoder.conv_out",
2268+
"decoder.stages": "decoder.up_blocks",
2269+
}
2270+
2271+
AE_F32C32_F64C128_F128C512_KEYS = {
2272+
"encoder.project_in.conv": "encoder.conv_in.conv",
2273+
"decoder.project_out.2.conv": "decoder.conv_out.conv",
2274+
}
2275+
2276+
AE_SPECIAL_KEYS_REMAP = {
2277+
"qkv.conv.weight": remap_qkv_,
2278+
"proj.conv.weight": remap_proj_conv_,
2279+
}
2280+
if "encoder.project_in.conv.bias" not in converted_state_dict:
2281+
AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS)
2282+
2283+
for key in list(converted_state_dict.keys()):
2284+
new_key = key[:]
2285+
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
2286+
new_key = new_key.replace(replace_key, rename_key)
2287+
converted_state_dict[new_key] = converted_state_dict.pop(key)
2288+
2289+
for key in list(converted_state_dict.keys()):
2290+
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
2291+
if special_key not in key:
2292+
continue
2293+
handler_fn_inplace(key, converted_state_dict)
2294+
2295+
return converted_state_dict
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import unittest
18+
19+
import torch
20+
21+
from diffusers import (
22+
AutoencoderDC,
23+
)
24+
from diffusers.utils.testing_utils import (
25+
backend_empty_cache,
26+
enable_full_determinism,
27+
load_hf_numpy,
28+
numpy_cosine_similarity_distance,
29+
require_torch_accelerator,
30+
slow,
31+
torch_device,
32+
)
33+
34+
35+
enable_full_determinism()
36+
37+
38+
@slow
39+
@require_torch_accelerator
40+
class AutoencoderDCSingleFileTests(unittest.TestCase):
41+
model_class = AutoencoderDC
42+
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
43+
repo_id = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"
44+
main_input_name = "sample"
45+
base_precision = 1e-2
46+
47+
def setUp(self):
48+
super().setUp()
49+
gc.collect()
50+
backend_empty_cache(torch_device)
51+
52+
def tearDown(self):
53+
super().tearDown()
54+
gc.collect()
55+
backend_empty_cache(torch_device)
56+
57+
def get_file_format(self, seed, shape):
58+
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
59+
60+
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
61+
dtype = torch.float16 if fp16 else torch.float32
62+
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
63+
return image
64+
65+
def test_single_file_inference_same_as_pretrained(self):
66+
model_1 = self.model_class.from_pretrained(self.repo_id).to(torch_device)
67+
model_2 = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id).to(torch_device)
68+
69+
image = self.get_sd_image(33)
70+
71+
with torch.no_grad():
72+
sample_1 = model_1(image).sample
73+
sample_2 = model_2(image).sample
74+
75+
assert sample_1.shape == sample_2.shape
76+
77+
output_slice_1 = sample_1.flatten().float().cpu()
78+
output_slice_2 = sample_2.flatten().float().cpu()
79+
80+
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
81+
82+
def test_single_file_components(self):
83+
model = self.model_class.from_pretrained(self.repo_id)
84+
model_single_file = self.model_class.from_single_file(self.ckpt_path)
85+
86+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
87+
for param_name, param_value in model_single_file.config.items():
88+
if param_name in PARAMS_TO_IGNORE:
89+
continue
90+
assert (
91+
model.config[param_name] == param_value
92+
), f"{param_name} differs between pretrained loading and single file loading"
93+
94+
def test_single_file_in_type_variant_components(self):
95+
# `in` variant checkpoints require passing in a `config` parameter
96+
# in order to set the scaling factor correctly.
97+
# `in` and `mix` variants have the same keys and we cannot automatically infer a scaling factor.
98+
# We default to using teh `mix` config
99+
repo_id = "mit-han-lab/dc-ae-f128c512-in-1.0-diffusers"
100+
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0/blob/main/model.safetensors"
101+
102+
model = self.model_class.from_pretrained(repo_id)
103+
model_single_file = self.model_class.from_single_file(ckpt_path, config=repo_id)
104+
105+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
106+
for param_name, param_value in model_single_file.config.items():
107+
if param_name in PARAMS_TO_IGNORE:
108+
continue
109+
assert (
110+
model.config[param_name] == param_value
111+
), f"{param_name} differs between pretrained loading and single file loading"
112+
113+
def test_single_file_mix_type_variant_components(self):
114+
repo_id = "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"
115+
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0/blob/main/model.safetensors"
116+
117+
model = self.model_class.from_pretrained(repo_id)
118+
model_single_file = self.model_class.from_single_file(ckpt_path, config=repo_id)
119+
120+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
121+
for param_name, param_value in model_single_file.config.items():
122+
if param_name in PARAMS_TO_IGNORE:
123+
continue
124+
assert (
125+
model.config[param_name] == param_value
126+
), f"{param_name} differs between pretrained loading and single file loading"

0 commit comments

Comments
 (0)