13
13
# limitations under the License.
14
14
15
15
import re
16
+ from typing import List
16
17
17
18
import torch
18
19
19
- from ..utils import is_peft_version , logging
20
+ from ..utils import is_peft_version , logging , state_dict_all_zero
20
21
21
22
22
23
logger = logging .get_logger (__name__ )
23
24
24
25
26
+ def swap_scale_shift (weight ):
27
+ shift , scale = weight .chunk (2 , dim = 0 )
28
+ new_weight = torch .cat ([scale , shift ], dim = 0 )
29
+ return new_weight
30
+
31
+
25
32
def _maybe_map_sgm_blocks_to_diffusers (state_dict , unet_config , delimiter = "_" , block_slice_pos = 5 ):
26
33
# 1. get all state_dict_keys
27
34
all_keys = list (state_dict .keys ())
@@ -313,6 +320,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
313
320
# Be aware that this is the new diffusers convention and the rest of the code might
314
321
# not utilize it yet.
315
322
diffusers_name = diffusers_name .replace (".lora." , ".lora_linear_layer." )
323
+
316
324
return diffusers_name
317
325
318
326
@@ -331,8 +339,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
331
339
332
340
333
341
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
334
- # are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
335
- # All credits go to `kohya-ss`.
342
+ # are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
336
343
def _convert_kohya_flux_lora_to_diffusers (state_dict ):
337
344
def _convert_to_ai_toolkit (sds_sd , ait_sd , sds_key , ait_key ):
338
345
if sds_key + ".lora_down.weight" not in sds_sd :
@@ -341,7 +348,8 @@ def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
341
348
342
349
# scale weight by alpha and dim
343
350
rank = down_weight .shape [0 ]
344
- alpha = sds_sd .pop (sds_key + ".alpha" ).item () # alpha is scalar
351
+ default_alpha = torch .tensor (rank , dtype = down_weight .dtype , device = down_weight .device , requires_grad = False )
352
+ alpha = sds_sd .pop (sds_key + ".alpha" , default_alpha ).item () # alpha is scalar
345
353
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
346
354
347
355
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
@@ -362,7 +370,10 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
362
370
sd_lora_rank = down_weight .shape [0 ]
363
371
364
372
# scale weight by alpha and dim
365
- alpha = sds_sd .pop (sds_key + ".alpha" )
373
+ default_alpha = torch .tensor (
374
+ sd_lora_rank , dtype = down_weight .dtype , device = down_weight .device , requires_grad = False
375
+ )
376
+ alpha = sds_sd .pop (sds_key + ".alpha" , default_alpha )
366
377
scale = alpha / sd_lora_rank
367
378
368
379
# calculate scale_down and scale_up
@@ -516,10 +527,103 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
516
527
f"transformer.single_transformer_blocks.{ i } .norm.linear" ,
517
528
)
518
529
530
+ # TODO: alphas.
531
+ def assign_remaining_weights (assignments , source ):
532
+ for lora_key in ["lora_A" , "lora_B" ]:
533
+ orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
534
+ for target_fmt , source_fmt , transform in assignments :
535
+ target_key = target_fmt .format (lora_key = lora_key )
536
+ source_key = source_fmt .format (orig_lora_key = orig_lora_key )
537
+ value = source .pop (source_key )
538
+ if transform :
539
+ value = transform (value )
540
+ ait_sd [target_key ] = value
541
+
542
+ if any ("guidance_in" in k for k in sds_sd ):
543
+ assign_remaining_weights (
544
+ [
545
+ (
546
+ "time_text_embed.guidance_embedder.linear_1.{lora_key}.weight" ,
547
+ "lora_unet_guidance_in_in_layer.{orig_lora_key}.weight" ,
548
+ None ,
549
+ ),
550
+ (
551
+ "time_text_embed.guidance_embedder.linear_2.{lora_key}.weight" ,
552
+ "lora_unet_guidance_in_out_layer.{orig_lora_key}.weight" ,
553
+ None ,
554
+ ),
555
+ ],
556
+ sds_sd ,
557
+ )
558
+
559
+ if any ("img_in" in k for k in sds_sd ):
560
+ assign_remaining_weights (
561
+ [
562
+ ("x_embedder.{lora_key}.weight" , "lora_unet_img_in.{orig_lora_key}.weight" , None ),
563
+ ],
564
+ sds_sd ,
565
+ )
566
+
567
+ if any ("txt_in" in k for k in sds_sd ):
568
+ assign_remaining_weights (
569
+ [
570
+ ("context_embedder.{lora_key}.weight" , "lora_unet_txt_in.{orig_lora_key}.weight" , None ),
571
+ ],
572
+ sds_sd ,
573
+ )
574
+
575
+ if any ("time_in" in k for k in sds_sd ):
576
+ assign_remaining_weights (
577
+ [
578
+ (
579
+ "time_text_embed.timestep_embedder.linear_1.{lora_key}.weight" ,
580
+ "lora_unet_time_in_in_layer.{orig_lora_key}.weight" ,
581
+ None ,
582
+ ),
583
+ (
584
+ "time_text_embed.timestep_embedder.linear_2.{lora_key}.weight" ,
585
+ "lora_unet_time_in_out_layer.{orig_lora_key}.weight" ,
586
+ None ,
587
+ ),
588
+ ],
589
+ sds_sd ,
590
+ )
591
+
592
+ if any ("vector_in" in k for k in sds_sd ):
593
+ assign_remaining_weights (
594
+ [
595
+ (
596
+ "time_text_embed.text_embedder.linear_1.{lora_key}.weight" ,
597
+ "lora_unet_vector_in_in_layer.{orig_lora_key}.weight" ,
598
+ None ,
599
+ ),
600
+ (
601
+ "time_text_embed.text_embedder.linear_2.{lora_key}.weight" ,
602
+ "lora_unet_vector_in_out_layer.{orig_lora_key}.weight" ,
603
+ None ,
604
+ ),
605
+ ],
606
+ sds_sd ,
607
+ )
608
+
609
+ if any ("final_layer" in k for k in sds_sd ):
610
+ # Notice the swap in processing for "final_layer".
611
+ assign_remaining_weights (
612
+ [
613
+ (
614
+ "norm_out.linear.{lora_key}.weight" ,
615
+ "lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight" ,
616
+ swap_scale_shift ,
617
+ ),
618
+ ("proj_out.{lora_key}.weight" , "lora_unet_final_layer_linear.{orig_lora_key}.weight" , None ),
619
+ ],
620
+ sds_sd ,
621
+ )
622
+
519
623
remaining_keys = list (sds_sd .keys ())
520
624
te_state_dict = {}
521
625
if remaining_keys :
522
- if not all (k .startswith ("lora_te" ) for k in remaining_keys ):
626
+ if not all (k .startswith (( "lora_te" , "lora_te1" ) ) for k in remaining_keys ):
523
627
raise ValueError (f"Incompatible keys detected: \n \n { ', ' .join (remaining_keys )} " )
524
628
for key in remaining_keys :
525
629
if not key .endswith ("lora_down.weight" ):
@@ -680,10 +784,98 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
680
784
if has_peft_state_dict :
681
785
state_dict = {k : v for k , v in state_dict .items () if k .startswith ("transformer." )}
682
786
return state_dict
787
+
683
788
# Another weird one.
684
789
has_mixture = any (
685
790
k .startswith ("lora_transformer_" ) and ("lora_down" in k or "lora_up" in k or "alpha" in k ) for k in state_dict
686
791
)
792
+
793
+ # ComfyUI.
794
+ if not has_mixture :
795
+ state_dict = {k .replace ("diffusion_model." , "lora_unet_" ): v for k , v in state_dict .items ()}
796
+ state_dict = {k .replace ("text_encoders.clip_l.transformer." , "lora_te_" ): v for k , v in state_dict .items ()}
797
+
798
+ has_position_embedding = any ("position_embedding" in k for k in state_dict )
799
+ if has_position_embedding :
800
+ zero_status_pe = state_dict_all_zero (state_dict , "position_embedding" )
801
+ if zero_status_pe :
802
+ logger .info (
803
+ "The `position_embedding` LoRA params are all zeros which make them ineffective. "
804
+ "So, we will purge them out of the curret state dict to make loading possible."
805
+ )
806
+
807
+ else :
808
+ logger .info (
809
+ "The state_dict has position_embedding LoRA params and we currently do not support them. "
810
+ "Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
811
+ )
812
+ state_dict = {k : v for k , v in state_dict .items () if "position_embedding" not in k }
813
+
814
+ has_t5xxl = any (k .startswith ("text_encoders.t5xxl.transformer." ) for k in state_dict )
815
+ if has_t5xxl :
816
+ zero_status_t5 = state_dict_all_zero (state_dict , "text_encoders.t5xxl" )
817
+ if zero_status_t5 :
818
+ logger .info (
819
+ "The `t5xxl` LoRA params are all zeros which make them ineffective. "
820
+ "So, we will purge them out of the curret state dict to make loading possible."
821
+ )
822
+ else :
823
+ logger .info (
824
+ "T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
825
+ "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
826
+ )
827
+ state_dict = {k : v for k , v in state_dict .items () if not k .startswith ("text_encoders.t5xxl.transformer." )}
828
+
829
+ has_diffb = any ("diff_b" in k and k .startswith (("lora_unet_" , "lora_te_" )) for k in state_dict )
830
+ if has_diffb :
831
+ zero_status_diff_b = state_dict_all_zero (state_dict , ".diff_b" )
832
+ if zero_status_diff_b :
833
+ logger .info (
834
+ "The `diff_b` LoRA params are all zeros which make them ineffective. "
835
+ "So, we will purge them out of the curret state dict to make loading possible."
836
+ )
837
+ else :
838
+ logger .info (
839
+ "`diff_b` keys found in the state dict which are currently unsupported. "
840
+ "So, we will filter out those keys. Open an issue if this is a problem - "
841
+ "https://github.com/huggingface/diffusers/issues/new."
842
+ )
843
+ state_dict = {k : v for k , v in state_dict .items () if ".diff_b" not in k }
844
+
845
+ has_norm_diff = any (".norm" in k and ".diff" in k for k in state_dict )
846
+ if has_norm_diff :
847
+ zero_status_diff = state_dict_all_zero (state_dict , ".diff" )
848
+ if zero_status_diff :
849
+ logger .info (
850
+ "The `diff` LoRA params are all zeros which make them ineffective. "
851
+ "So, we will purge them out of the curret state dict to make loading possible."
852
+ )
853
+ else :
854
+ logger .info (
855
+ "Normalization diff keys found in the state dict which are currently unsupported. "
856
+ "So, we will filter out those keys. Open an issue if this is a problem - "
857
+ "https://github.com/huggingface/diffusers/issues/new."
858
+ )
859
+ state_dict = {k : v for k , v in state_dict .items () if ".norm" not in k and ".diff" not in k }
860
+
861
+ limit_substrings = ["lora_down" , "lora_up" ]
862
+ if any ("alpha" in k for k in state_dict ):
863
+ limit_substrings .append ("alpha" )
864
+
865
+ state_dict = {
866
+ _custom_replace (k , limit_substrings ): v
867
+ for k , v in state_dict .items ()
868
+ if k .startswith (("lora_unet_" , "lora_te_" ))
869
+ }
870
+
871
+ if any ("text_projection" in k for k in state_dict ):
872
+ logger .info (
873
+ "`text_projection` keys found in the `state_dict` which are unexpected. "
874
+ "So, we will filter out those keys. Open an issue if this is a problem - "
875
+ "https://github.com/huggingface/diffusers/issues/new."
876
+ )
877
+ state_dict = {k : v for k , v in state_dict .items () if "text_projection" not in k }
878
+
687
879
if has_mixture :
688
880
return _convert_mixture_state_dict_to_diffusers (state_dict )
689
881
@@ -798,6 +990,26 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
798
990
return new_state_dict
799
991
800
992
993
+ def _custom_replace (key : str , substrings : List [str ]) -> str :
994
+ # Replaces the "."s with "_"s upto the `substrings`.
995
+ # Example:
996
+ # lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
997
+ pattern = "(" + "|" .join (re .escape (sub ) for sub in substrings ) + ")"
998
+
999
+ match = re .search (pattern , key )
1000
+ if match :
1001
+ start_sub = match .start ()
1002
+ if start_sub > 0 and key [start_sub - 1 ] == "." :
1003
+ boundary = start_sub - 1
1004
+ else :
1005
+ boundary = start_sub
1006
+ left = key [:boundary ].replace ("." , "_" )
1007
+ right = key [boundary :]
1008
+ return left + right
1009
+ else :
1010
+ return key .replace ("." , "_" )
1011
+
1012
+
801
1013
def _convert_bfl_flux_control_lora_to_diffusers (original_state_dict ):
802
1014
converted_state_dict = {}
803
1015
original_state_dict_keys = list (original_state_dict .keys ())
@@ -806,11 +1018,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
806
1018
inner_dim = 3072
807
1019
mlp_ratio = 4.0
808
1020
809
- def swap_scale_shift (weight ):
810
- shift , scale = weight .chunk (2 , dim = 0 )
811
- new_weight = torch .cat ([scale , shift ], dim = 0 )
812
- return new_weight
813
-
814
1021
for lora_key in ["lora_A" , "lora_B" ]:
815
1022
## time_text_embed.timestep_embedder <- time_in
816
1023
converted_state_dict [
0 commit comments