-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] support more comyui loras for Flux 🚨 #10985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
812b4e1
367153d
5c4976b
a6d8f3f
1074836
ca88a5e
4b9d2df
2b6990f
cc51f5c
ba0f8a3
1c98875
05ccc90
fc25e1c
0560dcc
34226af
78ae954
13ecc86
2f2100a
0d88427
ea0d131
2cb82f3
5e11a89
09b2a0f
c30a1e4
30f8f74
3a6eefc
e2f51de
171fa24
090468c
90bf93d
7532406
f754663
b5c136f
df28778
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,15 +13,22 @@ | |
# limitations under the License. | ||
|
||
import re | ||
from typing import List | ||
|
||
import torch | ||
|
||
from ..utils import is_peft_version, logging | ||
from ..utils import is_peft_version, logging, state_dict_all_zero | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
def swap_scale_shift(weight): | ||
shift, scale = weight.chunk(2, dim=0) | ||
new_weight = torch.cat([scale, shift], dim=0) | ||
return new_weight | ||
|
||
|
||
def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5): | ||
# 1. get all state_dict_keys | ||
all_keys = list(state_dict.keys()) | ||
|
@@ -313,6 +320,7 @@ def _convert_text_encoder_lora_key(key, lora_name): | |
# Be aware that this is the new diffusers convention and the rest of the code might | ||
# not utilize it yet. | ||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") | ||
|
||
return diffusers_name | ||
|
||
|
||
|
@@ -331,8 +339,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): | |
|
||
|
||
# The utilities under `_convert_kohya_flux_lora_to_diffusers()` | ||
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py | ||
# All credits go to `kohya-ss`. | ||
# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py | ||
def _convert_kohya_flux_lora_to_diffusers(state_dict): | ||
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): | ||
if sds_key + ".lora_down.weight" not in sds_sd: | ||
|
@@ -341,7 +348,8 @@ def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): | |
|
||
# scale weight by alpha and dim | ||
rank = down_weight.shape[0] | ||
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar | ||
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False) | ||
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar | ||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here | ||
|
||
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 | ||
|
@@ -362,7 +370,10 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): | |
sd_lora_rank = down_weight.shape[0] | ||
|
||
# scale weight by alpha and dim | ||
alpha = sds_sd.pop(sds_key + ".alpha") | ||
default_alpha = torch.tensor( | ||
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False | ||
) | ||
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha) | ||
scale = alpha / sd_lora_rank | ||
|
||
# calculate scale_down and scale_up | ||
|
@@ -516,10 +527,103 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd): | |
f"transformer.single_transformer_blocks.{i}.norm.linear", | ||
) | ||
|
||
# TODO: alphas. | ||
def assign_remaining_weights(assignments, source): | ||
for lora_key in ["lora_A", "lora_B"]: | ||
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up" | ||
for target_fmt, source_fmt, transform in assignments: | ||
target_key = target_fmt.format(lora_key=lora_key) | ||
source_key = source_fmt.format(orig_lora_key=orig_lora_key) | ||
value = source.pop(source_key) | ||
if transform: | ||
value = transform(value) | ||
ait_sd[target_key] = value | ||
|
||
if any("guidance_in" in k for k in sds_sd): | ||
assign_remaining_weights( | ||
[ | ||
( | ||
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight", | ||
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight", | ||
None, | ||
), | ||
( | ||
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight", | ||
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight", | ||
None, | ||
), | ||
], | ||
sds_sd, | ||
) | ||
|
||
if any("img_in" in k for k in sds_sd): | ||
assign_remaining_weights( | ||
[ | ||
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None), | ||
], | ||
sds_sd, | ||
) | ||
|
||
if any("txt_in" in k for k in sds_sd): | ||
assign_remaining_weights( | ||
[ | ||
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None), | ||
], | ||
sds_sd, | ||
) | ||
|
||
if any("time_in" in k for k in sds_sd): | ||
assign_remaining_weights( | ||
[ | ||
( | ||
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight", | ||
"lora_unet_time_in_in_layer.{orig_lora_key}.weight", | ||
None, | ||
), | ||
( | ||
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight", | ||
"lora_unet_time_in_out_layer.{orig_lora_key}.weight", | ||
None, | ||
), | ||
], | ||
sds_sd, | ||
) | ||
|
||
if any("vector_in" in k for k in sds_sd): | ||
assign_remaining_weights( | ||
[ | ||
( | ||
"time_text_embed.text_embedder.linear_1.{lora_key}.weight", | ||
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight", | ||
None, | ||
), | ||
( | ||
"time_text_embed.text_embedder.linear_2.{lora_key}.weight", | ||
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight", | ||
None, | ||
), | ||
], | ||
sds_sd, | ||
) | ||
|
||
if any("final_layer" in k for k in sds_sd): | ||
# Notice the swap in processing for "final_layer". | ||
assign_remaining_weights( | ||
[ | ||
( | ||
"norm_out.linear.{lora_key}.weight", | ||
"lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight", | ||
swap_scale_shift, | ||
), | ||
("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None), | ||
], | ||
sds_sd, | ||
) | ||
|
||
remaining_keys = list(sds_sd.keys()) | ||
te_state_dict = {} | ||
if remaining_keys: | ||
if not all(k.startswith("lora_te") for k in remaining_keys): | ||
if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys): | ||
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}") | ||
for key in remaining_keys: | ||
if not key.endswith("lora_down.weight"): | ||
|
@@ -680,10 +784,96 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): | |
if has_peft_state_dict: | ||
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")} | ||
return state_dict | ||
|
||
# Another weird one. | ||
has_mixture = any( | ||
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict | ||
) | ||
|
||
# ComfyUI. | ||
state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()} | ||
state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()} | ||
|
||
has_position_embedding = any("position_embedding" in k for k in state_dict) | ||
if has_position_embedding: | ||
zero_status_pe = state_dict_all_zero(state_dict, "position_embedding") | ||
if zero_status_pe: | ||
logger.info( | ||
"The `position_embedding` LoRA params are all zeros which make them ineffective. " | ||
"So, we will purge them out of the curret state dict to make loading possible." | ||
) | ||
state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k} | ||
else: | ||
raise NotImplementedError( | ||
"The state_dict has position_embedding LoRA params and we currently do not support them. " | ||
"Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new." | ||
) | ||
|
||
has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict) | ||
if has_t5xxl: | ||
zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl") | ||
if zero_status_t5: | ||
logger.info( | ||
"The `t5xxl` LoRA params are all zeros which make them ineffective. " | ||
"So, we will purge them out of the curret state dict to make loading possible." | ||
) | ||
else: | ||
logger.info( | ||
"T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out." | ||
"Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new." | ||
) | ||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")} | ||
|
||
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict) | ||
if has_diffb: | ||
zero_status_diff_b = state_dict_all_zero(state_dict, "diff_b") | ||
if zero_status_diff_b: | ||
logger.info( | ||
"The `diff_b` LoRA params are all zeros which make them ineffective. " | ||
"So, we will purge them out of the curret state dict to make loading possible." | ||
) | ||
else: | ||
logger.info( | ||
"`diff_b` keys found in the state dict which are currently unsupported. " | ||
"So, we will filter out those keys. Open an issue if this is a problem - " | ||
"https://github.com/huggingface/diffusers/issues/new." | ||
) | ||
state_dict = {k: v for k, v in state_dict.items() if "diff_b" not in k} | ||
|
||
has_norm_diff = any("norm" in k and "diff" in k for k in state_dict) | ||
if has_norm_diff: | ||
zero_status_diff = state_dict_all_zero(state_dict, "diff") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not familiar with those other state dict formats, just wanted to ask whether it would be safer to use dots in the filter keys, e.g. |
||
if zero_status_diff: | ||
logger.info( | ||
"The `diff` LoRA params are all zeros which make them ineffective. " | ||
"So, we will purge them out of the curret state dict to make loading possible." | ||
) | ||
else: | ||
logger.info( | ||
"Normalization diff keys found in the state dict which are currently unsupported. " | ||
"So, we will filter out those keys. Open an issue if this is a problem - " | ||
"https://github.com/huggingface/diffusers/issues/new." | ||
) | ||
state_dict = {k: v for k, v in state_dict.items() if "norm" not in k and "diff" not in k} | ||
|
||
limit_substrings = ["lora_down", "lora_up"] | ||
if any("alpha" in k for k in state_dict): | ||
limit_substrings.append("alpha") | ||
|
||
state_dict = { | ||
_custom_replace(k, limit_substrings): v | ||
for k, v in state_dict.items() | ||
if k.startswith(("lora_unet_", "lora_te_")) | ||
} | ||
|
||
if any("text_projection" in k for k in state_dict): | ||
logger.info( | ||
"`text_projection` keys found in the `state_dict` which are unexpected. " | ||
"So, we will filter out those keys. Open an issue if this is a problem - " | ||
"https://github.com/huggingface/diffusers/issues/new." | ||
) | ||
state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k} | ||
|
||
if has_mixture: | ||
return _convert_mixture_state_dict_to_diffusers(state_dict) | ||
|
||
|
@@ -798,6 +988,23 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): | |
return new_state_dict | ||
|
||
|
||
def _custom_replace(key: str, substrings: List[str]) -> str: | ||
pattern = "(" + "|".join(re.escape(sub) for sub in substrings) + ")" | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
match = re.search(pattern, key) | ||
if match: | ||
start_sub = match.start() | ||
if start_sub > 0 and key[start_sub - 1] == ".": | ||
boundary = start_sub - 1 | ||
else: | ||
boundary = start_sub | ||
left = key[:boundary].replace(".", "_") | ||
right = key[boundary:] | ||
return left + right | ||
else: | ||
return key.replace(".", "_") | ||
|
||
|
||
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): | ||
converted_state_dict = {} | ||
original_state_dict_keys = list(original_state_dict.keys()) | ||
|
@@ -806,11 +1013,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): | |
inner_dim = 3072 | ||
mlp_ratio = 4.0 | ||
|
||
def swap_scale_shift(weight): | ||
shift, scale = weight.chunk(2, dim=0) | ||
new_weight = torch.cat([scale, shift], dim=0) | ||
return new_weight | ||
|
||
for lora_key in ["lora_A", "lora_B"]: | ||
## time_text_embed.timestep_embedder <- time_in | ||
converted_state_dict[ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,59 +58,24 @@ | |
} | ||
|
||
|
||
def _maybe_adjust_config(config): | ||
""" | ||
We may run into some ambiguous configuration values when a model has module names, sharing a common prefix | ||
(`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This | ||
method removes the ambiguity by following what is described here: | ||
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028. | ||
""" | ||
# Track keys that have been explicitly removed to prevent re-adding them. | ||
deleted_keys = set() | ||
|
||
def _maybe_raise_error_for_ambiguity(config): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a breaking change but in this case, I would very much prefer this as otherwise it is becoming increasingly difficult and cumbersome to support LoRAs of the world. |
||
rank_pattern = config["rank_pattern"].copy() | ||
target_modules = config["target_modules"] | ||
original_r = config["r"] | ||
|
||
for key in list(rank_pattern.keys()): | ||
key_rank = rank_pattern[key] | ||
|
||
# try to detect ambiguity | ||
# `target_modules` can also be a str, in which case this loop would loop | ||
# over the chars of the str. The technically correct way to match LoRA keys | ||
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). | ||
# But this cuts it for now. | ||
exact_matches = [mod for mod in target_modules if mod == key] | ||
substring_matches = [mod for mod in target_modules if key in mod and mod != key] | ||
ambiguous_key = key | ||
|
||
if exact_matches and substring_matches: | ||
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example) | ||
config["r"] = key_rank | ||
# remove the ambiguous key from `rank_pattern` and record it as deleted | ||
del config["rank_pattern"][key] | ||
deleted_keys.add(key) | ||
# For substring matches, add them with the original rank only if they haven't been assigned already | ||
for mod in substring_matches: | ||
if mod not in config["rank_pattern"] and mod not in deleted_keys: | ||
config["rank_pattern"][mod] = original_r | ||
|
||
# Update the rest of the target modules with the original rank if not already set and not deleted | ||
for mod in target_modules: | ||
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys: | ||
config["rank_pattern"][mod] = original_r | ||
|
||
# Handle alphas to deal with cases like: | ||
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777 | ||
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"] | ||
if has_different_ranks: | ||
config["lora_alpha"] = config["r"] | ||
alpha_pattern = {} | ||
for module_name, rank in config["rank_pattern"].items(): | ||
alpha_pattern[module_name] = rank | ||
config["alpha_pattern"] = alpha_pattern | ||
|
||
return config | ||
if not is_peft_version(">=", "0.14.1"): | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." | ||
) | ||
|
||
|
||
class PeftAdapterMixin: | ||
|
@@ -254,16 +219,18 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans | |
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions. | ||
# Bias layers in LoRA only have a single dimension | ||
if "lora_B" in key and val.ndim > 1: | ||
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged. | ||
rank[key] = val.shape[1] | ||
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. | ||
# We may run into some ambiguous configuration values when a model has module | ||
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`, | ||
# for example) and they have different LoRA ranks. | ||
rank[f"^{key}"] = val.shape[1] | ||
|
||
if network_alphas is not None and len(network_alphas) >= 1: | ||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] | ||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} | ||
|
||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) | ||
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged. | ||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) | ||
_maybe_raise_error_for_ambiguity(lora_config_kwargs) | ||
|
||
if "use_dora" in lora_config_kwargs: | ||
if lora_config_kwargs["use_dora"]: | ||
|
Uh oh!
There was an error while loading. Please reload this page.