Skip to content

[LoRA] restrict certain keys to be checked for peft config update. #10808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 24, 2025
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
}
_NO_CONFIG_UPDATE_KEYS = ["to_k", "to_q", "to_v"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_NO_CONFIG_UPDATE_KEYS = ["to_k", "to_q", "to_v"]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need it when PEFT version doesn't contain the required prefix.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it

a possible fix for the issue in that function is storing deleted keys to ensure they aren't re-added by other iterations

If this issue reoccurs with other keys before minimum PEFT version is increased this can be applied

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check now?



def _maybe_adjust_config(config):
Expand All @@ -68,6 +69,8 @@ def _maybe_adjust_config(config):
original_r = config["r"]

for key in list(rank_pattern.keys()):
if any(prefix in key for prefix in _NO_CONFIG_UPDATE_KEYS):
continue
key_rank = rank_pattern[key]

# try to detect ambiguity
Expand Down Expand Up @@ -187,6 +190,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer

try:
from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX
except ImportError:
FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None

cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
Expand Down Expand Up @@ -251,14 +259,22 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
# Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1:
rank[key] = val.shape[1]
# Support to handle cases where layer patterns are treated as full layer names
# was added later in PEFT. So, we handle it accordingly.
# TODO: when we fix the minimal PEFT version for Diffusers,
# we should remove `_maybe_adjust_config()`.
if FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1]
else:
rank[key] = val.shape[1]

if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}

lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
Expand Down
Loading