Skip to content

[LoRA] restrict certain keys to be checked for peft config update. #10808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 24, 2025
35 changes: 26 additions & 9 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def _maybe_adjust_config(config):
method removes the ambiguity by following what is described here:
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
"""
# Track keys that have been explicitly removed to prevent re-adding them.
deleted_keys = set()

rank_pattern = config["rank_pattern"].copy()
target_modules = config["target_modules"]
original_r = config["r"]
Expand All @@ -80,21 +83,22 @@ def _maybe_adjust_config(config):
ambiguous_key = key

if exact_matches and substring_matches:
# if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example)
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
config["r"] = key_rank
# remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead
# remove the ambiguous key from `rank_pattern` and record it as deleted
del config["rank_pattern"][key]
deleted_keys.add(key)
# For substring matches, add them with the original rank only if they haven't been assigned already
for mod in substring_matches:
# avoid overwriting if the module already has a specific rank
if mod not in config["rank_pattern"]:
if mod not in config["rank_pattern"] and mod not in deleted_keys:
config["rank_pattern"][mod] = original_r

# update the rest of the keys with the `original_r`
# Update the rest of the target modules with the original rank if not already set and not deleted
for mod in target_modules:
if mod != ambiguous_key and mod not in config["rank_pattern"]:
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
config["rank_pattern"][mod] = original_r

# handle alphas to deal with cases like
# Handle alphas to deal with cases like:
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
if has_different_ranks:
Expand Down Expand Up @@ -187,6 +191,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer

try:
from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX
except ImportError:
FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None

cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
Expand Down Expand Up @@ -251,14 +260,22 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
# Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1:
rank[key] = val.shape[1]
# Support to handle cases where layer patterns are treated as full layer names
# was added later in PEFT. So, we handle it accordingly.
# TODO: when we fix the minimal PEFT version for Diffusers,
# we should remove `_maybe_adjust_config()`.
if FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1]
else:
rank[key] = val.shape[1]

if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}

lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
Expand Down