Open
Description
Describe the bug
When trying to hotswap multiple flux loras you get a runtime error around unexpected keys
RuntimeError: Hot swapping the adapter did not succeed, unexpected keys found: transformer_blocks.13.norm1.linear.lora_B.weight,
Reproduction
Download two Flux Dev loras (this example uses http://base-weights.weights.com/cm9dm38e4061uon15341k47ss.zip and http://base-weights.weights.com/cm9dnj1840088n214rn9uych4.zip)
Unzip and load the safetensors into memory
import time
import torch
import logging
from diffusers import FluxPipeline
logger = logging.get_logger(__name__)
class DownloadedLora:
def __init__(self, state_dict):
self.state_dict = state_dict
@property
def model(self):
state_dict = self.state_dict
# return a clone
# of the state dict to avoid modifying the original
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k] = v.clone().detach()
return new_state_dict
def test_lora_hotswap():
logger.info(f"Initializing flux model")
# todo - compile https://github.com/huggingface/diffusers/pull/9453 when this gets merged
flux_base_model: FluxPipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
)
flux_base_model = flux_base_model.to("cuda")
flux_base_model.enable_lora_hotswap(target_rank=128)
# download and set the state dicts of two random loras
first_lora = DownloadedLora(state_dict=first_state_dict)
second_lora = DownloadedLora(state_dict=second_state_dict)
# we need to load three loras as that is the limit of what we support - each name is "1", "2", "3"
# these will then be enabled or disabled
flux_base_model.load_lora_weights(first_lora.model, adapter_name="1")
flux_base_model.load_lora_weights(second_lora.model, adapter_name="2")
flux_base_model.load_lora_weights(second_lora.model, adapter_name="3")
logger.info("Initialized base flux model")
should_compile = False
if should_compile:
flux_base_model.image_encoder = torch.compile(flux_base_model.image_encoder)
flux_base_model.text_encoder = torch.compile(flux_base_model.text_encoder)
flux_base_model.text_encoder_2 = torch.compile(flux_base_model.text_encoder_2)
flux_base_model.vae = torch.compile(flux_base_model.vae)
flux_base_model.transformer = torch.compile(
flux_base_model.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True
)
for i in range(5):
start_time = time.time()
image = flux_base_model("An image of a cat", num_inference_steps=4, guidance_scale=3.0).images[0]
if i == 0:
logger.info(f"Warmup: {time.time() - start_time}")
else:
logger.info(f"Inference time: {time.time() - start_time}")
utc_seconds = int(time.time())
image.save(f"hotswap_{utc_seconds}.png")
if i == 1:
logger.info("Hotswapping lora one")
flux_base_model.load_lora_weights(first_lora.model, adapter_name="1", hotswap=True)
if i == 2:
logger.info("Hotswapping lora two")
flux_base_model.load_lora_weights(second_lora.model, adapter_name="2", hotswap=True)
flux_base_model.load_lora_weights(first_lora.model, adapter_name="1", hotswap=True)
Logs
2025-04-12 04:47:18 | INFO | Initialized base flux model
100%|██████████| 4/4 [00:01<00:00, 3.64it/s]
2025-04-12 04:47:21 | INFO | Warmup: 2.4211995601654053
100%|██████████| 4/4 [00:01<00:00, 3.79it/s]
2025-04-12 04:47:23 | INFO | Inference time: 1.2886595726013184
2025-04-12 04:47:23 | INFO | Hotswapping lora one
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/team/replay/python/hosted/utils/testing.py", line 708, in <module>
main()
File "/home/team/replay/python/hosted/utils/testing.py", line 704, in main
test_lora_hotswap()
File "/home/team/replay/python/hosted/utils/testing.py", line 667, in test_lora_hotswap
flux_base_model.load_lora_weights(first_lora.model, adapter_name="1", hotswap=True)
File "/home/team/.local/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py", line 1808, in load_lora_weights
self.load_lora_into_transformer(
File "/home/team/.local/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py", line 1899, in load_lora_into_transformer
transformer.load_lora_adapter(
File "/home/team/.local/lib/python3.11/site-packages/diffusers/loaders/peft.py", line 371, in load_lora_adapter
hotswap_adapter_from_state_dict(
File "/home/team/.local/lib/python3.11/site-packages/peft/utils/hotswap.py", line 431, in hotswap_adapter_from_state_dict
raise RuntimeError(msg)
RuntimeError: Hot swapping the adapter did not succeed, unexpected keys found: transformer_blocks.14.ff.net.0.proj.lora_B.weight, single_transformer_blocks.7.attn.to_v.lora_B.weight, ...
System Info
- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-5.10.0-34-cloud-amd64-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.11.11
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.28.1
- Transformers version: 4.50.3
- Accelerate version: 1.6.0
- PEFT version: 0.15.0
- Bitsandbytes version: 0.45.3
- Safetensors version: 0.5.3
- xFormers version: 0.0.29.post3
- Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
- Using GPU in script?:
- Using distributed or parallel set-up in script?: