Skip to content

[LoRA] support more comyui loras for Flux 🚨 #10985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
812b4e1
support more comyui loras.
sayakpaul Mar 6, 2025
367153d
fix
sayakpaul Mar 7, 2025
5c4976b
fixes
sayakpaul Mar 7, 2025
a6d8f3f
revert changes in LoRA base.
sayakpaul Mar 7, 2025
1074836
no position_embedding
sayakpaul Mar 7, 2025
ca88a5e
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 7, 2025
4b9d2df
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 7, 2025
2b6990f
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 8, 2025
cc51f5c
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 8, 2025
ba0f8a3
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 10, 2025
1c98875
🚨 introduce a breaking change to let peft handle module ambiguity
sayakpaul Mar 12, 2025
05ccc90
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 12, 2025
fc25e1c
styling
sayakpaul Mar 12, 2025
0560dcc
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 12, 2025
34226af
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 13, 2025
78ae954
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 14, 2025
13ecc86
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 16, 2025
2f2100a
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 18, 2025
0d88427
remove position embeddings.
sayakpaul Mar 18, 2025
ea0d131
improvements.
sayakpaul Mar 18, 2025
2cb82f3
style
sayakpaul Mar 18, 2025
5e11a89
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 18, 2025
09b2a0f
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 19, 2025
c30a1e4
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 20, 2025
30f8f74
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 20, 2025
3a6eefc
make info instead of NotImplementedError
sayakpaul Mar 20, 2025
e2f51de
Update src/diffusers/loaders/peft.py
sayakpaul Mar 20, 2025
171fa24
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 21, 2025
090468c
add example.
sayakpaul Mar 21, 2025
90bf93d
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Mar 25, 2025
7532406
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Apr 8, 2025
f754663
robust checks
sayakpaul Apr 8, 2025
b5c136f
updates
sayakpaul Apr 8, 2025
df28778
Merge branch 'main' into support-comyui-flux-loras
sayakpaul Apr 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 218 additions & 11 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,22 @@
# limitations under the License.

import re
from typing import List

import torch

from ..utils import is_peft_version, logging
from ..utils import is_peft_version, logging, state_dict_all_zero


logger = logging.get_logger(__name__)


def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight


def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
# 1. get all state_dict_keys
all_keys = list(state_dict.keys())
Expand Down Expand Up @@ -313,6 +320,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")

return diffusers_name


Expand All @@ -331,8 +339,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):


# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
# All credits go to `kohya-ss`.
# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
def _convert_kohya_flux_lora_to_diffusers(state_dict):
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
Expand All @@ -341,7 +348,8 @@ def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):

# scale weight by alpha and dim
rank = down_weight.shape[0]
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here

# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
Expand All @@ -362,7 +370,10 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
sd_lora_rank = down_weight.shape[0]

# scale weight by alpha and dim
alpha = sds_sd.pop(sds_key + ".alpha")
default_alpha = torch.tensor(
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
scale = alpha / sd_lora_rank

# calculate scale_down and scale_up
Expand Down Expand Up @@ -516,10 +527,103 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
f"transformer.single_transformer_blocks.{i}.norm.linear",
)

# TODO: alphas.
def assign_remaining_weights(assignments, source):
for lora_key in ["lora_A", "lora_B"]:
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
for target_fmt, source_fmt, transform in assignments:
target_key = target_fmt.format(lora_key=lora_key)
source_key = source_fmt.format(orig_lora_key=orig_lora_key)
value = source.pop(source_key)
if transform:
value = transform(value)
ait_sd[target_key] = value

if any("guidance_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
)

if any("img_in" in k for k in sds_sd):
assign_remaining_weights(
[
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
],
sds_sd,
)

if any("txt_in" in k for k in sds_sd):
assign_remaining_weights(
[
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
],
sds_sd,
)

if any("time_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
)

if any("vector_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
)

if any("final_layer" in k for k in sds_sd):
# Notice the swap in processing for "final_layer".
assign_remaining_weights(
[
(
"norm_out.linear.{lora_key}.weight",
"lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight",
swap_scale_shift,
),
("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None),
],
sds_sd,
)

remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
if not all(k.startswith("lora_te") for k in remaining_keys):
if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
Expand Down Expand Up @@ -680,10 +784,98 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
if has_peft_state_dict:
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
return state_dict

# Another weird one.
has_mixture = any(
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
)

# ComfyUI.
if not has_mixture:
state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}

has_position_embedding = any("position_embedding" in k for k in state_dict)
if has_position_embedding:
zero_status_pe = state_dict_all_zero(state_dict, "position_embedding")
if zero_status_pe:
logger.info(
"The `position_embedding` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible."
)

