Closed
Description
Describe the bug
PR #8483 added support for SD3 lora loading, but it does so a) without any checks, b) documentation on required params.
For SD3, lora_lora_weights triggers new method load_lora_into_transformer
which expects to extract values from kwargs
and without having default values.
Documentation does not mention required kwargs at all and attempting to load lora without them results in runtime error.
for reproduction, download any sd3 early loras available on civitai, for example:
- https://civitai.com/models/512239/pixel-art-medium-128?modelVersionId=569272
- https://civitai.com/models/513204/stable-diffusion-3-famous-folks?modelVersionId=571990
- https://civitai.com/models/513371/school-yearbook-photos-sd3?modelVersionId=570533
Reproduction
import warnings
import torch
import diffusers
import rich.traceback
rich.traceback.install()
warnings.filterwarnings(action="ignore", category=FutureWarning)
cache_dir = '/mnt/models/Diffusers'
pipe = diffusers.StableDiffusion3Pipeline.from_single_file(
'/mnt/models/stable-diffusion/sd3/sd3_medium_incl_clips.safetensors',
torch_dtype = torch.float16,
text_encoder_3 = None,
cache_dir = cache_dir,
)
pipe.load_lora_weights('/mnt/models/Lora/sd3/famous-folks.safetensors')
pipe.to('cuda')
result = pipe(
prompt='A photo of a cat',
width=1024,
height=1024,
)
image = result.images[0]
image.save('test.png')
Logs
Fetching 20 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 3238.09it/s]
Loading pipeline components...: 50%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 4/8 [00:00<00:00, 25.37it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:01<00:00, 4.76it/s]
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/vlado/dev/sdnext/tmp/sd3.py:17 in <module> │
│ │
│ 14 │ text_encoder_3 = None, │
│ 15 │ cache_dir = cache_dir, │
│ 16 ) │
│ ❱ 17 pipe.load_lora_weights('/mnt/models/Lora/sd3/famous-folks.safetensors') │
│ 18 │
│ 19 """ │
│ 20 import transformers │
│ │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/loaders/lora.py:1387 in │
│ load_lora_weights │
│ │
│ 1384 │ │ if not is_correct_format: │
│ 1385 │ │ │ raise ValueError("Invalid LoRA checkpoint.") │
│ 1386 │ │ │
│ ❱ 1387 │ │ self.load_lora_into_transformer( │
│ 1388 │ │ │ state_dict, │
│ 1389 │ │ │ transformer=getattr(self, self.transformer_name) if not hasattr(self, "trans │
│ 1390 │ │ │ adapter_name=adapter_name, │
│ │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/loaders/lora.py:1555 in │
│ load_lora_into_transformer │
│ │
│ 1552 │ │ │ │ if "lora_B" in key: │
│ 1553 │ │ │ │ │ rank[key] = val.shape[1] │
│ 1554 │ │ │ │
│ ❱ 1555 │ │ │ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_sta │
│ 1556 │ │ │ if "use_dora" in lora_config_kwargs: │
│ 1557 │ │ │ │ if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): │
│ 1558 │ │ │ │ │ raise ValueError( │
│ │
│ /home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/utils/peft_utils.py:153 in │
│ get_peft_kwargs │
│ │
│ 150 def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): │
│ 151 │ rank_pattern = {} │
│ 152 │ alpha_pattern = {} │
│ ❱ 153 │ r = lora_alpha = list(rank_dict.values())[0] │
│ 154 │ │
│ 155 │ if len(set(rank_dict.values())) > 1: │
│ 156 │ │ # get the rank occuring the most number of times │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
IndexError: list index out of range
System Info
diffusers==0.30.0.dev (06/15/2024)