Skip to content

Commit 6bfacf0

Browse files
sayakpaulhlky
andauthored
[LoRA] support more comyui loras for Flux 🚨 (#10985)
* support more comyui loras. * fix * fixes * revert changes in LoRA base. * no position_embedding * 🚨 introduce a breaking change to let peft handle module ambiguity * styling * remove position embeddings. * improvements. * style * make info instead of NotImplementedError * Update src/diffusers/loaders/peft.py Co-authored-by: hlky <[email protected]> * add example. * robust checks * updates --------- Co-authored-by: hlky <[email protected]>
1 parent f685981 commit 6bfacf0

File tree

4 files changed

+244
-55
lines changed

4 files changed

+244
-55
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 218 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,22 @@
1313
# limitations under the License.
1414

1515
import re
16+
from typing import List
1617

1718
import torch
1819

19-
from ..utils import is_peft_version, logging
20+
from ..utils import is_peft_version, logging, state_dict_all_zero
2021

2122

2223
logger = logging.get_logger(__name__)
2324

2425

26+
def swap_scale_shift(weight):
27+
shift, scale = weight.chunk(2, dim=0)
28+
new_weight = torch.cat([scale, shift], dim=0)
29+
return new_weight
30+
31+
2532
def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
2633
# 1. get all state_dict_keys
2734
all_keys = list(state_dict.keys())
@@ -313,6 +320,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
313320
# Be aware that this is the new diffusers convention and the rest of the code might
314321
# not utilize it yet.
315322
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
323+
316324
return diffusers_name
317325

318326

@@ -331,8 +339,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
331339

332340

333341
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
334-
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
335-
# All credits go to `kohya-ss`.
342+
# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
336343
def _convert_kohya_flux_lora_to_diffusers(state_dict):
337344
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
338345
if sds_key + ".lora_down.weight" not in sds_sd:
@@ -341,7 +348,8 @@ def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
341348

342349
# scale weight by alpha and dim
343350
rank = down_weight.shape[0]
344-
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
351+
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
352+
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar
345353
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
346354

347355
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
@@ -362,7 +370,10 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
362370
sd_lora_rank = down_weight.shape[0]
363371

364372
# scale weight by alpha and dim
365-
alpha = sds_sd.pop(sds_key + ".alpha")
373+
default_alpha = torch.tensor(
374+
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
375+
)
376+
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
366377
scale = alpha / sd_lora_rank
367378

368379
# calculate scale_down and scale_up
@@ -516,10 +527,103 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
516527
f"transformer.single_transformer_blocks.{i}.norm.linear",
517528
)
518529

