Skip to content

comfy-ui compatible FLUX1.dev LoRA fails to load #10804

Closed
@AmericanPresidentJimmyCarter

Description

Describe the bug

https://civitai.com/models/677200?modelVersionId=758070

This LoRA fails to load. Specifically, model.load_state_dict fails with mismatched tensors when attempting to load in the LoRA state dict.

Reproduction

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')

pipe.load_lora_weights('/data/models/wow_details.safetensors', adapter_name="wow_details")

prompt = "a tiny astronaut hatching from an egg on the moon"
out = pipe(
    prompt=prompt,
    guidance_scale=3.5,
    height=1024,
    width=1024,
    num_inference_steps=25,
).images[0]
out.save("image.png")

Logs

pipe.load_lora_weights(lora_fn, adapter_name=name, low_cpu_mem_usage=True)
env/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py:1872: in loa
d_lora_weights                                                                  
    self.load_lora_into_transformer(                                            
env/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py:1936: in loa
d_lora_into_transformer                                                         
    transformer.load_lora_adapter(                                              
env/lib/python3.11/site-packages/diffusers/loaders/peft.py:327: in load_lora_ada
pter                                                                            
    incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name
, **peft_kwargs)                                                                
env/lib/python3.11/site-packages/peft/utils/save_and_load.py:432: in set_peft_mo
del_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False)

  File "/root/app/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2584, in load_state_dict                                        raise RuntimeError(                                                         RuntimeError: Error(s) in loading state_dict for FluxTransformer2DModel:        
        size mismatch for transformer_blocks.0.attn.to_q.lora_A.wow_details.weight: copying a param with shape torch.Size([24, 3072]) from checkpoint, the shape in current model is torch.Size([4, 3072]).                                     
        size mismatch for transformer_blocks.0.attn.to_q.lora_B.wow_details.weight: copying a param with shape torch.Size([3072, 24]) from checkpoint, the shape in current model is torch.Size([3072, 4]).                                     
        size mismatch for transformer_blocks.0.attn.to_k.lora_A.wow_details.weight: copying a param with shape torch.Size([24, 3072]) from checkpoint, the shape in current model is torch.Size([4, 3072]).                                     
        size mismatch for transformer_blocks.0.attn.to_k.lora_B.wow_details.weight: copying a param with shape torch.Size([3072, 24]) from checkpoint, the shape in current model is torch.Size([3072, 4]).                                     
        size mismatch for transformer_blocks.0.attn.to_v.lora_A.wow_details.weight: copying a param with shape torch.Size([24, 3072]) from checkpoint, the shape in current model is torch.Size([4, 3072]).                                     
        size mismatch for transformer_blocks.0.attn.to_v.lora_B.wow_details.weight: copying a param with shape torch.Size([3072, 24]) from checkpoint, the shape in current model is torch.Size([3072, 4]).                                     
        size mismatch for transformer_blocks.1.attn.to_q.lora_A.wow_details.wei$ht: copying a param with shape torch.Size([24, 3072]) from checkpoint, the shape in current model is torch.Size([4, 3072]).                                     
        size mismatch for transformer_blocks.1.attn.to_q.lora_B.wow_details.weig
ht: copying a param with shape torch.Size([3072, 24]) from checkpoint, the shape
 in current model is torch.Size([3072, 4]).  
...

System Info

80g A100 VM

Dev env:

[tool.poetry.dependencies]
python = ">=3.10,<3.12"
torch = {version = "^2.5.1+cu124", source = "pytorch"}
torchvision = {version = "^0.20.1+cu124", source = "pytorch"}
diffusers = {git = "https://github.com/huggingface/diffusers.git", rev = "cb342b745aa57798b759c0ba5b80c045a5dafbad"}
transformers = "4.48.2"
datasets = "^3.0.1"
bitsandbytes = "^0.44.1"
wandb = "^0.18.2"
requests = "^2.32.3"
pillow = "^10.4.0"
opencv-python = "^4.10.0.84"
deepspeed = "^0.15.1"
accelerate = "^0.34.2"
safetensors = "^0.4.5"
compel = "^2.0.1"
clip-interrogator = "^0.6.0"
open-clip-torch = "^2.26.1"
iterutils = "^0.1.6"
scipy = "^1.11.1"
boto3 = "^1.35.24"
pandas = "^2.2.3"
botocore = "^1.35.24"
urllib3 = "<1.27"
triton-library = "^1.0.0rc4"
torchsde = "^0.2.5"
torchmetrics = "^1.1.1"
colorama = "^0.4.6"
numpy = "1.26"
peft = {git = "https://github.com/huggingface/peft.git", rev = "1e2d6b5832401e07e917604dfb080ec474818f2b"}
tensorboard = "^2.17.1"
triton = {version = "^3.0.0", source = "pytorch"}
sentencepiece = "^0.2.0"
optimum-quanto = {git = "https://github.com/huggingface/optimum-quanto"}
lycoris-lora = {git = "https://github.com/kohakublueleaf/lycoris", rev = "dev"}
torch-optimi = "^0.2.1"
toml = "^0.10.2"
torchao = {version = "^0.5.0+cu124", source = "pytorch"}
rollbar = "^1.0.0"
runpod = "^1.7.0"
comet-ml = "^3.44.4"
pillow-heif = "^0.18.0"
pillow-avif-plugin = "^1.4.6"
kornia = "^0.7.3"
realesrgan = "^0.3.0"
einops = "^0.8.0"
controlnet-aux = "^0.0.9"
ultralytics = "^8.3.27"
insightface = "^0.7.3"
onnxruntime-gpu = "^1.20.1"
xformers = "^0.0.28.post3"
scepter = "^1.3.1"

Who can help?

@sayakpaul @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions