-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] Improve warning messages when LoRA loading becomes a no-op #10187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
cd88a4b
updates
sayakpaul b694ca4
updates
sayakpaul 1db7503
updates
sayakpaul 6134491
updates
sayakpaul db827b5
Merge branch 'main' into improve-lora-warning-msg
sayakpaul ac29785
notebooks revert
sayakpaul 3f4a3fc
resolve conflicts
sayakpaul b6db978
Merge branch 'main' into improve-lora-warning-msg
sayakpaul c44f7a3
Merge branch 'main' into improve-lora-warning-msg
sayakpaul 876132a
fix-copies.
sayakpaul bb7c09a
Merge branch 'main' into improve-lora-warning-msg
sayakpaul e6043a0
seeing
sayakpaul 7ca7493
fix
sayakpaul ec44f9a
revert
sayakpaul 343b2d2
Merge branch 'main' into improve-lora-warning-msg
sayakpaul 615e372
fixes
sayakpaul e2e3ea0
fixes
sayakpaul f9dd64c
fixes
sayakpaul a01cb45
remove print
sayakpaul da96621
fix
sayakpaul a91138d
Merge branch 'main' into improve-lora-warning-msg
sayakpaul 83ad82b
Merge branch 'main' into improve-lora-warning-msg
sayakpaul be187da
Merge branch 'main' into improve-lora-warning-msg
sayakpaul 726e492
Merge branch 'main' into improve-lora-warning-msg
sayakpaul 3efdc58
fix conflicts
sayakpaul cf50148
conflicts ii.
sayakpaul b2afc10
updates
sayakpaul 96eced3
fixes
sayakpaul b4be719
Merge branch 'main' into improve-lora-warning-msg
sayakpaul 8bf1173
Merge branch 'main' into improve-lora-warning-msg
sayakpaul 0e43b55
Merge branch 'main' into improve-lora-warning-msg
hlky 1e4dbbc
Merge branch 'main' into improve-lora-warning-msg
sayakpaul 9eb460f
better filtering of prefix.
sayakpaul 279ee91
Merge branch 'main' into improve-lora-warning-msg
sayakpaul 6240876
Merge branch 'main' into improve-lora-warning-msg
sayakpaul cf9027a
Merge branch 'main' into improve-lora-warning-msg
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -294,22 +294,15 @@ def load_lora_into_unet( | |
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | ||
) | ||
|
||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), | ||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as | ||
# their prefixes. | ||
keys = list(state_dict.keys()) | ||
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) | ||
if not only_text_encoder: | ||
# Load the layers corresponding to UNet. | ||
logger.info(f"Loading {cls.unet_name}.") | ||
unet.load_lora_adapter( | ||
state_dict, | ||
prefix=cls.unet_name, | ||
network_alphas=network_alphas, | ||
adapter_name=adapter_name, | ||
_pipeline=_pipeline, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
# Load the layers corresponding to UNet. | ||
unet.load_lora_adapter( | ||
state_dict, | ||
prefix=cls.unet_name, | ||
network_alphas=network_alphas, | ||
adapter_name=adapter_name, | ||
_pipeline=_pipeline, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
|
||
@classmethod | ||
def load_lora_into_text_encoder( | ||
|
@@ -462,6 +455,11 @@ def load_lora_into_text_encoder( | |
_pipeline.enable_sequential_cpu_offload() | ||
# Unsafe code /> | ||
|
||
else: | ||
logger.info( | ||
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new." | ||
) | ||
|
||
@classmethod | ||
def save_lora_weights( | ||
cls, | ||
|
@@ -660,18 +658,16 @@ def load_lora_weights( | |
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} | ||
if len(text_encoder_state_dict) > 0: | ||
self.load_lora_into_text_encoder( | ||
text_encoder_state_dict, | ||
network_alphas=network_alphas, | ||
text_encoder=self.text_encoder, | ||
prefix="text_encoder", | ||
lora_scale=self.lora_scale, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
Comment on lines
-562
to
-573
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar philosophy as explained above. |
||
self.load_lora_into_text_encoder( | ||
state_dict, | ||
network_alphas=network_alphas, | ||
text_encoder=self.text_encoder, | ||
prefix="text_encoder", | ||
lora_scale=self.lora_scale, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
|
||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} | ||
if len(text_encoder_2_state_dict) > 0: | ||
|
@@ -836,22 +832,15 @@ def load_lora_into_unet( | |
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | ||
) | ||
|
||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), | ||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as | ||
# their prefixes. | ||
keys = list(state_dict.keys()) | ||
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) | ||
if not only_text_encoder: | ||
# Load the layers corresponding to UNet. | ||
logger.info(f"Loading {cls.unet_name}.") | ||
unet.load_lora_adapter( | ||
state_dict, | ||
prefix=cls.unet_name, | ||
network_alphas=network_alphas, | ||
adapter_name=adapter_name, | ||
_pipeline=_pipeline, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
# Load the layers corresponding to UNet. | ||
unet.load_lora_adapter( | ||
state_dict, | ||
prefix=cls.unet_name, | ||
network_alphas=network_alphas, | ||
adapter_name=adapter_name, | ||
_pipeline=_pipeline, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
|
||
@classmethod | ||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder | ||
|
@@ -1005,6 +994,11 @@ def load_lora_into_text_encoder( | |
_pipeline.enable_sequential_cpu_offload() | ||
# Unsafe code /> | ||
|
||
else: | ||
logger.info( | ||
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new." | ||
) | ||
|
||
@classmethod | ||
def save_lora_weights( | ||
cls, | ||
|
@@ -1288,43 +1282,35 @@ def load_lora_weights( | |
if not is_correct_format: | ||
raise ValueError("Invalid LoRA checkpoint.") | ||
|
||
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} | ||
if len(transformer_state_dict) > 0: | ||
self.load_lora_into_transformer( | ||
state_dict, | ||
transformer=getattr(self, self.transformer_name) | ||
if not hasattr(self, "transformer") | ||
else self.transformer, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
self.load_lora_into_transformer( | ||
state_dict, | ||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
|
||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} | ||
if len(text_encoder_state_dict) > 0: | ||
self.load_lora_into_text_encoder( | ||
text_encoder_state_dict, | ||
network_alphas=None, | ||
text_encoder=self.text_encoder, | ||
prefix="text_encoder", | ||
lora_scale=self.lora_scale, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
self.load_lora_into_text_encoder( | ||
state_dict, | ||
network_alphas=None, | ||
text_encoder=self.text_encoder, | ||
prefix="text_encoder", | ||
lora_scale=self.lora_scale, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
|
||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} | ||
if len(text_encoder_2_state_dict) > 0: | ||
self.load_lora_into_text_encoder( | ||
text_encoder_2_state_dict, | ||
network_alphas=None, | ||
text_encoder=self.text_encoder_2, | ||
prefix="text_encoder_2", | ||
lora_scale=self.lora_scale, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
self.load_lora_into_text_encoder( | ||
state_dict, | ||
network_alphas=None, | ||
text_encoder=self.text_encoder_2, | ||
prefix="text_encoder_2", | ||
lora_scale=self.lora_scale, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
|
||
@classmethod | ||
def load_lora_into_transformer( | ||
|
@@ -1353,7 +1339,6 @@ def load_lora_into_transformer( | |
) | ||
|
||
# Load the layers corresponding to transformer. | ||
logger.info(f"Loading {cls.transformer_name}.") | ||
transformer.load_lora_adapter( | ||
state_dict, | ||
network_alphas=None, | ||
|
@@ -1514,6 +1499,11 @@ def load_lora_into_text_encoder( | |
_pipeline.enable_sequential_cpu_offload() | ||
# Unsafe code /> | ||
|
||
else: | ||
logger.info( | ||
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new." | ||
) | ||
|
||
@classmethod | ||
def save_lora_weights( | ||
cls, | ||
|
@@ -1844,7 +1834,7 @@ def load_lora_weights( | |
raise ValueError("Invalid LoRA checkpoint.") | ||
|
||
transformer_lora_state_dict = { | ||
k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k | ||
k: state_dict.get(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k | ||
} | ||
transformer_norm_state_dict = { | ||
k: state_dict.pop(k) | ||
|
@@ -1864,15 +1854,14 @@ def load_lora_weights( | |
"To get a comprehensive list of parameter names that were modified, enable debug logging." | ||
) | ||
|
||
if len(transformer_lora_state_dict) > 0: | ||
self.load_lora_into_transformer( | ||
transformer_lora_state_dict, | ||
network_alphas=network_alphas, | ||
transformer=transformer, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
self.load_lora_into_transformer( | ||
state_dict, | ||
network_alphas=network_alphas, | ||
transformer=transformer, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
|
||
if len(transformer_norm_state_dict) > 0: | ||
transformer._transformer_norm_layers = self._load_norm_into_transformer( | ||
|
@@ -1881,18 +1870,16 @@ def load_lora_weights( | |
discard_original_layers=False, | ||
) | ||
|
||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} | ||
if len(text_encoder_state_dict) > 0: | ||
self.load_lora_into_text_encoder( | ||
text_encoder_state_dict, | ||
network_alphas=network_alphas, | ||
text_encoder=self.text_encoder, | ||
prefix="text_encoder", | ||
lora_scale=self.lora_scale, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
self.load_lora_into_text_encoder( | ||
state_dict, | ||
network_alphas=network_alphas, | ||
text_encoder=self.text_encoder, | ||
prefix="text_encoder", | ||
lora_scale=self.lora_scale, | ||
adapter_name=adapter_name, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
|
||
@classmethod | ||
def load_lora_into_transformer( | ||
|
@@ -1925,17 +1912,13 @@ def load_lora_into_transformer( | |
) | ||
|
||
# Load the layers corresponding to transformer. | ||
keys = list(state_dict.keys()) | ||
transformer_present = any(key.startswith(cls.transformer_name) for key in keys) | ||
if transformer_present: | ||
logger.info(f"Loading {cls.transformer_name}.") | ||
transformer.load_lora_adapter( | ||
state_dict, | ||
network_alphas=network_alphas, | ||
adapter_name=adapter_name, | ||
_pipeline=_pipeline, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
transformer.load_lora_adapter( | ||
state_dict, | ||
network_alphas=network_alphas, | ||
adapter_name=adapter_name, | ||
_pipeline=_pipeline, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
|
||
@classmethod | ||
def _load_norm_into_transformer( | ||
|
@@ -2143,6 +2126,11 @@ def load_lora_into_text_encoder( | |
_pipeline.enable_sequential_cpu_offload() | ||
# Unsafe code /> | ||
|
||
else: | ||
logger.info( | ||
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new." | ||
) | ||
|
||
@classmethod | ||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer | ||
def save_lora_weights( | ||
|
@@ -2413,17 +2401,13 @@ def load_lora_into_transformer( | |
) | ||
|
||
# Load the layers corresponding to transformer. | ||
keys = list(state_dict.keys()) | ||
transformer_present = any(key.startswith(cls.transformer_name) for key in keys) | ||
if transformer_present: | ||
logger.info(f"Loading {cls.transformer_name}.") | ||
transformer.load_lora_adapter( | ||
state_dict, | ||
network_alphas=network_alphas, | ||
adapter_name=adapter_name, | ||
_pipeline=_pipeline, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
transformer.load_lora_adapter( | ||
state_dict, | ||
network_alphas=network_alphas, | ||
adapter_name=adapter_name, | ||
_pipeline=_pipeline, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
|
||
@classmethod | ||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder | ||
|
@@ -2577,6 +2561,11 @@ def load_lora_into_text_encoder( | |
_pipeline.enable_sequential_cpu_offload() | ||
# Unsafe code /> | ||
|
||
else: | ||
logger.info( | ||
f"No LoRA keys found in the provided state dict for {text_encoder.__class__.__name__}. Please open an issue if you think this is unexpected - https://github.com/huggingface/diffusers/issues/new." | ||
) | ||
|
||
@classmethod | ||
def save_lora_weights( | ||
cls, | ||
|
@@ -2816,7 +2805,6 @@ def load_lora_into_transformer( | |
) | ||
|
||
# Load the layers corresponding to transformer. | ||
logger.info(f"Loading {cls.transformer_name}.") | ||
transformer.load_lora_adapter( | ||
state_dict, | ||
network_alphas=None, | ||
|
@@ -3124,7 +3112,6 @@ def load_lora_into_transformer( | |
) | ||
|
||
# Load the layers corresponding to transformer. | ||
logger.info(f"Loading {cls.transformer_name}.") | ||
transformer.load_lora_adapter( | ||
state_dict, | ||
network_alphas=None, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're handling all of this within the
load_lora_adapter()
method now which I think is more appropriate as:pipe.load_lora_weights()
.load_lora_adapter()
method (with something likeunet.load_lora_adapter()
.Helps to avoid duplication. I have run the integration tests, too and nothing is breaking due to this.