Skip to content

[LoRA] Improve warning messages when LoRA loading becomes a no-op #10187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
cd88a4b
updates
sayakpaul Dec 11, 2024
b694ca4
updates
sayakpaul Dec 11, 2024
1db7503
updates
sayakpaul Dec 11, 2024
6134491
updates
sayakpaul Dec 11, 2024
db827b5
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 11, 2024
ac29785
notebooks revert
sayakpaul Dec 11, 2024
3f4a3fc
resolve conflicts
sayakpaul Dec 15, 2024
b6db978
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 19, 2024
c44f7a3
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 20, 2024
876132a
fix-copies.
sayakpaul Dec 20, 2024
bb7c09a
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 25, 2024
e6043a0
seeing
sayakpaul Dec 25, 2024
7ca7493
fix
sayakpaul Dec 25, 2024
ec44f9a
revert
sayakpaul Dec 25, 2024
343b2d2
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 25, 2024
615e372
fixes
sayakpaul Dec 25, 2024
e2e3ea0
fixes
sayakpaul Dec 25, 2024
f9dd64c
fixes
sayakpaul Dec 25, 2024
a01cb45
remove print
sayakpaul Dec 25, 2024
da96621
fix
sayakpaul Dec 25, 2024
a91138d
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Dec 27, 2024
83ad82b
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Jan 2, 2025
be187da
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Jan 5, 2025
726e492
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Jan 7, 2025
3efdc58
fix conflicts
sayakpaul Jan 13, 2025
cf50148
conflicts ii.
sayakpaul Jan 13, 2025
b2afc10
updates
sayakpaul Jan 13, 2025
96eced3
fixes
sayakpaul Jan 13, 2025
b4be719
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Jan 13, 2025
8bf1173
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Feb 9, 2025
0e43b55
Merge branch 'main' into improve-lora-warning-msg
hlky Mar 6, 2025
1e4dbbc
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Mar 9, 2025
9eb460f
better filtering of prefix.
sayakpaul Mar 10, 2025
279ee91
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Mar 10, 2025
6240876
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Mar 10, 2025
cf9027a
Merge branch 'main' into improve-lora-warning-msg
sayakpaul Mar 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 113 additions & 126 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,22 +294,15 @@ def load_lora_into_unet(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)

# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
if not only_text_encoder:
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're handling all of this within the load_lora_adapter() method now which I think is more appropriate as:

  1. It takes care of logging when users try to load LoRA into a model via pipe.load_lora_weights().
  2. Users try to load LoRAs directly into a model with the load_lora_adapter() method (with something like unet.load_lora_adapter().

Helps to avoid duplication. I have run the integration tests, too and nothing is breaking due to this.

# Load the layers corresponding to UNet.
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)

@classmethod
def load_lora_into_text_encoder(
Expand Down Expand Up @@ -462,6 +455,11 @@ def load_lora_into_text_encoder(
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />

else:
logger.info(
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new."
)

@classmethod
def save_lora_weights(
cls,
Expand Down Expand Up @@ -660,18 +658,16 @@ def load_lora_weights(
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
Comment on lines -562 to -573
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar philosophy as explained above.

self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)

text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
Expand Down Expand Up @@ -836,22 +832,15 @@ def load_lora_into_unet(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)

# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
if not only_text_encoder:
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
# Load the layers corresponding to UNet.
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
Expand Down Expand Up @@ -1005,6 +994,11 @@ def load_lora_into_text_encoder(
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />

else:
logger.info(
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new."
)

@classmethod
def save_lora_weights(
cls,
Expand Down Expand Up @@ -1288,43 +1282,35 @@ def load_lora_weights(
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
if len(transformer_state_dict) > 0:
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name)
if not hasattr(self, "transformer")
else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)

text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=None,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.load_lora_into_text_encoder(
state_dict,
network_alphas=None,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)

text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=None,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.load_lora_into_text_encoder(
state_dict,
network_alphas=None,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)

@classmethod
def load_lora_into_transformer(
Expand Down Expand Up @@ -1353,7 +1339,6 @@ def load_lora_into_transformer(
)

# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
Expand Down Expand Up @@ -1514,6 +1499,11 @@ def load_lora_into_text_encoder(
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />

else:
logger.info(
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new."
)

@classmethod
def save_lora_weights(
cls,
Expand Down Expand Up @@ -1844,7 +1834,7 @@ def load_lora_weights(
raise ValueError("Invalid LoRA checkpoint.")

transformer_lora_state_dict = {
k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
k: state_dict.get(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
}
transformer_norm_state_dict = {
k: state_dict.pop(k)
Expand All @@ -1864,15 +1854,14 @@ def load_lora_weights(
"To get a comprehensive list of parameter names that were modified, enable debug logging."
)

if len(transformer_lora_state_dict) > 0:
self.load_lora_into_transformer(
transformer_lora_state_dict,
network_alphas=network_alphas,
transformer=transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
transformer=transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)

if len(transformer_norm_state_dict) > 0:
transformer._transformer_norm_layers = self._load_norm_into_transformer(
Expand All @@ -1881,18 +1870,16 @@ def load_lora_weights(
discard_original_layers=False,
)

text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)

@classmethod
def load_lora_into_transformer(
Expand Down Expand Up @@ -1925,17 +1912,13 @@ def load_lora_into_transformer(
)

# Load the layers corresponding to transformer.
keys = list(state_dict.keys())
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
if transformer_present:
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
transformer.load_lora_adapter(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)

@classmethod
def _load_norm_into_transformer(
Expand Down Expand Up @@ -2143,6 +2126,11 @@ def load_lora_into_text_encoder(
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />

else:
logger.info(
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new."
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights(
Expand Down Expand Up @@ -2413,17 +2401,13 @@ def load_lora_into_transformer(
)

# Load the layers corresponding to transformer.
keys = list(state_dict.keys())
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
if transformer_present:
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
transformer.load_lora_adapter(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
Expand Down Expand Up @@ -2577,6 +2561,11 @@ def load_lora_into_text_encoder(
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />

else:
logger.info(
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new."
)

@classmethod
def save_lora_weights(
cls,
Expand Down Expand Up @@ -2816,7 +2805,6 @@ def load_lora_into_transformer(
)

# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
Expand Down Expand Up @@ -3124,7 +3112,6 @@ def load_lora_into_transformer(
)

# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,15 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
if len(model_keys) > 0:
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
else:
state_dict = {}

if len(state_dict) > 0:
if prefix is None:
component_name = "unet" if "UNet" in self.__class__.__name__ else "transformer"
else:
component_name = prefix
logger.info(f"Loading {component_name}.")
if adapter_name in getattr(self, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
Expand Down Expand Up @@ -351,6 +358,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />

else:
logger.info(
f"No LoRA keys found in the provided state dict for {self.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new."
)

def save_lora_adapter(
self,
save_directory,
Expand Down
Loading
Loading