Skip to content

Allow users to save SDXL LoRA weights for only one text encoder #7607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,9 @@ def save_lora_weights(
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
Comment on lines +1409 to +1411
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's there:

text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,

You need to look at the right class:
https://github.com/huggingface/diffusers/blob/7e39516627c69b71f8b21a2b53689028d4733b72/src/diffusers/loaders/lora.py#L1288C7-L1288C39

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is what I aimed to fix.

Just added the missing parameter in the documentation (to match the signature) and split the following if condition:

if text_encoder_lora_layers and text_encoder_2_lora_layers:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining! LGTM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into it in record time :)
Cheers

is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
Expand All @@ -1432,8 +1435,10 @@ def pack_weights(layers, prefix):
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))

if text_encoder_lora_layers and text_encoder_2_lora_layers:
if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))

if text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))

cls.write_lora_layers(
Expand Down
Loading