Skip to content

Hotswapping multiple LoRAs throws a peft key error. #11298

Open
@jonluca

Description

@jonluca

Describe the bug

When trying to hotswap multiple flux loras you get a runtime error around unexpected keys

RuntimeError: Hot swapping the adapter did not succeed, unexpected keys found: transformer_blocks.13.norm1.linear.lora_B.weight,

Reproduction

Download two Flux Dev loras (this example uses http://base-weights.weights.com/cm9dm38e4061uon15341k47ss.zip and http://base-weights.weights.com/cm9dnj1840088n214rn9uych4.zip)

Unzip and load the safetensors into memory

import time
import torch
import logging
from diffusers import FluxPipeline

logger = logging.get_logger(__name__)



class DownloadedLora:
    def __init__(self, state_dict):
        self.state_dict = state_dict

    @property
    def model(self):
        state_dict = self.state_dict
        # return a clone
        # of the state dict to avoid modifying the original
        new_state_dict = {}
        for k, v in state_dict.items():
            new_state_dict[k] = v.clone().detach()
        return new_state_dict


def test_lora_hotswap():
    logger.info(f"Initializing flux model")

    # todo - compile https://github.com/huggingface/diffusers/pull/9453 when this gets merged
    flux_base_model: FluxPipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
    )
    flux_base_model = flux_base_model.to("cuda")
    flux_base_model.enable_lora_hotswap(target_rank=128)

    # download and set the state dicts of two random loras
    first_lora = DownloadedLora(state_dict=first_state_dict)
    second_lora = DownloadedLora(state_dict=second_state_dict)

    # we need to load three loras as that is the limit of what we support - each name is "1", "2", "3"
    # these will then be enabled or disabled

    flux_base_model.load_lora_weights(first_lora.model, adapter_name="1")
    flux_base_model.load_lora_weights(second_lora.model, adapter_name="2")
    flux_base_model.load_lora_weights(second_lora.model, adapter_name="3")

    logger.info("Initialized base flux model")
    should_compile = False
    if should_compile:
        flux_base_model.image_encoder = torch.compile(flux_base_model.image_encoder)
        flux_base_model.text_encoder = torch.compile(flux_base_model.text_encoder)
        flux_base_model.text_encoder_2 = torch.compile(flux_base_model.text_encoder_2)
        flux_base_model.vae = torch.compile(flux_base_model.vae)
        flux_base_model.transformer = torch.compile(
            flux_base_model.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True
        )

    for i in range(5):
        start_time = time.time()
        image = flux_base_model("An image of a cat", num_inference_steps=4, guidance_scale=3.0).images[0]
        if i == 0:
            logger.info(f"Warmup: {time.time() - start_time}")
        else:
            logger.info(f"Inference time: {time.time() - start_time}")

        utc_seconds = int(time.time())

        image.save(f"hotswap_{utc_seconds}.png")

        if i == 1:
            logger.info("Hotswapping lora one")
            flux_base_model.load_lora_weights(first_lora.model, adapter_name="1", hotswap=True)
        if i == 2:
            logger.info("Hotswapping lora two")
            flux_base_model.load_lora_weights(second_lora.model, adapter_name="2", hotswap=True)
            flux_base_model.load_lora_weights(first_lora.model, adapter_name="1", hotswap=True)

Logs

2025-04-12 04:47:18 | INFO     | Initialized base flux model
100%|██████████| 4/4 [00:01<00:00,  3.64it/s]
2025-04-12 04:47:21 | INFO     | Warmup: 2.4211995601654053
100%|██████████| 4/4 [00:01<00:00,  3.79it/s]
2025-04-12 04:47:23 | INFO     | Inference time: 1.2886595726013184
2025-04-12 04:47:23 | INFO     | Hotswapping lora one
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/team/replay/python/hosted/utils/testing.py", line 708, in <module>
    main()
  File "/home/team/replay/python/hosted/utils/testing.py", line 704, in main
    test_lora_hotswap()
  File "/home/team/replay/python/hosted/utils/testing.py", line 667, in test_lora_hotswap
    flux_base_model.load_lora_weights(first_lora.model, adapter_name="1", hotswap=True)
  File "/home/team/.local/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py", line 1808, in load_lora_weights
    self.load_lora_into_transformer(
  File "/home/team/.local/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py", line 1899, in load_lora_into_transformer
    transformer.load_lora_adapter(
  File "/home/team/.local/lib/python3.11/site-packages/diffusers/loaders/peft.py", line 371, in load_lora_adapter
    hotswap_adapter_from_state_dict(
  File "/home/team/.local/lib/python3.11/site-packages/peft/utils/hotswap.py", line 431, in hotswap_adapter_from_state_dict
    raise RuntimeError(msg)
RuntimeError: Hot swapping the adapter did not succeed, unexpected keys found: transformer_blocks.14.ff.net.0.proj.lora_B.weight, single_transformer_blocks.7.attn.to_v.lora_B.weight, ...

System Info

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux-5.10.0-34-cloud-amd64-x86_64-with-glibc2.31
  • Running on Google Colab?: No
  • Python version: 3.11.11
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.28.1
  • Transformers version: 4.50.3
  • Accelerate version: 1.6.0
  • PEFT version: 0.15.0
  • Bitsandbytes version: 0.45.3
  • Safetensors version: 0.5.3
  • xFormers version: 0.0.29.post3
  • Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@sayakpaul @yiyixuxu

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinglora

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions