Skip to content

Commit d486f0e

Browse files
[LoRA serialization] fix: duplicate unet prefix problem. (#5991)
* fix: duplicate unet prefix problem. * Update src/diffusers/loaders/lora.py Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 3351270 commit d486f0e

File tree

2 files changed

+18
-21
lines changed

2 files changed

+18
-21
lines changed

src/diffusers/loaders/lora.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,10 @@ def load_lora_into_unet(
391391
# their prefixes.
392392
keys = list(state_dict.keys())
393393

394+
if all(key.startswith("unet.unet") for key in keys):
395+
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
396+
deprecate("unet.unet keys", "0.27", deprecation_message)
397+
394398
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
395399
# Load the layers corresponding to UNet.
396400
logger.info(f"Loading {cls.unet_name}.")
@@ -407,8 +411,9 @@ def load_lora_into_unet(
407411
else:
408412
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
409413
# contain the module names of the `unet` as its keys WITHOUT any prefix.
410-
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
411-
logger.warn(warn_message)
414+
if not USE_PEFT_BACKEND:
415+
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
416+
logger.warn(warn_message)
412417

413418
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
414419
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
@@ -800,29 +805,21 @@ def save_lora_weights(
800805
safe_serialization (`bool`, *optional*, defaults to `True`):
801806
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
802807
"""
803-
# Create a flat dictionary.
804808
state_dict = {}
805809

806-
# Populate the dictionary.
807-
if unet_lora_layers is not None:
808-
weights = (
809-
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
810-
)
810+
def pack_weights(layers, prefix):
811+
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
812+
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
813+
return layers_state_dict
811814

812-
unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()}
813-
state_dict.update(unet_lora_state_dict)
815+
if not (unet_lora_layers or text_encoder_lora_layers):
816+
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
814817

815-
if text_encoder_lora_layers is not None:
816-
weights = (
817-
text_encoder_lora_layers.state_dict()
818-
if isinstance(text_encoder_lora_layers, torch.nn.Module)
819-
else text_encoder_lora_layers
820-
)
818+
if unet_lora_layers:
819+
state_dict.update(pack_weights(unet_lora_layers, "unet"))
821820

822-
text_encoder_lora_state_dict = {
823-
f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
824-
}
825-
state_dict.update(text_encoder_lora_state_dict)
821+
if text_encoder_lora_layers:
822+
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
826823

827824
# Save the model
828825
cls.write_lora_layers(

src/diffusers/training_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
6767
current_lora_layer_sd = lora_layer.state_dict()
6868
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
6969
# The matrix name can either be "down" or "up".
70-
lora_state_dict[f"unet.{name}.lora.{lora_layer_matrix_name}"] = lora_param
70+
lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
7171

7272
return lora_state_dict
7373

0 commit comments

Comments
 (0)