Skip to content

SD3 lora support non functional #8579

Closed
@vladmandic

Description

@vladmandic

Describe the bug

PR #8483 added support for SD3 lora loading, but it does so a) without any checks, b) documentation on required params.

For SD3, lora_lora_weights triggers new method load_lora_into_transformer which expects to extract values from kwargs and without having default values.

Documentation does not mention required kwargs at all and attempting to load lora without them results in runtime error.

for reproduction, download any sd3 early loras available on civitai, for example:

Reproduction

import warnings
import torch
import diffusers
import rich.traceback

rich.traceback.install()
warnings.filterwarnings(action="ignore", category=FutureWarning)
cache_dir = '/mnt/models/Diffusers'

pipe = diffusers.StableDiffusion3Pipeline.from_single_file(
    '/mnt/models/stable-diffusion/sd3/sd3_medium_incl_clips.safetensors',
    torch_dtype = torch.float16,
    text_encoder_3 = None,
    cache_dir = cache_dir,
)
pipe.load_lora_weights('/mnt/models/Lora/sd3/famous-folks.safetensors')
pipe.to('cuda')
result = pipe(
    prompt='A photo of a cat',
    width=1024,
    height=1024,
)
image = result.images[0]
image.save('test.png')

Logs

Fetching 20 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 3238.09it/s]
Loading pipeline components...:  50%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                                                                                            | 4/8 [00:00<00:00, 25.37it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  4.76it/s]
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/vlado/dev/sdnext/tmp/sd3.py:17 in <module>                                                 │
│                                                                                                  │
│   14 │   text_encoder_3 = None,                                                                  │
│   15 │   cache_dir = cache_dir,                                                                  │
│   16 )                                                                                           │
│ ❱ 17 pipe.load_lora_weights('/mnt/models/Lora/sd3/famous-folks.safetensors')                     │
│   18                                                                                             │
│   19 """
│   20 import transformers                                                                         │
│                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/loaders/lora.py:1387 in       │
│ load_lora_weights                                                                                │
│                                                                                                  │
│   1384 │   │   if not is_correct_format:                                                         │
│   1385 │   │   │   raise ValueError("Invalid LoRA checkpoint.")                                  │
│   1386 │   │                                                                                     │
│ ❱ 1387 │   │   self.load_lora_into_transformer(                                                  │
│   1388 │   │   │   state_dict,                                                                   │
│   1389 │   │   │   transformer=getattr(self, self.transformer_name) if not hasattr(self, "trans  │
│   1390 │   │   │   adapter_name=adapter_name,                                                    │
│                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/loaders/lora.py:1555 in       │
│ load_lora_into_transformer                                                                       │
│                                                                                                  │
│   1552 │   │   │   │   if "lora_B" in key:                                                       │
│   1553 │   │   │   │   │   rank[key] = val.shape[1]                                              │
│   1554 │   │   │                                                                                 │
│ ❱ 1555 │   │   │   lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_sta  │
│   1556 │   │   │   if "use_dora" in lora_config_kwargs:                                          │
│   1557 │   │   │   │   if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):      │
│   1558 │   │   │   │   │   raise ValueError(                                                     │
│                                                                                                  │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/utils/peft_utils.py:153 in    │
│ get_peft_kwargs                                                                                  │
│                                                                                                  │
│   150 def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):         │
│   151 │   rank_pattern = {}                                                                      │
│   152 │   alpha_pattern = {}                                                                     │
│ ❱ 153 │   r = lora_alpha = list(rank_dict.values())[0]                                           │
│   154 │                                                                                          │
│   155 │   if len(set(rank_dict.values())) > 1:                                                   │
│   156 │   │   # get the rank occuring the most number of times                                   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
IndexError: list index out of range

System Info

diffusers==0.30.0.dev (06/15/2024)

Who can help?

@yiyixuxu @sayakpaul @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions