Skip to content

Commit d3fbd7b

Browse files
[WIP][LoRA] Implement hot-swapping of LoRA
This PR adds the possibility to hot-swap LoRA adapters. It is WIP. Description As of now, users can already load multiple LoRA adapters. They can offload existing adapters or they can unload them (i.e. delete them). However, they cannot "hotswap" adapters yet, i.e. substitute the weights from one LoRA adapter with the weights of another, without the need to create a separate LoRA adapter. Generally, hot-swapping may not appear not super useful but when the model is compiled, it is necessary to prevent recompilation. See #9279 for more context. Caveats To hot-swap a LoRA adapter for another, these two adapters should target exactly the same layers and the "hyper-parameters" of the two adapters should be identical. For instance, the LoRA alpha has to be the same: Given that we keep the alpha from the first adapter, the LoRA scaling would be incorrect for the second adapter otherwise. Theoretically, we could override the scaling dict with the alpha values derived from the second adapter's config, but changing the dict will trigger a guard for recompilation, defeating the main purpose of the feature. I also found that compilation flags can have an impact on whether this works or not. E.g. when passing "reduce-overhead", there will be errors of the type: > input name: arg861_1. data pointer changed from 139647332027392 to 139647331054592 I don't know enough about compilation to determine whether this is problematic or not. Current state This is obviously WIP right now to collect feedback and discuss which direction to take this. If this PR turns out to be useful, the hot-swapping functions will be added to PEFT itself and can be imported here (or there is a separate copy in diffusers to avoid the need for a min PEFT version to use this feature). Moreover, more tests need to be added to better cover this feature, although we don't necessarily need tests for the hot-swapping functionality itself, since those tests will be added to PEFT. Furthermore, as of now, this is only implemented for the unet. Other pipeline components have yet to implement this feature. Finally, it should be properly documented. I would like to collect feedback on the current state of the PR before putting more time into finalizing it.
1 parent 8fcfb2a commit d3fbd7b

File tree

3 files changed

+161
-8
lines changed

3 files changed

+161
-8
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
6363
text_encoder_name = TEXT_ENCODER_NAME
6464

