-
Notifications
You must be signed in to change notification settings - Fork 6k
Add controlnet and vae from single file #4084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
1d07807
881ddcc
b8272fa
5fd2904
b092e04
bdd138b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -621,8 +621,8 @@ def convert_ldm_unet_checkpoint( | |
def convert_ldm_vae_checkpoint(checkpoint, config): | ||
# extract state dict for VAE | ||
vae_state_dict = {} | ||
vae_key = "first_stage_model." | ||
keys = list(checkpoint.keys()) | ||
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" | ||
for key in keys: | ||
if key.startswith(vae_key): | ||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) | ||
|
@@ -1064,7 +1064,7 @@ def convert_controlnet_checkpoint( | |
if cross_attention_dim is not None: | ||
ctrlnet_config["cross_attention_dim"] = cross_attention_dim | ||
|
||
controlnet_model = ControlNetModel(**ctrlnet_config) | ||
controlnet = ControlNetModel(**ctrlnet_config) | ||
|
||
# Some controlnet ckpt files are distributed independently from the rest of the | ||
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ | ||
|
@@ -1082,9 +1082,9 @@ def convert_controlnet_checkpoint( | |
skip_extract_state_dict=skip_extract_state_dict, | ||
) | ||
|
||
controlnet_model.load_state_dict(converted_ctrl_checkpoint) | ||
controlnet.load_state_dict(converted_ctrl_checkpoint) | ||
|
||
return controlnet_model | ||
return controlnet | ||
|
||
|
||
def download_from_original_stable_diffusion_ckpt( | ||
|
@@ -1182,7 +1182,7 @@ def download_from_original_stable_diffusion_ckpt( | |
) | ||
|
||
if pipeline_class is None: | ||
pipeline_class = StableDiffusionPipeline | ||
pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline | ||
|
||
if prediction_type == "v-prediction": | ||
prediction_type = "v_prediction" | ||
|
@@ -1289,8 +1289,7 @@ def download_from_original_stable_diffusion_ckpt( | |
if controlnet is None: | ||
controlnet = "control_stage_config" in original_config.model.params | ||
|
||
if controlnet: | ||
controlnet_model = convert_controlnet_checkpoint( | ||
controlnet = convert_controlnet_checkpoint( | ||
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema | ||
) | ||
Comment on lines
-1292
to
1294
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. iamwavecut#1 this might have introduced a bug @patrickvonplaten There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It did indeed haha. Solved as explained in iamwavecut#1 |
||
|
||
|
@@ -1401,13 +1400,13 @@ def download_from_original_stable_diffusion_ckpt( | |
|
||
if stable_unclip is None: | ||
if controlnet: | ||
pipe = StableDiffusionControlNetPipeline( | ||
pipe = pipeline_class( | ||
vae=vae, | ||
text_encoder=text_model, | ||
tokenizer=tokenizer, | ||
unet=unet, | ||
scheduler=scheduler, | ||
controlnet=controlnet_model, | ||
controlnet=controlnet, | ||
safety_checker=None, | ||
feature_extractor=None, | ||
requires_safety_checker=False, | ||
|
@@ -1504,12 +1503,12 @@ def download_from_original_stable_diffusion_ckpt( | |
feature_extractor = None | ||
|
||
if controlnet: | ||
pipe = StableDiffusionControlNetPipeline( | ||
pipe = pipeline_class( | ||
vae=vae, | ||
text_encoder=text_model, | ||
tokenizer=tokenizer, | ||
unet=unet, | ||
controlnet=controlnet_model, | ||
controlnet=controlnet, | ||
scheduler=scheduler, | ||
safety_checker=safety_checker, | ||
feature_extractor=feature_extractor, | ||
|
@@ -1624,7 +1623,7 @@ def download_controlnet_from_original_ckpt( | |
if "control_stage_config" not in original_config.model.params: | ||
raise ValueError("`control_stage_config` not present in original config") | ||
|
||
controlnet_model = convert_controlnet_checkpoint( | ||
controlnet = convert_controlnet_checkpoint( | ||
checkpoint, | ||
original_config, | ||
checkpoint_path, | ||
|
@@ -1635,4 +1634,4 @@ def download_controlnet_from_original_ckpt( | |
cross_attention_dim=cross_attention_dim, | ||
) | ||
|
||
return controlnet_model | ||
return controlnet |
Uh oh!
There was an error while loading. Please reload this page.