|
62 | 62 | "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
|
63 | 63 | "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
|
64 | 64 | "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",
|
65 |
| - "controlnet": "control_model.time_embed.0.weight", |
| 65 | + "controlnet": [ |
| 66 | + "control_model.time_embed.0.weight", |
| 67 | + "controlnet_cond_embedding.conv_in.weight", |
| 68 | + ], |
| 69 | + # TODO: find non-Diffusers keys for controlnet_xl |
| 70 | + "controlnet_xl": "add_embedding.linear_1.weight", |
| 71 | + "controlnet_xl_large": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", |
| 72 | + "controlnet_xl_mid": "down_blocks.1.attentions.0.norm.weight", |
66 | 73 | "playground-v2-5": "edm_mean",
|
67 | 74 | "inpainting": "model.diffusion_model.input_blocks.0.0.weight",
|
68 | 75 | "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
|
96 | 103 | "inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"},
|
97 | 104 | "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
|
98 | 105 | "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
|
| 106 | + "controlnet_xl_large": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0"}, |
| 107 | + "controlnet_xl_mid": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-mid"}, |
| 108 | + "controlnet_xl_small": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-small"}, |
99 | 109 | "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
|
100 | 110 | "v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"},
|
101 | 111 | "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
|
@@ -481,8 +491,16 @@ def infer_diffusers_model_type(checkpoint):
|
481 | 491 | elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
|
482 | 492 | model_type = "upscale"
|
483 | 493 |
|
484 |
| - elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint: |
485 |
| - model_type = "controlnet" |
| 494 | + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]): |
| 495 | + if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint: |
| 496 | + if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint: |
| 497 | + model_type = "controlnet_xl_large" |
| 498 | + elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint: |
| 499 | + model_type = "controlnet_xl_mid" |
| 500 | + else: |
| 501 | + model_type = "controlnet_xl_small" |
| 502 | + else: |
| 503 | + model_type = "controlnet" |
486 | 504 |
|
487 | 505 | elif (
|
488 | 506 | CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
|
@@ -1072,6 +1090,9 @@ def convert_controlnet_checkpoint(
|
1072 | 1090 | config,
|
1073 | 1091 | **kwargs,
|
1074 | 1092 | ):
|
| 1093 | + # Return checkpoint if it's already been converted |
| 1094 | + if "time_embedding.linear_1.weight" in checkpoint: |
| 1095 | + return checkpoint |
1075 | 1096 | # Some controlnet ckpt files are distributed independently from the rest of the
|
1076 | 1097 | # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
|
1077 | 1098 | if "time_embed.0.weight" in checkpoint:
|
|
0 commit comments