-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] support more comyui loras for Flux 🚨 #10985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
812b4e1
support more comyui loras.
sayakpaul 367153d
fix
sayakpaul 5c4976b
fixes
sayakpaul a6d8f3f
revert changes in LoRA base.
sayakpaul 1074836
no position_embedding
sayakpaul ca88a5e
Merge branch 'main' into support-comyui-flux-loras
sayakpaul 4b9d2df
Merge branch 'main' into support-comyui-flux-loras
sayakpaul 2b6990f
Merge branch 'main' into support-comyui-flux-loras
sayakpaul cc51f5c
Merge branch 'main' into support-comyui-flux-loras
sayakpaul ba0f8a3
Merge branch 'main' into support-comyui-flux-loras
sayakpaul 1c98875
🚨 introduce a breaking change to let peft handle module ambiguity
sayakpaul 05ccc90
Merge branch 'main' into support-comyui-flux-loras
sayakpaul fc25e1c
styling
sayakpaul 0560dcc
Merge branch 'main' into support-comyui-flux-loras
sayakpaul 34226af
Merge branch 'main' into support-comyui-flux-loras
sayakpaul 78ae954
Merge branch 'main' into support-comyui-flux-loras
sayakpaul 13ecc86
Merge branch 'main' into support-comyui-flux-loras
sayakpaul 2f2100a
Merge branch 'main' into support-comyui-flux-loras
sayakpaul 0d88427
remove position embeddings.
sayakpaul ea0d131
improvements.
sayakpaul 2cb82f3
style
sayakpaul 5e11a89
Merge branch 'main' into support-comyui-flux-loras
sayakpaul 09b2a0f
Merge branch 'main' into support-comyui-flux-loras
sayakpaul c30a1e4
Merge branch 'main' into support-comyui-flux-loras
sayakpaul 30f8f74
Merge branch 'main' into support-comyui-flux-loras
sayakpaul 3a6eefc
make info instead of NotImplementedError
sayakpaul e2f51de
Update src/diffusers/loaders/peft.py
sayakpaul 171fa24
Merge branch 'main' into support-comyui-flux-loras
sayakpaul 090468c
add example.
sayakpaul 90bf93d
Merge branch 'main' into support-comyui-flux-loras
sayakpaul 7532406
Merge branch 'main' into support-comyui-flux-loras
sayakpaul f754663
robust checks
sayakpaul b5c136f
updates
sayakpaul df28778
Merge branch 'main' into support-comyui-flux-loras
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,59 +58,24 @@ | |
} | ||
|
||
|
||
def _maybe_adjust_config(config): | ||
""" | ||
We may run into some ambiguous configuration values when a model has module names, sharing a common prefix | ||
(`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This | ||
method removes the ambiguity by following what is described here: | ||
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028. | ||
""" | ||
# Track keys that have been explicitly removed to prevent re-adding them. | ||
deleted_keys = set() | ||
|
||
def _maybe_raise_error_for_ambiguity(config): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a breaking change but in this case, I would very much prefer this as otherwise it is becoming increasingly difficult and cumbersome to support LoRAs of the world. |
||
rank_pattern = config["rank_pattern"].copy() | ||
target_modules = config["target_modules"] | ||
original_r = config["r"] | ||
|
||
for key in list(rank_pattern.keys()): | ||
key_rank = rank_pattern[key] | ||
|
||
# try to detect ambiguity | ||
# `target_modules` can also be a str, in which case this loop would loop | ||
# over the chars of the str. The technically correct way to match LoRA keys | ||
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). | ||
# But this cuts it for now. | ||
exact_matches = [mod for mod in target_modules if mod == key] | ||
substring_matches = [mod for mod in target_modules if key in mod and mod != key] | ||
ambiguous_key = key | ||
|
||
if exact_matches and substring_matches: | ||
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example) | ||
config["r"] = key_rank | ||
# remove the ambiguous key from `rank_pattern` and record it as deleted | ||
del config["rank_pattern"][key] | ||
deleted_keys.add(key) | ||
# For substring matches, add them with the original rank only if they haven't been assigned already | ||
for mod in substring_matches: | ||
if mod not in config["rank_pattern"] and mod not in deleted_keys: | ||
config["rank_pattern"][mod] = original_r | ||
|
||
# Update the rest of the target modules with the original rank if not already set and not deleted | ||
for mod in target_modules: | ||
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys: | ||
config["rank_pattern"][mod] = original_r | ||
|
||
# Handle alphas to deal with cases like: | ||
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777 | ||
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"] | ||
if has_different_ranks: | ||
config["lora_alpha"] = config["r"] | ||
alpha_pattern = {} | ||
for module_name, rank in config["rank_pattern"].items(): | ||
alpha_pattern[module_name] = rank | ||
config["alpha_pattern"] = alpha_pattern | ||
|
||
return config | ||
if is_peft_version("<", "0.14.1"): | ||
raise ValueError( | ||
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." | ||
) | ||
|
||
|
||
class PeftAdapterMixin: | ||
|
@@ -254,16 +219,18 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans | |
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions. | ||
# Bias layers in LoRA only have a single dimension | ||
if "lora_B" in key and val.ndim > 1: | ||
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged. | ||
rank[key] = val.shape[1] | ||
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. | ||
# We may run into some ambiguous configuration values when a model has module | ||
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`, | ||
# for example) and they have different LoRA ranks. | ||
rank[f"^{key}"] = val.shape[1] | ||
|
||
if network_alphas is not None and len(network_alphas) >= 1: | ||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] | ||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} | ||
|
||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) | ||
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged. | ||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) | ||
_maybe_raise_error_for_ambiguity(lora_config_kwargs) | ||
|
||
if "use_dora" in lora_config_kwargs: | ||
if lora_config_kwargs["use_dora"]: | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with those other state dict formats, just wanted to ask whether it would be safer to use dots in the filter keys, e.g.
.diff.
instead ofdiff
to prevent accidental matches.