Skip to content

Commit 75aab34

Browse files
dulacpsayakpaulyiyixuxu
authored
Allow users to save SDXL LoRA weights for only one text encoder (#7607)
SDXL LoRA weights for text encoders should be decoupled on save The method checks if at least one of unet, text_encoder and text_encoder_2 lora weights are passed, which was not reflected in the implentation. Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 35358a2 commit 75aab34

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/diffusers/loaders/lora.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1406,6 +1406,9 @@ def save_lora_weights(
14061406
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
14071407
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
14081408
encoder LoRA state dict because it comes from 🤗 Transformers.
1409+
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1410+
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
1411+
encoder LoRA state dict because it comes from 🤗 Transformers.
14091412
is_main_process (`bool`, *optional*, defaults to `True`):
14101413
Whether the process calling this is the main process or not. Useful during distributed training and you
14111414
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
@@ -1432,8 +1435,10 @@ def pack_weights(layers, prefix):
14321435
if unet_lora_layers:
14331436
state_dict.update(pack_weights(unet_lora_layers, "unet"))
14341437

1435-
if text_encoder_lora_layers and text_encoder_2_lora_layers:
1438+
if text_encoder_lora_layers:
14361439
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
1440+
1441+
if text_encoder_2_lora_layers:
14371442
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
14381443

14391444
cls.write_lora_layers(

0 commit comments

Comments
 (0)