530+
# TODO: alphas.
531+
def assign_remaining_weights(assignments, source):
532+
for lora_key in ["lora_A", "lora_B"]:
533+
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
534+
for target_fmt, source_fmt, transform in assignments:
535+
target_key = target_fmt.format(lora_key=lora_key)
536+
source_key = source_fmt.format(orig_lora_key=orig_lora_key)
537+
value = source.pop(source_key)
538+
if transform:
539+
value = transform(value)
540+
ait_sd[target_key] = value
541+
542+
if any("guidance_in" in k for k in sds_sd):
543+
assign_remaining_weights(
544+
[
545+
(
546+
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
547+
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
548+
None,
549+
),
550+
(
551+
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
552+
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
553+
None,
554+
),
555+
],
556+
sds_sd,
557+
)
558+
559+
if any("img_in" in k for k in sds_sd):
560+
assign_remaining_weights(
561+
[
562+
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
563+
],
564+
sds_sd,
565+
)
566+
567+
if any("txt_in" in k for k in sds_sd):
568+
assign_remaining_weights(
569+
[
570+
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
571+
],
572+
sds_sd,
573+
)
574+
575+
if any("time_in" in k for k in sds_sd):
576+
assign_remaining_weights(
577+
[
578+
(
579+
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
580+
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
581+
None,
582+
),
583+
(
584+
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
585+
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
586+
None,
587+
),
588+
],
589+
sds_sd,
590+
)
591+
592+
if any("vector_in" in k for k in sds_sd):
593+
assign_remaining_weights(
594+
[
595+
(
596+
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
597+
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
598+
None,
599+
),
600+
(
601+
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
602+
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
603+
None,
604+
),
605+
],
606+
sds_sd,
607+
)
608+
609+
if any("final_layer" in k for k in sds_sd):
610+
# Notice the swap in processing for "final_layer".
611+
assign_remaining_weights(
612+
[
613+
(
614+
"norm_out.linear.{lora_key}.weight",
615+
"lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight",
616+
swap_scale_shift,
617+
),
618+
("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None),
619+
],
620+
sds_sd,
621+
)
622+
519623
remaining_keys = list(sds_sd.keys())
520624
te_state_dict = {}
521625
if remaining_keys:
522-
if not all(k.startswith("lora_te") for k in remaining_keys):
626+
if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys):
523627
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
524628
for key in remaining_keys:
525629
if not key.endswith("lora_down.weight"):
@@ -680,10 +784,98 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
680784
if has_peft_state_dict:
681785
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
682786
return state_dict
787+
683788
# Another weird one.
684789
has_mixture = any(
685790
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
686791
)
792+
793+
# ComfyUI.
794+
if not has_mixture:
795+
state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
796+
state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}
797+
798+
has_position_embedding = any("position_embedding" in k for k in state_dict)
799+
if has_position_embedding:
800+
zero_status_pe = state_dict_all_zero(state_dict, "position_embedding")
801+
if zero_status_pe:
802+
logger.info(
803+
"The `position_embedding` LoRA params are all zeros which make them ineffective. "
804+
"So, we will purge them out of the curret state dict to make loading possible."
805+
)
806+
807+
else:
808+
logger.info(
809+
"The state_dict has position_embedding LoRA params and we currently do not support them. "
810+
"Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
811+
)
812+
state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k}
813+
814+
has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
815+
if has_t5xxl:
816+
zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl")
817+
if zero_status_t5:
818+
logger.info(
819+
"The `t5xxl` LoRA params are all zeros which make them ineffective. "
820+
"So, we will purge them out of the curret state dict to make loading possible."
821+
)
822+
else:
823+
logger.info(
824+
"T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
825+
"Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
826+
)
827+
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
828+
829+
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
830+
if has_diffb:
831+
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
832+
if zero_status_diff_b:
833+
logger.info(
834+
"The `diff_b` LoRA params are all zeros which make them ineffective. "
835+
"So, we will purge them out of the curret state dict to make loading possible."
836+
)
837+
else:
838+
logger.info(
839+
"`diff_b` keys found in the state dict which are currently unsupported. "
840+
"So, we will filter out those keys. Open an issue if this is a problem - "
841+
"https://github.com/huggingface/diffusers/issues/new."
842+
)
843+
state_dict = {k: v for k, v in state_dict.items() if ".diff_b" not in k}
844+
845+
has_norm_diff = any(".norm" in k and ".diff" in k for k in state_dict)
846+
if has_norm_diff:
847+
zero_status_diff = state_dict_all_zero(state_dict, ".diff")
848+
if zero_status_diff:
849+
logger.info(
850+
"The `diff` LoRA params are all zeros which make them ineffective. "
851+
"So, we will purge them out of the curret state dict to make loading possible."
852+
)
853+
else:
854+
logger.info(
855+
"Normalization diff keys found in the state dict which are currently unsupported. "
856+
"So, we will filter out those keys. Open an issue if this is a problem - "
857+
"https://github.com/huggingface/diffusers/issues/new."
858+
)
859+
state_dict = {k: v for k, v in state_dict.items() if ".norm" not in k and ".diff" not in k}
860+
861+
limit_substrings = ["lora_down", "lora_up"]
862+
if any("alpha" in k for k in state_dict):
863+
limit_substrings.append("alpha")
864+
865+
state_dict = {
866+
_custom_replace(k, limit_substrings): v
867+
for k, v in state_dict.items()
868+
if k.startswith(("lora_unet_", "lora_te_"))
869+
}
870+
871+
if any("text_projection" in k for k in state_dict):
872+
logger.info(
873+
"`text_projection` keys found in the `state_dict` which are unexpected. "
874+
"So, we will filter out those keys. Open an issue if this is a problem - "
875+
"https://github.com/huggingface/diffusers/issues/new."
876+
)
877+
state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k}
878+
687879
if has_mixture:
688880
return _convert_mixture_state_dict_to_diffusers(state_dict)
689881

@@ -798,6 +990,26 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
798990
return new_state_dict
799991

800992

