Skip to content

Commit c738f14

Browse files
committed
update scaling dict
add padding draft update
1 parent d3fbd7b commit c738f14

File tree

1 file changed

+92
-8
lines changed

1 file changed

+92
-8
lines changed

src/diffusers/loaders/unet.py

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
from collections import defaultdict
16+
import collections
1617
from contextlib import nullcontext
1718
from pathlib import Path
1819
from typing import Callable, Dict, Union
@@ -56,6 +57,56 @@
5657
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
5758
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
5859

60+
def pad_lora_weights(state_dict, target_rank):
61+
"""
62+
Pad LoRA weights in a state dict to a target rank while preserving the original behavior.
63+
64+
Args:
65+
state_dict (dict): The state dict containing LoRA weights
66+
target_rank (int): The target rank to pad to
67+
68+
Returns:
69+
new_state_dict: A new state dict with padded LoRA weights
70+
"""
71+
new_state_dict = {}
72+
73+
for key, weight in state_dict.items():
74+
if "lora_A" in key or "lora_B" in key:
75+
is_conv = weight.dim() == 4
76+
77+
if "lora_A" in key:
78+
original_rank = weight.size(0)
79+
if original_rank >= target_rank:
80+
new_state_dict[key] = weight
81+
continue
82+
83+
if is_conv:
84+
padded = torch.zeros(target_rank, weight.size(1), weight.size(2), weight.size(3),
85+
device=weight.device, dtype=weight.dtype)
86+
padded[:original_rank, :, :, :] = weight
87+
else:
88+
padded = torch.zeros(target_rank, weight.size(1), device=weight.device, dtype=weight.dtype)
89+
padded[:original_rank, :] = weight
90+
91+
elif "lora_B" in key:
92+
original_rank = weight.size(1)
93+
if original_rank >= target_rank:
94+
new_state_dict[key] = weight
95+
continue
96+
97+
if is_conv:
98+
padded = torch.zeros(weight.size(0), target_rank, weight.size(2), weight.size(3),
99+
device=weight.device, dtype=weight.dtype)
100+
padded[:, :original_rank, :, :] = weight
101+
else:
102+
padded = torch.zeros(weight.size(0), target_rank, device=weight.device, dtype=weight.dtype)
103+
padded[:, :original_rank] = weight
104+
105+
new_state_dict[key] = padded
106+
else:
107+
new_state_dict[key] = weight
108+
109+
return new_state_dict
59110

60111
class UNet2DConditionLoadersMixin:
61112
"""
@@ -307,19 +358,32 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter
307358
)
308359
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
309360
raise ValueError(f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name.")
310-
361+
362+
def get_rank(state_dict):
363+
rank = {}
364+
for key, val in state_dict.items():
365+
if "lora_B" in key:
366+
rank[key] = val.shape[1]
367+
return rank
368+
369+
def get_r(rank_dict):
370+
r = list(rank_dict.values())[0]
371+
if len(set(rank_dict.values())) > 1:
372+
# get the rank occuring the most number of times
373+
r = collections.Counter(rank_dict.values()).most_common()[0][0]
374+
return r
375+
311376
state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)
377+
r = get_r(get_rank(state_dict))
378+
379+
state_dict = pad_lora_weights(state_dict, 128)
312380

313381
if network_alphas is not None:
314382
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
315383
# `convert_unet_state_dict_to_peft` method.
316384
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
317385

318-
rank = {}
319-
for key, val in state_dict.items():
320-
if "lora_B" in key:
321-
rank[key] = val.shape[1]
322-
386+
rank = get_rank(state_dict)
323387
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
324388
if "use_dora" in lora_config_kwargs:
325389
if lora_config_kwargs["use_dora"]:
@@ -348,7 +412,7 @@ def _check_hotswap_configs_compatible(config0, config1):
348412
# values as well, but that's not implemented yet, and it would trigger a re-compilation if the model is compiled.
349413

350414
# TODO: This is a very rough check at the moment and there are probably better ways than to error out
351-
config_keys_to_check = ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]
415+
config_keys_to_check = ["use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]
352416
config0 = config0.to_dict()
353417
config1 = config1.to_dict()
354418
for key in config_keys_to_check:
@@ -357,6 +421,15 @@ def _check_hotswap_configs_compatible(config0, config1):
357421
if val0 != val1:
358422
raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}")
359423

424+
def _update_scaling(model, adapter_name, scaling_factor=None):
425+
target_modules = model.peft_config[adapter_name].target_modules
426+
for name, lora_module in model.named_modules():
427+
if name in target_modules and hasattr(lora_module, "scaling"):
428+
if not isinstance(lora_module.scaling[adapter_name], torch.Tensor):
429+
lora_module.scaling[adapter_name] = torch.tensor(scaling_factor, device=lora_module.weight.device)
430+
else:
431+
lora_module.scaling[adapter_name].fill_(scaling_factor)
432+
360433
def _hotswap_adapter_from_state_dict(model, state_dict, adapter_name):
361434
"""
362435
Swap out the LoRA weights from the model with the weights from state_dict.
@@ -430,18 +503,29 @@ def _hotswap_adapter_from_state_dict(model, state_dict, adapter_name):
430503
for key, new_val in state_dict.items():
431504
# no need to account for potential _orig_mod in key here, as torch handles that
432505
old_val = attrgetter(key)(model)
433-
old_val.data = new_val.data.to(device=old_val.device)
506+
# print(f" dtype: {old_val.data.dtype}/{new_val.data.dtype}, layout: {old_val.data.layout}/{new_val.data.layout}")
507+
old_val.data.copy_(new_val.data.to(device=old_val.device))
434508
# TODO: wanted to use swap_tensors but this somehow does not work on nn.Parameter
435509
# torch.utils.swap_tensors(old_val.data, new_val.data)
436510

437511
if hotswap:
438512
_check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
513+
self.peft_config[adapter_name] = lora_config
514+
# update r & scaling
515+
self.peft_config[adapter_name].r = r
516+
new_scaling_factor = self.peft_config[adapter_name].lora_alpha/self.peft_config[adapter_name].r
517+
_update_scaling(self, adapter_name, new_scaling_factor)
518+
439519
_hotswap_adapter_from_state_dict(self, state_dict, adapter_name)
440520
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set it to None
441521
incompatible_keys = None
442522
else:
443523
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
444524
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
525+
# update r & scaling
526+
self.peft_config[adapter_name].r = r
527+
new_scaling_factor = self.peft_config[adapter_name].lora_alpha/r
528+
_update_scaling(self, adapter_name, new_scaling_factor)
445529

446530
if incompatible_keys is not None:
447531
# check only for unexpected keys

0 commit comments

Comments
 (0)