Skip to content

[LoRA] restrict certain keys to be checked for peft config update. #10808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 24, 2025

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Feb 17, 2025

What does this PR do?

Fixes: #10804

Relies on huggingface/peft#2382

Code:

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to("cuda")

pipe.load_lora_weights(
    "sayakpaul/different-lora-from-civitai", 
    weight_name="wow_details.safetensors", 
    adapter_name="wow_details"
)

prompt = "a tiny astronaut hatching from an egg on the moon"
out = pipe(
    prompt=prompt,
    guidance_scale=3.5,
    height=1024,
    width=1024,
    num_inference_steps=25,
).images[0]
out.save("image.png")

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@hlky
Copy link
Contributor

hlky commented Feb 17, 2025

The original peft config is like

{
    "r": 4,
    "lora_alpha": 24,
    "rank_pattern": {
        "transformer_blocks.0.attn.to_out.0": 24,
        "transformer_blocks.0.attn.to_q": 24,
        "transformer_blocks.0.attn.to_k": 24,
        "transformer_blocks.0.attn.to_v": 24,
        "transformer_blocks.0.attn.to_add_out": 24,
...

rank_pattern only contains keys with different rank 24 and r is 4.

After adjustment

{
    "r": 24,
    "lora_alpha": 24,
    "rank_pattern": {
        "transformer_blocks.0.attn.to_out.0": 24,
        "transformer_blocks.0.attn.to_add_out": 24,
...
        "single_transformer_blocks.0.attn.to_q": 4,
        "single_transformer_blocks.30.proj_mlp": 4,
        "single_transformer_blocks.32.attn.to_k": 4,
...

r is 24, problem keys transformer_blocks.0.attn.to_q etc get removed, and all other target_modules are added to rank_pattern with rank 4 which also restores transformer_blocks.0.attn.to_q etc with rank 4 on later iterations and causes the reported issue.

Similar for control loras that this was added for

{
    "r": 128,
    "lora_alpha": 128,
    "rank_pattern": {
        "proj_out": 64
    },

Becomes

{
    "r": 64,
    "lora_alpha": 128,
    "rank_pattern": {
        "single_transformer_blocks.33.proj_out": 128,
        "single_transformer_blocks.24.proj_out": 128,
        "single_transformer_blocks.21.proj_out": 128,
...

Except in that case, "proj_out" is not added to rank_pattern, so it will use rank 64 from r instead.

It seems like we already have the correct rank from the original config, no? In the original case are adjusting the config to have rank_pattern define all 128 rank and changing r to 64 when we had rank_pattern with proj_out: 64 and everything else would be 128, and for this case it's the same, we have all rank 24 defined in rank_pattern and r 4 for everything else.

https://github.com/huggingface/peft/blob/1e2d6b5832401e07e917604dfb080ec474818f2b/src/peft/utils/other.py#L770-L772

peft is treating rank_pattern as a pattern but we aren't using it as a pattern, when we say proj_out in the original config that's the whole key

If we change get_pattern_key to rf"^{key}$" then black-forest-labs/FLUX.1-Canny-dev-lora works without _maybe_adjust_config and the lora from this issue.

@sayakpaul
Copy link
Member Author

Your deducations are correct, thank you!

Help me understand this:

peft is treating rank_pattern as a pattern but we aren't using it as a pattern, when we say proj_out in the original config that's the whole key

We infer this in get_peft_kwargs(). There's no direct way for us to specify a rank_pattern which is why the adjustment bit was added during Control LoRA.

If we change get_pattern_key to rf"^{key}$" then black-forest-labs/FLUX.1-Canny-dev-lora works without _maybe_adjust_config and the lora from this issue.

Feel free to push it directly to my branch and we can run the tests to confirm nothing broke.

@hlky
Copy link
Contributor

hlky commented Feb 17, 2025

It would be a change in peft (and removing _maybe_adjust_config here). Currently with black-forest-labs/FLUX.1-Canny-dev-lora before _maybe_adjust_config the peft config defines the default rank r as 128 and has "rank_pattern": {"proj_out": 64}. When replacing modules peft checks rank_pattern with regex .*\.{key}$ (any character zero or more times, . and ends with key) so proj_out is ambiguous to peft, but it's the full key that we want to be rank 64 so we adjust the config essentially reversing default and rank pattern, defining full keys of everything we want to be a non-default rank. So I'm wondering if this could be an option from peft to use the passed keys in rank_pattern as entire keys not suffixes, and how common it is to pass a suffix/actual pattern in rank_pattern.

As another example for this lora we originally have default 4 and "transformer_blocks.0.attn.to_q" is 24 but ambiguous, it could be "single_transformer_blocks.0.attn.to_q", we make the default 24 and define "single_transformer_blocks.0.attn.to_q" instead, then there's a bug which adds transformer_blocks.0.attn.to_q again with 4. Skipping to_q fixes the issue but it could happen with other keys. If we can get peft to use transformer_blocks.0.attn.to_q as the whole key then we don't need to adjust the config.

https://github.com/huggingface/peft/blob/1e2d6b5832401e07e917604dfb080ec474818f2b/src/peft/tuners/lora/model.py#L175-L190
https://github.com/huggingface/peft/blob/1e2d6b5832401e07e917604dfb080ec474818f2b/src/peft/utils/other.py#L770-L772

@sayakpaul
Copy link
Member Author

but it's the full key that we want to be rank 64 so we adjust the config essentially reversing default and rank pattern, defining full keys of everything we want to be a non-default rank.

This was done with suggestions from @BenjaminBossan here:
#9985 (comment)

So I'm wondering if this could be an option from peft to use the passed keys in rank_pattern as entire keys not suffixes, and how common it is to pass a suffix/actual pattern in rank_pattern.

I think it's a bit hard to gauge but I will let @BenjaminBossan comment further. In my experience, I would say this is fairly divided.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Feb 17, 2025

Always this cursèd problem of missing metadata :-/

r is 24, problem keys transformer_blocks.0.attn.to_q etc get removed, and all other target_modules are added to rank_pattern with rank 4 which also restores transformer_blocks.0.attn.to_q etc with rank 4 on later iterations and causes the reported issue.

In general, I think we have to accept that without metadata, we cannot guarantee that the identical config is restored, but it should of course still lead to the same adapter model. Do we know why transformer_blocks.0.attn.to_q is being removed? If we can prevent that, would it solve the problem?

peft is treating rank_pattern as a pattern but we aren't using it as a pattern, when we say proj_out in the original config that's the whole key
If we change get_pattern_key to rf"^{key}$" then black-forest-labs/FLUX.1-Canny-dev-lora works without _maybe_adjust_config and the lora from this issue.

We can't change the base logic here in PEFT, as it could potentially break existing code. Theoretically, we could think about having a flag on LoraConfig that determines how to parse the rank_pattern keys, but I think this would be overly specific. Therefore, I'm considering another option: Maybe we can add a special prefix to a key to "mark" it as being the full key. Something along the lines of:

FULLY_QUALIFIED_KEY_PREFIX = "FULL-NAME-"

def get_pattern_key(pattern_keys, key_to_match):
    """Match a substring of key_to_match in pattern keys"""
    for key in pattern_keys:
        if key.startswith(FULLY_QUALIFIED_KEY_PREFIX) and key[len(FULLY_QUALIFIED_KEY_PREFIX):] == key_to_match:
            return key

    return next(filter(lambda key: re.match(rf".*\.{key}$", key_to_match), pattern_keys), key_to_match)

Diffusers would have to be adjusted to use this new prefix. We would also need to ensure that the PEFT and diffusers version match, otherwise adding the prefix in diffusers would result in no match on the PEFT side. WDYT, would it solve the issue? I'm also open to different proposals.

@sayakpaul
Copy link
Member Author

Thanks for the discussions! Appreciate that.

Do we know why transformer_blocks.0.attn.to_q is being removed? If we can prevent that, would it solve the problem?

It's this line that is the culprit I think:

exact_matches = [mod for mod in target_modules if mod == key]

Diffusers would have to be adjusted to use this new prefix. We would also need to ensure that the PEFT and diffusers version match, otherwise adding the prefix in diffusers would result in no match on the PEFT side. WDYT, would it solve the issue? I'm also open to different proposals.

This sounds like a bit too much of diffusers-specific changes on the peft side whereas a simple fix (for now) exists. Also, note that these changes are usually motivated by community checkpoints which we cannot always control, hence the guesswork.

@BenjaminBossan
Copy link
Member

This sounds like a bit too much of diffusers-specific changes on the peft side whereas a simple fix (for now) exists. Also, note that these changes are usually motivated by community checkpoints which we cannot always control, hence the guesswork.

I wouldn't say it's diffusers-specific, there could be other model architectures where users may want to add a specific rank pattern but where the target module key is not unique. Therefore, I think a more general solution would be preferable, even if it's more complicated.

_NO_CONFIG_UPDATE_KEYS = ["to_k", "to_q", "to_v"]

Isn't this error prone?

@sayakpaul
Copy link
Member Author

It is error-prone, yes.

FULLY_QUALIFIED_KEY_PREFIX

How do we decide a reasonable default for this variable?

API-wise could you walk me through the changes that might needed on the diffusers side? Would like to gauge the changes a bit. get_pattern_key() isn't used in the diffusers codebase. Or maybe I am not fully understanding the changes on the diffusers side.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Feb 17, 2025

API-wise could you walk me through the changes that might needed on the diffusers side? Would like to gauge the changes a bit. get_pattern_key() isn't used in the diffusers codebase. Or maybe I am not fully understanding the changes on the diffusers side.

If I'm not missing something, the change would be that instead of creating a rank_pattern dict like

{key0, rank0, key1: rank1, ...}

the dict would be

{FULLY_QUALIFIED_KEY_PREFIX + key0: rank0, FULLY_QUALIFIED_KEY_PREFIX + key1: rank1, ...}

This would signal to PEFT that we should consider this key to be the full key, not a pattern to match.

I'm not sure if this could lead to bad performance, given that rank_pattern could be quite large and get_pattern_key would be doing more work on the PEFT side, but I think not (it's not quadratic or worse).

@sayakpaul
Copy link
Member Author

Thanks. I am still wondering about a realistic example for FULLY_QUALIFIED_KEY_PREFIX. What could that be in this case of Flux? Also, wondering if this should be conditioned in diffusers based on some configuration value or should it always be the case.

@BenjaminBossan
Copy link
Member

If I understood the issue correctly, it is caused by trying to set a special rank pattern for a certain key but the same key (to_q) is also matching another key because PEFT interprets it as a pattern. Thus if diffusers would let PEFT know that it's actually the full name and not a pattern, the false match could be avoided.

@BenjaminBossan
Copy link
Member

To be clear, if you think the proposed solution in this PR is good enough to work for most users, I'd also be fine with going that way and not having to change PEFT, as it's less work overall :)

@sayakpaul
Copy link
Member Author

If I understood the issue correctly, it is caused by trying to set a special rank pattern for a certain key but the same key (to_q) is also matching another key because PEFT interprets it as a pattern. Thus if diffusers would let PEFT know that it's actually the full name and not a pattern, the false match could be avoided.

I see. This can work. Will just have to figure out the right API calls inside of diffusers when the PEFT support has landed. I think this is much better as a solution.

To be clear, if you think the proposed solution in this PR is good enough to work for most users, I'd also be fine with going that way and not having to change PEFT, as it's less work overall :)

My solution was admittedly a dirty hack. But after the discussions, I am leaning towards your proposal.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Feb 17, 2025
See huggingface/diffusers#10808 for context.

Right now, if we have a key in rank_pattern or alpha_pattern, it is
interpreted as a pattern to be matched against the module names of the
model (basically it is an endswith match). The issue with this logic is
that we may sometimes have false matches. E.g. if we have a model with
modules "foo" and "bar", and if "bar" also has a sub-module called
"foo", it is impossible to target just the outer "foo" but not
"bar.foo". (Note: It is already possible target only "bar.foo" and not
"foo" by using "bar.foo" as key)

This PR adds the possibility to indicate to PEFT that a key should be
considered to be the "fully qualified" key, i.e. a strict match should
be performed instead of a pattern match. For this, users need to prefix
the string "FULL-NAME-" before the actual key. In our example, that
would be "FULL-NAME-foo".

Notice that since the prefix contains "-", there is no possibility that
it could accidentally match a valid module name.
@BenjaminBossan
Copy link
Member

I created a draft PR to implement this: huggingface/peft#2382. Let's check that, with the right adjustments on the diffusers code, this is solving the original issue before proceeding.

@hlky
Copy link
Contributor

hlky commented Feb 17, 2025

Always this cursèd problem of missing metadata :-/

There is no missing metadata, right? we're detecting the correct default rank and "fully qualified" keys to set as a different rank. Adding a magic key seems like a weird solution to me.

@BenjaminBossan
Copy link
Member

There is no missing metadata, right? we're detecting the correct default rank and "fully qualified" keys to set as a different rank.

I just meant that in general, many LoRA training frameworks don't save any LoraConfig-equivalent metadata with the checkpoint, which means that it needs to be inferred.

Adding a magic key seems like a weird solution to me.

It's not beautiful, but what would be your suggested solution?

@hlky
Copy link
Contributor

hlky commented Feb 17, 2025

it needs to be inferred

I think it's being inferred for both loras mentioned in this issue. Are there known circumstances when the metadata cannot be inferred correctly?

what would be your suggested solution

Maybe just a flag in LoraConfig or a separate list for fully qualified keys. As far as I can tell we're always using fully qualified keys with rank_pattern (and target_modules) from diffusers.

@sayakpaul
Copy link
Member Author

Are there known circumstances when the metadata cannot be inferred correctly?

As far as I can tell from my experience of dealing with several LoRA checkpoints from the community, metadata (rank, alpha, etc.) can always be inferred from the LoRA checkpoints.

As far as I can tell we're always using fully qualified keys with rank_pattern (and target_modules) from diffusers.

Yes, this is true, especially when we're loading non-diffusers LoRA checkpoints. However, I think @BenjaminBossan wanted a more general solution to deal with this in peft:

@BenjaminBossan, in huggingface/peft#2382, what if we add a flag instead like @hlky suggested. What would be the downside of it?

@hlky
Copy link
Contributor

hlky commented Feb 18, 2025

It seems there are 2 separate paths for peft/LoraConfig then, one for "training" where we may want to specify a pattern/suffix like proj_out and "inference" when we need to use inferred keys as-is. Enforcing a regex makes "training" somewhat easier but has limitations/impossible configs and makes "inference" awkward. Could a future major version consider accepting a regex directly, that would support both cases, we'd pass .*\.proj_out$ ourselves if we wanted .proj_out suffix and just proj_out if we wanted fully qualified key proj_out.

A flag would feel more natural but I have no major issues with a magic key prefix.

@sayakpaul
Copy link
Member Author

I can then base this PR off of huggingface/peft#2382 until the regex flag is landed in the future. Is that fine by @hlky @BenjaminBossan?

@sayakpaul sayakpaul marked this pull request as draft February 18, 2025 08:21
@sayakpaul
Copy link
Member Author

@BenjaminBossan

One more thing that came to my mind is that if we make these changes, no matter if it's the magic prefix or the flag solution, could we run into trouble for other packages that load diffusers LoRA checkpoints and need to be aware of the change? I suppose not, as they probably use other formats than PEFT, but I just wanted to bring up the possibility.

Well, we're going through hurdles because the checkpoint under consideration is a non-diffusers and non-peft checkpoint. So, I think we're good for now.

LMK if the current changes are good for you.

@hlky

(note: @sayakpaul a possible fix for the issue in that function is storing deleted keys to ensure they aren't re-added by other iterations)

Thanks for the headsup, yeah I thought of about it too but didn't reflect on it here as we had already found a better solution.

@BenjaminBossan
Copy link
Member

WDYM? We couldn't just do something like this?

What I meant is that it would be cool if we could do:

config = LoraConfig(..., rank_pattern={"my-string-key": 8, re.compile("my-regex-key"): 16})

That way, we would inverse control so that users can decide how they want to match if the pass a regex as key. However, we would not be able to JSON-serialize this config.

If not a flag, an additional fully_qualified_rank_keys would work and support 2.

You mean something like:

config = LoraConfig(..., rank_pattern={"foo": 8, "bar": 16}, fully_qualified_rank_keys=["foo"])?

I'm not a big fan of this pattern, as it requires 2 independent data structures to be in sync. It's easier to make mistake as a user and PEFT would need to implement checks for consistency.

However, I don't think this will be a particularly common use case.

I agree.

WDYM? It's not checkpoint specific

I mean if this adapter is saved in the PEFT format, it will contain the magic string. If anyone wants to load this without PEFT, they need to be aware of that. As mentioned, I don't think it's a big deal but I just wanted to mention it in case I'm missing something.

LMK if the current changes are good for you.

_FULL_NAME_PREFIX_FOR_PEFT = "FULL-NAME"

In the PEFT PR, I went with "FULL-NAME-". But we can also agree on a different prefix, I don't care too much. If we agree, I can merge the PEFT PR and the prefix can be imported anyway. On the diffusers side, I think we would need a PEFT version guard though, right?

@hlky
Copy link
Contributor

hlky commented Feb 18, 2025

it will contain the magic string

Oh, we will need to handle that then.

@sayakpaul
Copy link
Member Author

sayakpaul commented Feb 19, 2025

In the PEFT PR, I went with "FULL-NAME-". But we can also agree on a different prefix, I don't care too much. If we agree, I can merge the PEFT PR and the prefix can be imported anyway. On the diffusers side, I think we would need a PEFT version guard though, right?

Yeah, I didn't version-guard yet. I think we may have to support both codepaths based on the peft version being used:

  • Adjusting the LoRA configs
  • Current updates

Do you see any easy way other than this?

Or we could try to import the prefix constant first from PEFT and in the except block just define it ourselves. This is of course still a bit risky, but I don't see an easy way out either.

@BenjaminBossan
Copy link
Member

Hmm, I would have added a check for PEFT > 0.14.0 and

  • if True, add the prefix
  • if False, either:
    • don't add the prefix and see what happens
    • don't add the prefix and use your original workaround

@sayakpaul
Copy link
Member Author

don't add the prefix and see what happens

Well, it would certainly fail without _maybe_adjust_lora_config(). Is that what you meant?

@BenjaminBossan
Copy link
Member

Well, it would certainly fail without _maybe_adjust_lora_config(). Is that what you meant?

Yes, it would mean that for the error to go away, users would need the next PEFT version. But we could also adopt the workaround for older PEFT versions and the new prefix-based solution if the installed PEFT version supports it.

@sayakpaul
Copy link
Member Author

But we could also adopt the workaround for older PEFT versions and the new prefix-based solution if the installed PEFT version supports it.

Yeah for the older peft version, I will defer to what I had in this PR (constant the attention modules and checking against them in the config update method). Do you have any other ideas?

@BenjaminBossan
Copy link
Member

Do you have any other ideas?

No, unfortunately not :(

@sayakpaul sayakpaul requested a review from hlky February 22, 2025 17:12
@sayakpaul
Copy link
Member Author

@BenjaminBossan ready for another review.

Copy link
Contributor

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -54,6 +54,7 @@
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
}
_NO_CONFIG_UPDATE_KEYS = ["to_k", "to_q", "to_v"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_NO_CONFIG_UPDATE_KEYS = ["to_k", "to_q", "to_v"]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need it when PEFT version doesn't contain the required prefix.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it

a possible fix for the issue in that function is storing deleted keys to ensure they aren't re-added by other iterations

If this issue reoccurs with other keys before minimum PEFT version is increased this can be applied

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check now?

@sayakpaul sayakpaul marked this pull request as ready for review February 22, 2025 18:40
@sayakpaul sayakpaul requested a review from hlky February 22, 2025 18:40
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit of a messy situation, but I think this PR finds a nice solution to the issue. LGTM. If the test script passes both with and without the new feature in PEFT, it's good to be merged.

@@ -71,30 +74,27 @@ def _maybe_adjust_config(config):
key_rank = rank_pattern[key]

# try to detect ambiguity
# `target_modules` can also be a str, in which case this loop would loop
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason why this was removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add this back.

@sayakpaul sayakpaul merged commit b0550a6 into main Feb 24, 2025
15 checks passed
@sayakpaul sayakpaul deleted the restrict-config-update-peft branch February 24, 2025 11:25
@sayakpaul
Copy link
Member Author

@AmericanPresidentJimmyCarter can you check if this solves #10804

@burgalon
Copy link
Contributor

after install diffusers@b0550a66cc3c882a1b88470df7e26103208b13de and trying to load Comfy LoRA (https://civitai.com/models/631986/xlabs-flux-realism-lora) - still getting

ValueError: Incompatible keys detected:

diffusion_model.double_blocks.0.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.0.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.0.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.0.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.0.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.0.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.0.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.0.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.1.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.1.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.1.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.1.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.1.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.1.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.1.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.1.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.10.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.10.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.10.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.10.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.10.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.10.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.10.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.10.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.11.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.11.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.11.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.11.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.11.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.11.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.11.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.11.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.12.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.12.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.12.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.12.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.12.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.12.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.12.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.12.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.13.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.13.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.13.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.13.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.13.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.13.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.13.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.13.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.14.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.14.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.14.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.14.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.14.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.14.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.14.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.14.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.15.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.15.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.15.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.15.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.15.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.15.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.15.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.15.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.16.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.16.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.16.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.16.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.16.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.16.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.16.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.16.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.17.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.17.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.17.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.17.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.17.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.17.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.17.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.17.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.18.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.18.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.18.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.18.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.18.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.18.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.18.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.18.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.2.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.2.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.2.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.2.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.2.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.2.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.2.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.2.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.3.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.3.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.3.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.3.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.3.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.3.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.3.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.3.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.4.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.4.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.4.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.4.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.4.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.4.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.4.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.4.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.5.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.5.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.5.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.5.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.5.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.5.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.5.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.5.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.6.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.6.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.6.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.6.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.6.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.6.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.6.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.6.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.7.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.7.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.7.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.7.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.7.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.7.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.7.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.7.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.8.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.8.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.8.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.8.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.8.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.8.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.8.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.8.txt_attn.qkv.lora_up.weight, diffusion_model.double_blocks.9.img_attn.proj.lora_down.weight, diffusion_model.double_blocks.9.img_attn.proj.lora_up.weight, diffusion_model.double_blocks.9.img_attn.qkv.lora_down.weight, diffusion_model.double_blocks.9.img_attn.qkv.lora_up.weight, diffusion_model.double_blocks.9.txt_attn.proj.lora_down.weight, diffusion_model.double_blocks.9.txt_attn.proj.lora_up.weight, diffusion_model.double_blocks.9.txt_attn.qkv.lora_down.weight, diffusion_model.double_blocks.9.txt_attn.qkv.lora_up.weight

....

@sayakpaul
Copy link
Member Author

Could you provide a new issue thread as I just ran the xlabs integration tests and they are passing:

def test_flux_xlabs(self):

def test_flux_xlabs_load_lora_with_single_blocks(self):

(I ensured I pulled in the latest changes in main)

@AmericanPresidentJimmyCarter
Copy link
Contributor

@AmericanPresidentJimmyCarter can you check if this solves #10804

Confirmed that it works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

comfy-ui compatible FLUX1.dev LoRA fails to load
7 participants