993+
def _custom_replace(key: str, substrings: List[str]) -> str:
994+
# Replaces the "."s with "_"s upto the `substrings`.
995+
# Example:
996+
# lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
997+
pattern = "(" + "|".join(re.escape(sub) for sub in substrings) + ")"
998+
999+
match = re.search(pattern, key)
1000+
if match:
1001+
start_sub = match.start()
1002+
if start_sub > 0 and key[start_sub - 1] == ".":
1003+
boundary = start_sub - 1
1004+
else:
1005+
boundary = start_sub
1006+
left = key[:boundary].replace(".", "_")
1007+
right = key[boundary:]
1008+
return left + right
1009+
else:
1010+
return key.replace(".", "_")
1011+
1012+
8011013
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
8021014
converted_state_dict = {}
8031015
original_state_dict_keys = list(original_state_dict.keys())
@@ -806,11 +1018,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
8061018
inner_dim = 3072
8071019
mlp_ratio = 4.0
8081020

809-
def swap_scale_shift(weight):
810-
shift, scale = weight.chunk(2, dim=0)
811-
new_weight = torch.cat([scale, shift], dim=0)
812-
return new_weight
813-
8141021
for lora_key in ["lora_A", "lora_B"]:
8151022
## time_text_embed.timestep_embedder <- time_in
8161023
converted_state_dict[

src/diffusers/loaders/peft.py

Lines changed: 11 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -58,59 +58,24 @@
5858
}
5959

6060

61-
def _maybe_adjust_config(config):
62-
"""
63-
We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
64-
(`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
65-
method removes the ambiguity by following what is described here:
66-
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
67-
"""
68-
# Track keys that have been explicitly removed to prevent re-adding them.
69-
deleted_keys = set()
70-
61+
def _maybe_raise_error_for_ambiguity(config):
7162
rank_pattern = config["rank_pattern"].copy()
7263
target_modules = config["target_modules"]
73-
original_r = config["r"]
7464

7565
for key in list(rank_pattern.keys()):
76-
key_rank = rank_pattern[key]
77-
7866
# try to detect ambiguity
7967
# `target_modules` can also be a str, in which case this loop would loop
8068
# over the chars of the str. The technically correct way to match LoRA keys
8169
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
8270
# But this cuts it for now.
8371
exact_matches = [mod for mod in target_modules if mod == key]
8472
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
85-
ambiguous_key = key
8673

8774
if exact_matches and substring_matches:
88-
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
89-
config["r"] = key_rank
90-
# remove the ambiguous key from `rank_pattern` and record it as deleted
91-
del config["rank_pattern"][key]
92-
deleted_keys.add(key)
93-
# For substring matches, add them with the original rank only if they haven't been assigned already
94-
for mod in substring_matches:
95-
if mod not in config["rank_pattern"] and mod not in deleted_keys:
96-
config["rank_pattern"][mod] = original_r
97-
98-
# Update the rest of the target modules with the original rank if not already set and not deleted
99-
for mod in target_modules:
100-
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
101-
config["rank_pattern"][mod] = original_r
102-
103-
# Handle alphas to deal with cases like:
104-
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
105-
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
106-
if has_different_ranks:
107-
config["lora_alpha"] = config["r"]
108-
alpha_pattern = {}
109-
for module_name, rank in config["rank_pattern"].items():
110-
alpha_pattern[module_name] = rank
111-
config["alpha_pattern"] = alpha_pattern
112-
113-
return config
75+
if is_peft_version("<", "0.14.1"):
76+
raise ValueError(
77+
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
78+
)
11479

11580

11681
class PeftAdapterMixin:
@@ -286,16 +251,18 @@ def load_lora_adapter(
286251
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
287252
# Bias layers in LoRA only have a single dimension
288253
if "lora_B" in key and val.ndim > 1:
289-
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
290-
rank[key] = val.shape[1]
254+
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
255+
# We may run into some ambiguous configuration values when a model has module
256+
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
257+
# for example) and they have different LoRA ranks.
258+
rank[f"^{key}"] = val.shape[1]
291259

292260
if network_alphas is not None and len(network_alphas) >= 1:
293261
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
294262
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
295263

296264
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
297-
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
298-
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
265+
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
299266

300267
if "use_dora" in lora_config_kwargs:
301268
if lora_config_kwargs["use_dora"]:

0 commit comments

Comments
 (0)