else:
logger.info(
"The state_dict has position_embedding LoRA params and we currently do not support them. "
"Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k}

has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
if has_t5xxl:
zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl")
if zero_status_t5:
logger.info(
"The `t5xxl` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible."
)
else:
logger.info(
"T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
"Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}

has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
if has_diffb:
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
if zero_status_diff_b:
logger.info(
"The `diff_b` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible."
)
else:
logger.info(
"`diff_b` keys found in the state dict which are currently unsupported. "
"So, we will filter out those keys. Open an issue if this is a problem - "
"https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if ".diff_b" not in k}

has_norm_diff = any(".norm" in k and ".diff" in k for k in state_dict)
if has_norm_diff:
zero_status_diff = state_dict_all_zero(state_dict, ".diff")
if zero_status_diff:
logger.info(
"The `diff` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible."
)
else:
logger.info(
"Normalization diff keys found in the state dict which are currently unsupported. "
"So, we will filter out those keys. Open an issue if this is a problem - "
"https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if ".norm" not in k and ".diff" not in k}

limit_substrings = ["lora_down", "lora_up"]
if any("alpha" in k for k in state_dict):
limit_substrings.append("alpha")

state_dict = {
_custom_replace(k, limit_substrings): v
for k, v in state_dict.items()
if k.startswith(("lora_unet_", "lora_te_"))
}

if any("text_projection" in k for k in state_dict):
logger.info(
"`text_projection` keys found in the `state_dict` which are unexpected. "
"So, we will filter out those keys. Open an issue if this is a problem - "
"https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k}

if has_mixture:
return _convert_mixture_state_dict_to_diffusers(state_dict)

Expand Down Expand Up @@ -798,6 +990,26 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
return new_state_dict


def _custom_replace(key: str, substrings: List[str]) -> str:
# Replaces the "."s with "_"s upto the `substrings`.
# Example:
# lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
pattern = "(" + "|".join(re.escape(sub) for sub in substrings) + ")"

match = re.search(pattern, key)
if match:
start_sub = match.start()
if start_sub > 0 and key[start_sub - 1] == ".":
boundary = start_sub - 1
else:
boundary = start_sub
left = key[:boundary].replace(".", "_")
right = key[boundary:]
return left + right
else:
return key.replace(".", "_")


def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
converted_state_dict = {}
original_state_dict_keys = list(original_state_dict.keys())
Expand All @@ -806,11 +1018,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
inner_dim = 3072
mlp_ratio = 4.0

def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight

for lora_key in ["lora_A", "lora_B"]:
## time_text_embed.timestep_embedder <- time_in
converted_state_dict[
Expand Down
55 changes: 11 additions & 44 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,59 +58,24 @@
}


def _maybe_adjust_config(config):
"""
We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
(`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
method removes the ambiguity by following what is described here:
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
"""
# Track keys that have been explicitly removed to prevent re-adding them.
deleted_keys = set()

def _maybe_raise_error_for_ambiguity(config):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change but in this case, I would very much prefer this as otherwise it is becoming increasingly difficult and cumbersome to support LoRAs of the world.

rank_pattern = config["rank_pattern"].copy()
target_modules = config["target_modules"]
original_r = config["r"]

for key in list(rank_pattern.keys()):
key_rank = rank_pattern[key]

# try to detect ambiguity
# `target_modules` can also be a str, in which case this loop would loop
# over the chars of the str. The technically correct way to match LoRA keys
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
# But this cuts it for now.
exact_matches = [mod for mod in target_modules if mod == key]
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
ambiguous_key = key

if exact_matches and substring_matches:
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
config["r"] = key_rank
# remove the ambiguous key from `rank_pattern` and record it as deleted
del config["rank_pattern"][key]
deleted_keys.add(key)
# For substring matches, add them with the original rank only if they haven't been assigned already
for mod in substring_matches:
if mod not in config["rank_pattern"] and mod not in deleted_keys:
config["rank_pattern"][mod] = original_r

# Update the rest of the target modules with the original rank if not already set and not deleted
for mod in target_modules:
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
config["rank_pattern"][mod] = original_r

# Handle alphas to deal with cases like:
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
if has_different_ranks:
config["lora_alpha"] = config["r"]
alpha_pattern = {}
for module_name, rank in config["rank_pattern"].items():
alpha_pattern[module_name] = rank
config["alpha_pattern"] = alpha_pattern

return config
if is_peft_version("<", "0.14.1"):
raise ValueError(
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
)


class PeftAdapterMixin:
Expand Down Expand Up @@ -286,16 +251,18 @@ def load_lora_adapter(
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
# Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1:
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
rank[key] = val.shape[1]
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
# We may run into some ambiguous configuration values when a model has module
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
# for example) and they have different LoRA ranks.
rank[f"^{key}"] = val.shape[1]

if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}

lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
Expand Down
Loading
Loading