35
35
from datasets import load_dataset
36
36
from huggingface_hub import create_repo , upload_folder
37
37
from packaging import version
38
- from peft import LoraConfig
38
+ from peft import LoraConfig , set_peft_model_state_dict
39
39
from peft .utils import get_peft_model_state_dict
40
40
from torchvision import transforms
41
41
from torchvision .transforms .functional import crop
51
51
)
52
52
from diffusers .loaders import LoraLoaderMixin
53
53
from diffusers .optimization import get_scheduler
54
- from diffusers .training_utils import cast_training_params , compute_snr
55
- from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
54
+ from diffusers .training_utils import _set_state_dict_into_text_encoder , cast_training_params , compute_snr
55
+ from diffusers .utils import (
56
+ check_min_version ,
57
+ convert_state_dict_to_diffusers ,
58
+ convert_unet_state_dict_to_peft ,
59
+ is_wandb_available ,
60
+ )
56
61
from diffusers .utils .import_utils import is_xformers_available
57
62
from diffusers .utils .torch_utils import is_compiled_module
58
63
@@ -629,14 +634,6 @@ def main(args):
629
634
text_encoder_one .add_adapter (text_lora_config )
630
635
text_encoder_two .add_adapter (text_lora_config )
631
636
632
- # Make sure the trainable params are in float32.
633
- if args .mixed_precision == "fp16" :
634
- models = [unet ]
635
- if args .train_text_encoder :
636
- models .extend ([text_encoder_one , text_encoder_two ])
637
- # only upcast trainable parameters (LoRA) into fp32
638
- cast_training_params (models , dtype = torch .float32 )
639
-
640
637
def unwrap_model (model ):
641
638
model = accelerator .unwrap_model (model )
642
639
model = model ._orig_mod if is_compiled_module (model ) else model
@@ -693,18 +690,34 @@ def load_model_hook(models, input_dir):
693
690
else :
694
691
raise ValueError (f"unexpected save model: { model .__class__ } " )
695
692
696
- lora_state_dict , network_alphas = LoraLoaderMixin .lora_state_dict (input_dir )
697
- LoraLoaderMixin .load_lora_into_unet (lora_state_dict , network_alphas = network_alphas , unet = unet_ )
693
+ lora_state_dict , _ = LoraLoaderMixin .lora_state_dict (input_dir )
694
+ unet_state_dict = {f'{ k .replace ("unet." , "" )} ' : v for k , v in lora_state_dict .items () if k .startswith ("unet." )}
695
+ unet_state_dict = convert_unet_state_dict_to_peft (unet_state_dict )
696
+ incompatible_keys = set_peft_model_state_dict (unet_ , unet_state_dict , adapter_name = "default" )
697
+ if incompatible_keys is not None :
698
+ # check only for unexpected keys
699
+ unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
700
+ if unexpected_keys :
701
+ logger .warning (
702
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
703
+ f" { unexpected_keys } . "
704
+ )
698
705
699
- text_encoder_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder." in k }
700
- LoraLoaderMixin .load_lora_into_text_encoder (
701
- text_encoder_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_one_
702
- )
706
+ if args .train_text_encoder :
707
+ _set_state_dict_into_text_encoder (lora_state_dict , prefix = "text_encoder." , text_encoder = text_encoder_one_ )
703
708
704
- text_encoder_2_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder_2." in k }
705
- LoraLoaderMixin .load_lora_into_text_encoder (
706
- text_encoder_2_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_two_
707
- )
709
+ _set_state_dict_into_text_encoder (
710
+ lora_state_dict , prefix = "text_encoder_2." , text_encoder = text_encoder_two_
711
+ )
712
+
713
+ # Make sure the trainable params are in float32. This is again needed since the base models
714
+ # are in `weight_dtype`. More details:
715
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
716
+ if args .mixed_precision == "fp16" :
717
+ models = [unet_ ]
718
+ if args .train_text_encoder :
719
+ models .extend ([text_encoder_one_ , text_encoder_two_ ])
720
+ cast_training_params (models , dtype = torch .float32 )
708
721
709
722
accelerator .register_save_state_pre_hook (save_model_hook )
710
723
accelerator .register_load_state_pre_hook (load_model_hook )
@@ -725,6 +738,13 @@ def load_model_hook(models, input_dir):
725
738
args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
726
739
)
727
740
741
+ # Make sure the trainable params are in float32.
742
+ if args .mixed_precision == "fp16" :
743
+ models = [unet ]
744
+ if args .train_text_encoder :
745
+ models .extend ([text_encoder_one , text_encoder_two ])
746
+ cast_training_params (models , dtype = torch .float32 )
747
+
728
748
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
729
749
if args .use_8bit_adam :
730
750
try :
0 commit comments