13
13
# limitations under the License.
14
14
import os
15
15
from collections import defaultdict
16
+ import collections
16
17
from contextlib import nullcontext
17
18
from pathlib import Path
18
19
from typing import Callable , Dict , Union
56
57
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
57
58
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
58
59
60
+ def pad_lora_weights (state_dict , target_rank ):
61
+ """
62
+ Pad LoRA weights in a state dict to a target rank while preserving the original behavior.
63
+
64
+ Args:
65
+ state_dict (dict): The state dict containing LoRA weights
66
+ target_rank (int): The target rank to pad to
67
+
68
+ Returns:
69
+ new_state_dict: A new state dict with padded LoRA weights
70
+ """
71
+ new_state_dict = {}
72
+
73
+ for key , weight in state_dict .items ():
74
+ if "lora_A" in key or "lora_B" in key :
75
+ is_conv = weight .dim () == 4
76
+
77
+ if "lora_A" in key :
78
+ original_rank = weight .size (0 )
79
+ if original_rank >= target_rank :
80
+ new_state_dict [key ] = weight
81
+ continue
82
+
83
+ if is_conv :
84
+ padded = torch .zeros (target_rank , weight .size (1 ), weight .size (2 ), weight .size (3 ),
85
+ device = weight .device , dtype = weight .dtype )
86
+ padded [:original_rank , :, :, :] = weight
87
+ else :
88
+ padded = torch .zeros (target_rank , weight .size (1 ), device = weight .device , dtype = weight .dtype )
89
+ padded [:original_rank , :] = weight
90
+
91
+ elif "lora_B" in key :
92
+ original_rank = weight .size (1 )
93
+ if original_rank >= target_rank :
94
+ new_state_dict [key ] = weight
95
+ continue
96
+
97
+ if is_conv :
98
+ padded = torch .zeros (weight .size (0 ), target_rank , weight .size (2 ), weight .size (3 ),
99
+ device = weight .device , dtype = weight .dtype )
100
+ padded [:, :original_rank , :, :] = weight
101
+ else :
102
+ padded = torch .zeros (weight .size (0 ), target_rank , device = weight .device , dtype = weight .dtype )
103
+ padded [:, :original_rank ] = weight
104
+
105
+ new_state_dict [key ] = padded
106
+ else :
107
+ new_state_dict [key ] = weight
108
+
109
+ return new_state_dict
59
110
60
111
class UNet2DConditionLoadersMixin :
61
112
"""
@@ -307,19 +358,32 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter
307
358
)
308
359
elif adapter_name not in getattr (self , "peft_config" , {}) and hotswap :
309
360
raise ValueError (f"Trying to hotswap LoRA adapter '{ adapter_name } ' but there is no existing adapter by that name." )
310
-
361
+
362
+ def get_rank (state_dict ):
363
+ rank = {}
364
+ for key , val in state_dict .items ():
365
+ if "lora_B" in key :
366
+ rank [key ] = val .shape [1 ]
367
+ return rank
368
+
369
+ def get_r (rank_dict ):
370
+ r = list (rank_dict .values ())[0 ]
371
+ if len (set (rank_dict .values ())) > 1 :
372
+ # get the rank occuring the most number of times
373
+ r = collections .Counter (rank_dict .values ()).most_common ()[0 ][0 ]
374
+ return r
375
+
311
376
state_dict = convert_unet_state_dict_to_peft (state_dict_to_be_used )
377
+ r = get_r (get_rank (state_dict ))
378
+
379
+ state_dict = pad_lora_weights (state_dict , 128 )
312
380
313
381
if network_alphas is not None :
314
382
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
315
383
# `convert_unet_state_dict_to_peft` method.
316
384
network_alphas = convert_unet_state_dict_to_peft (network_alphas )
317
385
318
- rank = {}
319
- for key , val in state_dict .items ():
320
- if "lora_B" in key :
321
- rank [key ] = val .shape [1 ]
322
-
386
+ rank = get_rank (state_dict )
323
387
lora_config_kwargs = get_peft_kwargs (rank , network_alphas , state_dict , is_unet = True )
324
388
if "use_dora" in lora_config_kwargs :
325
389
if lora_config_kwargs ["use_dora" ]:
@@ -348,7 +412,7 @@ def _check_hotswap_configs_compatible(config0, config1):
348
412
# values as well, but that's not implemented yet, and it would trigger a re-compilation if the model is compiled.
349
413
350
414
# TODO: This is a very rough check at the moment and there are probably better ways than to error out
351
- config_keys_to_check = ["lora_alpha" , " use_rslora" , "lora_dropout" , "alpha_pattern" , "use_dora" ]
415
+ config_keys_to_check = ["use_rslora" , "lora_dropout" , "alpha_pattern" , "use_dora" ]
352
416
config0 = config0 .to_dict ()
353
417
config1 = config1 .to_dict ()
354
418
for key in config_keys_to_check :
@@ -357,6 +421,15 @@ def _check_hotswap_configs_compatible(config0, config1):
357
421
if val0 != val1 :
358
422
raise ValueError (f"Configs are incompatible: for { key } , { val0 } != { val1 } " )
359
423
424
+ def _update_scaling (model , adapter_name , scaling_factor = None ):
425
+ target_modules = model .peft_config [adapter_name ].target_modules
426
+ for name , lora_module in model .named_modules ():
427
+ if name in target_modules and hasattr (lora_module , "scaling" ):
428
+ if not isinstance (lora_module .scaling [adapter_name ], torch .Tensor ):
429
+ lora_module .scaling [adapter_name ] = torch .tensor (scaling_factor , device = lora_module .weight .device )
430
+ else :
431
+ lora_module .scaling [adapter_name ].fill_ (scaling_factor )
432
+
360
433
def _hotswap_adapter_from_state_dict (model , state_dict , adapter_name ):
361
434
"""
362
435
Swap out the LoRA weights from the model with the weights from state_dict.
@@ -430,18 +503,29 @@ def _hotswap_adapter_from_state_dict(model, state_dict, adapter_name):
430
503
for key , new_val in state_dict .items ():
431
504
# no need to account for potential _orig_mod in key here, as torch handles that
432
505
old_val = attrgetter (key )(model )
433
- old_val .data = new_val .data .to (device = old_val .device )
506
+ # print(f" dtype: {old_val.data.dtype}/{new_val.data.dtype}, layout: {old_val.data.layout}/{new_val.data.layout}")
507
+ old_val .data .copy_ (new_val .data .to (device = old_val .device ))
434
508
# TODO: wanted to use swap_tensors but this somehow does not work on nn.Parameter
435
509
# torch.utils.swap_tensors(old_val.data, new_val.data)
436
510
437
511
if hotswap :
438
512
_check_hotswap_configs_compatible (self .peft_config [adapter_name ], lora_config )
513
+ self .peft_config [adapter_name ] = lora_config
514
+ # update r & scaling
515
+ self .peft_config [adapter_name ].r = r
516
+ new_scaling_factor = self .peft_config [adapter_name ].lora_alpha / self .peft_config [adapter_name ].r
517
+ _update_scaling (self , adapter_name , new_scaling_factor )
518
+
439
519
_hotswap_adapter_from_state_dict (self , state_dict , adapter_name )
440
520
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set it to None
441
521
incompatible_keys = None
442
522
else :
443
523
inject_adapter_in_model (lora_config , self , adapter_name = adapter_name )
444
524
incompatible_keys = set_peft_model_state_dict (self , state_dict , adapter_name )
525
+ # update r & scaling
526
+ self .peft_config [adapter_name ].r = r
527
+ new_scaling_factor = self .peft_config [adapter_name ].lora_alpha / r
528
+ _update_scaling (self , adapter_name , new_scaling_factor )
445
529
446
530
if incompatible_keys is not None :
447
531
# check only for unexpected keys
0 commit comments