|
92 | 92 | "double_blocks.0.img_attn.norm.key_norm.scale",
|
93 | 93 | "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
94 | 94 | ],
|
| 95 | + "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", |
| 96 | + "autoencoder-dc-sana": "encoder.project_in.conv.bias", |
95 | 97 | }
|
96 | 98 |
|
97 | 99 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
138 | 140 | "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
|
139 | 141 | "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
|
140 | 142 | "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
| 143 | + "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"}, |
| 144 | + "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, |
| 145 | + "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, |
| 146 | + "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"}, |
141 | 147 | }
|
142 | 148 |
|
143 | 149 | # Use to configure model sample size when original config is provided
|
@@ -564,6 +570,23 @@ def infer_diffusers_model_type(checkpoint):
|
564 | 570 | model_type = "flux-dev"
|
565 | 571 | else:
|
566 | 572 | model_type = "flux-schnell"
|
| 573 | + |
| 574 | + elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint: |
| 575 | + encoder_key = "encoder.project_in.conv.conv.bias" |
| 576 | + decoder_key = "decoder.project_in.main.conv.weight" |
| 577 | + |
| 578 | + if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint: |
| 579 | + model_type = "autoencoder-dc-f32c32-sana" |
| 580 | + |
| 581 | + elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32: |
| 582 | + model_type = "autoencoder-dc-f32c32" |
| 583 | + |
| 584 | + elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128: |
| 585 | + model_type = "autoencoder-dc-f64c128" |
| 586 | + |
| 587 | + else: |
| 588 | + model_type = "autoencoder-dc-f128c512" |
| 589 | + |
567 | 590 | else:
|
568 | 591 | model_type = "v1"
|
569 | 592 |
|
@@ -2198,3 +2221,75 @@ def swap_scale_shift(weight):
|
2198 | 2221 | )
|
2199 | 2222 |
|
2200 | 2223 | return converted_state_dict
|
| 2224 | + |
| 2225 | + |
| 2226 | +def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs): |
| 2227 | + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} |
| 2228 | + |
| 2229 | + def remap_qkv_(key: str, state_dict): |
| 2230 | + qkv = state_dict.pop(key) |
| 2231 | + q, k, v = torch.chunk(qkv, 3, dim=0) |
| 2232 | + parent_module, _, _ = key.rpartition(".qkv.conv.weight") |
| 2233 | + state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() |
| 2234 | + state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() |
| 2235 | + state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() |
| 2236 | + |
| 2237 | + def remap_proj_conv_(key: str, state_dict): |
| 2238 | + parent_module, _, _ = key.rpartition(".proj.conv.weight") |
| 2239 | + state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() |
| 2240 | + |
| 2241 | + AE_KEYS_RENAME_DICT = { |
| 2242 | + # common |
| 2243 | + "main.": "", |
| 2244 | + "op_list.": "", |
| 2245 | + "context_module": "attn", |
| 2246 | + "local_module": "conv_out", |
| 2247 | + # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1 |
| 2248 | + # If there were more scales, there would be more layers, so a loop would be better to handle this |
| 2249 | + "aggreg.0.0": "to_qkv_multiscale.0.proj_in", |
| 2250 | + "aggreg.0.1": "to_qkv_multiscale.0.proj_out", |
| 2251 | + "depth_conv.conv": "conv_depth", |
| 2252 | + "inverted_conv.conv": "conv_inverted", |
| 2253 | + "point_conv.conv": "conv_point", |
| 2254 | + "point_conv.norm": "norm", |
| 2255 | + "conv.conv.": "conv.", |
| 2256 | + "conv1.conv": "conv1", |
| 2257 | + "conv2.conv": "conv2", |
| 2258 | + "conv2.norm": "norm", |
| 2259 | + "proj.norm": "norm_out", |
| 2260 | + # encoder |
| 2261 | + "encoder.project_in.conv": "encoder.conv_in", |
| 2262 | + "encoder.project_out.0.conv": "encoder.conv_out", |
| 2263 | + "encoder.stages": "encoder.down_blocks", |
| 2264 | + # decoder |
| 2265 | + "decoder.project_in.conv": "decoder.conv_in", |
| 2266 | + "decoder.project_out.0": "decoder.norm_out", |
| 2267 | + "decoder.project_out.2.conv": "decoder.conv_out", |
| 2268 | + "decoder.stages": "decoder.up_blocks", |
| 2269 | + } |
| 2270 | + |
| 2271 | + AE_F32C32_F64C128_F128C512_KEYS = { |
| 2272 | + "encoder.project_in.conv": "encoder.conv_in.conv", |
| 2273 | + "decoder.project_out.2.conv": "decoder.conv_out.conv", |
| 2274 | + } |
| 2275 | + |
| 2276 | + AE_SPECIAL_KEYS_REMAP = { |
| 2277 | + "qkv.conv.weight": remap_qkv_, |
| 2278 | + "proj.conv.weight": remap_proj_conv_, |
| 2279 | + } |
| 2280 | + if "encoder.project_in.conv.bias" not in converted_state_dict: |
| 2281 | + AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS) |
| 2282 | + |
| 2283 | + for key in list(converted_state_dict.keys()): |
| 2284 | + new_key = key[:] |
| 2285 | + for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): |
| 2286 | + new_key = new_key.replace(replace_key, rename_key) |
| 2287 | + converted_state_dict[new_key] = converted_state_dict.pop(key) |
| 2288 | + |
| 2289 | + for key in list(converted_state_dict.keys()): |
| 2290 | + for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): |
| 2291 | + if special_key not in key: |
| 2292 | + continue |
| 2293 | + handler_fn_inplace(key, converted_state_dict) |
| 2294 | + |
| 2295 | + return converted_state_dict |
0 commit comments