-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] restrict certain keys to be checked for peft config update. #10808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
a9f4762
5cb1e09
c8802d7
8c988f4
bb0d4a1
76f9d82
cbc4432
5459b60
a81351b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -54,6 +54,7 @@ | |||
"SanaTransformer2DModel": lambda model_cls, weights: weights, | ||||
"Lumina2Transformer2DModel": lambda model_cls, weights: weights, | ||||
} | ||||
_NO_CONFIG_UPDATE_KEYS = ["to_k", "to_q", "to_v"] | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need it when PEFT version doesn't contain the required prefix. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it
If this issue reoccurs with other keys before minimum PEFT version is increased this can be applied There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check now? |
||||
|
||||
|
||||
def _maybe_adjust_config(config): | ||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
@@ -68,6 +69,8 @@ def _maybe_adjust_config(config): | |||
original_r = config["r"] | ||||
|
||||
for key in list(rank_pattern.keys()): | ||||
if any(prefix in key for prefix in _NO_CONFIG_UPDATE_KEYS): | ||||
continue | ||||
hlky marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
key_rank = rank_pattern[key] | ||||
|
||||
# try to detect ambiguity | ||||
|
@@ -187,6 +190,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans | |||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict | ||||
from peft.tuners.tuners_utils import BaseTunerLayer | ||||
|
||||
try: | ||||
from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX | ||||
except ImportError: | ||||
FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None | ||||
|
||||
cache_dir = kwargs.pop("cache_dir", None) | ||||
force_download = kwargs.pop("force_download", False) | ||||
proxies = kwargs.pop("proxies", None) | ||||
|
@@ -251,14 +259,22 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans | |||
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions. | ||||
# Bias layers in LoRA only have a single dimension | ||||
if "lora_B" in key and val.ndim > 1: | ||||
rank[key] = val.shape[1] | ||||
# Support to handle cases where layer patterns are treated as full layer names | ||||
# was added later in PEFT. So, we handle it accordingly. | ||||
# TODO: when we fix the minimal PEFT version for Diffusers, | ||||
# we should remove `_maybe_adjust_config()`. | ||||
if FULLY_QUALIFIED_PATTERN_KEY_PREFIX: | ||||
rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1] | ||||
else: | ||||
rank[key] = val.shape[1] | ||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1: | ||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] | ||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} | ||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) | ||||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) | ||||
if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX: | ||||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) | ||||
|
||||
if "use_dora" in lora_config_kwargs: | ||||
if lora_config_kwargs["use_dora"]: | ||||
|
Uh oh!
There was an error while loading. Please reload this page.