6565
def load_lora_weights(
66-
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
66+
self,
67+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
68+
adapter_name=None,
69+
hotswap: bool = False,
70+
**kwargs,
6771
):
6872
"""
6973
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
@@ -88,6 +92,7 @@ def load_lora_weights(
8892
adapter_name (`str`, *optional*):
8993
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
9094
`default_{i}` where i is the total number of adapters being loaded.
95+
hotswap TODO
9196
"""
9297
if not USE_PEFT_BACKEND:
9398
raise ValueError("PEFT backend is required for this method.")
@@ -109,6 +114,7 @@ def load_lora_weights(
109114
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
110115
adapter_name=adapter_name,
111116
_pipeline=self,
117+
hotswap=hotswap,
112118
)
113119
self.load_lora_into_text_encoder(
114120
state_dict,
@@ -232,7 +238,7 @@ def lora_state_dict(
232238
return state_dict, network_alphas
233239

234240
@classmethod
235-
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
241+
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, hotswap: bool = False):
236242
"""
237243
This will load the LoRA layers specified in `state_dict` into `unet`.
238244
@@ -250,6 +256,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None
250256
adapter_name (`str`, *optional*):
251257
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
252258
`default_{i}` where i is the total number of adapters being loaded.
259+
hotswap TODO
253260
"""
254261
if not USE_PEFT_BACKEND:
255262
raise ValueError("PEFT backend is required for this method.")
@@ -263,7 +270,11 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None
263270
# Load the layers corresponding to UNet.
264271
logger.info(f"Loading {cls.unet_name}.")
265272
unet.load_attn_procs(
266-
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
273+
state_dict,
274+
network_alphas=network_alphas,
275+
adapter_name=adapter_name,
276+
_pipeline=_pipeline,
277+
hotswap=hotswap,
267278
)
268279

269280
@classmethod

src/diffusers/loaders/unet.py

Lines changed: 109 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class UNet2DConditionLoadersMixin:
6666
unet_name = UNET_NAME
6767

6868
@validate_hf_hub_args
69-
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
69+
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], hotswap: bool = False, **kwargs):
7070
r"""
7171
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
7272
defined in
@@ -115,6 +115,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
115115
`default_{i}` where i is the total number of adapters being loaded.
116116
weight_name (`str`, *optional*, defaults to None):
117117
Name of the serialized state dict file.
118+
hotswap TODO
118119
119120
Example:
120121
@@ -209,6 +210,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
209210
network_alphas=network_alphas,
210211
adapter_name=adapter_name,
211212
_pipeline=_pipeline,
213+
hotswap=hotswap,
212214
)
213215
else:
214216
raise ValueError(
@@ -268,7 +270,7 @@ def _process_custom_diffusion(self, state_dict):
268270

269271
return attn_processors
270272

271-
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
273+
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, hotswap: bool = False):
272274
# This method does the following things:
273275
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
274276
# format. For legacy format no filtering is applied.
@@ -299,10 +301,12 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter
299301
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
300302

301303
if len(state_dict_to_be_used) > 0:
302-
if adapter_name in getattr(self, "peft_config", {}):
304+
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
303305
raise ValueError(
304306
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
305307
)
308+
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
309+
raise ValueError(f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name.")
306310

307311
state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)
308312

@@ -336,8 +340,108 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter
336340
# otherwise loading LoRA weights will lead to an error
337341
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
338342

339-
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
340-
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
343+
344+
def _check_hotswap_configs_compatible(config0, config1):
345+
# To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they
346+
# use different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the
347+
# weights from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these
348+
# values as well, but that's not implemented yet, and it would trigger a re-compilation if the model is compiled.
349+
350+
# TODO: This is a very rough check at the moment and there are probably better ways than to error out
351+
config_keys_to_check = ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]
352+
config0 = config0.to_dict()
353+
config1 = config1.to_dict()
354+
for key in config_keys_to_check:
355+
val0 = config0[key]
356+
val1 = config1[key]
357+
if val0 != val1:
358+
raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}")
359+
360+
def _hotswap_adapter_from_state_dict(model, state_dict, adapter_name):
361+
"""
362+
Swap out the LoRA weights from the model with the weights from state_dict.
363+
364+
It is assumed that the existing adapter and the new adapter are compatible.
365+
366+
Args:
367+
model: nn.Module
368+
The model with the loaded adapter.
369+
state_dict: dict[str, torch.Tensor]
370+
The state dict of the new adapter, which needs to be compatible (targeting same modules etc.).
371+
adapter_name: Optional[str]
372+
The name of the adapter that should be hot-swapped.
373+
374+
Raises:
375+
RuntimeError
376+
If the old and the new adapter are not compatible, a RuntimeError is raised.
377+
"""
378+
from operator import attrgetter
379+
380+
#######################
381+
# INSERT ADAPTER NAME #
382+
#######################
383+
384+
remapped_state_dict = {}
385+
expected_str = adapter_name + "."
386+
for key, val in state_dict.items():
387+
if expected_str not in key:
388+
prefix, _, suffix = key.rpartition(".")
389+
key = f"{prefix}.{adapter_name}.{suffix}"
390+
remapped_state_dict[key] = val
391+
state_dict = remapped_state_dict
392+
393+
####################
394+
# CHECK STATE_DICT #
395+
####################
396+
397+
# Ensure that all the keys of the new adapter correspond exactly to the keys of the old adapter, otherwise
398+
# hot-swapping is not possible
399+
parameter_prefix = "lora_" # hard-coded for now
400+
is_compiled = hasattr(model, "_orig_mod")
401+
# TODO: there is probably a more precise way to identify the adapter keys
402+
missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)}
403+
unexpected_keys = set()
404+
405+
# first: dry run, not swapping anything
406+
for key, new_val in state_dict.items():
407+
try:
408+
old_val = attrgetter(key)(model)
409+
except AttributeError:
410+
unexpected_keys.add(key)
411+
continue
412+
413+
if is_compiled:
414+
missing_keys.remove("_orig_mod." + key)
415+
else:
416+
missing_keys.remove(key)
417+
418+
if missing_keys or unexpected_keys:
419+
msg = "Hot swapping the adapter did not succeed."
420+
if missing_keys:
421+
msg += f" Missing keys: {', '.join(sorted(missing_keys))}."
422+
if unexpected_keys:
423+
msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}."
424+
raise RuntimeError(msg)
425+
426+
###################
427+
# ACTUAL SWAPPING #
428+
###################
429+
430+
for key, new_val in state_dict.items():
431+
# no need to account for potential _orig_mod in key here, as torch handles that
432+
old_val = attrgetter(key)(model)
433+
old_val.data = new_val.data.to(device=old_val.device)
434+
# TODO: wanted to use swap_tensors but this somehow does not work on nn.Parameter
435+
# torch.utils.swap_tensors(old_val.data, new_val.data)
436+
437+
if hotswap:
438+
_check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
439+
_hotswap_adapter_from_state_dict(self, state_dict, adapter_name)
440+
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set it to None
441+
incompatible_keys = None
442+
else:
443+
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
444+
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
341445

342446
if incompatible_keys is not None:
343447
# check only for unexpected keys

tests/pipelines/test_pipelines.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
import random
2020
import shutil
21+
import subprocess
2122
import sys
2223
import tempfile
2324
import traceback
@@ -2014,3 +2015,40 @@ def test_ddpm_ddim_equality_batched(self):
20142015

20152016
# the values aren't exactly equal, but the images look the same visually
20162017
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
2018+
2019+
2020+
class TestLoraHotSwapping:
2021+
def test_hotswapping_peft_config_incompatible_raises(self):
2022+
# TODO
2023+
pass
2024+
2025+
def test_hotswapping_no_existing_adapter_raises(self):
2026+
# TODO
2027+
pass
2028+
2029+
def test_hotswapping_works(self):
2030+
# TODO
2031+
pass
2032+
2033+
def test_hotswapping_compiled_model_does_not_trigger_recompilation(self):
2034+
# TODO: kinda slow, should it get a slow marker?
2035+
env = {"TORCH_LOGS": "guards,recompiles"}
2036+
here = os.path.dirname(__file__)
2037+
file_name = os.path.join(here, "run_compiled_model_hotswap.py")
2038+
2039+
process = subprocess.Popen(
2040+
[sys.executable, file_name],
2041+
env=env,
2042+
stdout=subprocess.PIPE,
2043+
stderr=subprocess.PIPE
2044+
)
2045+
2046+
# Communicate will read the output and error streams, preventing deadlock
2047+
stdout, stderr = process.communicate()
2048+
exit_code = process.returncode
2049+
2050+
# sanity check:
2051+
assert exit_code == 0
2052+
2053+
# check that the recompilation message is not present
2054+
assert "__recompiles" not in stderr.decode()

0 commit comments

Comments
 (0)