Skip to content

Commit 50a28f1

Browse files
sajadnSajad Norouzi
authored and
Jimmy
committed
Fix mixed precision fine-tuning for text-to-image-lora-sdxl example. (huggingface#6751)
* Fix mixed precision fine-tuning for text-to-image-lora-sdxl example. * fix text_encoder_two bug. --------- Co-authored-by: Sajad Norouzi <[email protected]>
1 parent 5d849e7 commit 50a28f1

File tree

1 file changed

+41
-21
lines changed

1 file changed

+41
-21
lines changed

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from datasets import load_dataset
3636
from huggingface_hub import create_repo, upload_folder
3737
from packaging import version
38-
from peft import LoraConfig
38+
from peft import LoraConfig, set_peft_model_state_dict
3939
from peft.utils import get_peft_model_state_dict
4040
from torchvision import transforms
4141
from torchvision.transforms.functional import crop
@@ -51,8 +51,13 @@
5151
)
5252
from diffusers.loaders import LoraLoaderMixin
5353
from diffusers.optimization import get_scheduler
54-
from diffusers.training_utils import cast_training_params, compute_snr
55-
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
54+
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
55+
from diffusers.utils import (
56+
check_min_version,
57+
convert_state_dict_to_diffusers,
58+
convert_unet_state_dict_to_peft,
59+
is_wandb_available,
60+
)
5661
from diffusers.utils.import_utils import is_xformers_available
5762
from diffusers.utils.torch_utils import is_compiled_module
5863

@@ -629,14 +634,6 @@ def main(args):
629634
text_encoder_one.add_adapter(text_lora_config)
630635
text_encoder_two.add_adapter(text_lora_config)
631636

632-
# Make sure the trainable params are in float32.
633-
if args.mixed_precision == "fp16":
634-
models = [unet]
635-
if args.train_text_encoder:
636-
models.extend([text_encoder_one, text_encoder_two])
637-
# only upcast trainable parameters (LoRA) into fp32
638-
cast_training_params(models, dtype=torch.float32)
639-
640637
def unwrap_model(model):
641638
model = accelerator.unwrap_model(model)
642639
model = model._orig_mod if is_compiled_module(model) else model
@@ -693,18 +690,34 @@ def load_model_hook(models, input_dir):
693690
else:
694691
raise ValueError(f"unexpected save model: {model.__class__}")
695692

696-
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
697-
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
693+
lora_state_dict, _ = LoraLoaderMixin.lora_state_dict(input_dir)
694+
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
695+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
696+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
697+
if incompatible_keys is not None:
698+
# check only for unexpected keys
699+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
700+
if unexpected_keys:
701+
logger.warning(
702+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
703+
f" {unexpected_keys}. "
704+
)
698705

699-
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
700-
LoraLoaderMixin.load_lora_into_text_encoder(
701-
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
702-
)
706+
if args.train_text_encoder:
707+
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
703708

704-
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
705-
LoraLoaderMixin.load_lora_into_text_encoder(
706-
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
707-
)
709+
_set_state_dict_into_text_encoder(
710+
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
711+
)
712+
713+
# Make sure the trainable params are in float32. This is again needed since the base models
714+
# are in `weight_dtype`. More details:
715+
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
716+
if args.mixed_precision == "fp16":
717+
models = [unet_]
718+
if args.train_text_encoder:
719+
models.extend([text_encoder_one_, text_encoder_two_])
720+
cast_training_params(models, dtype=torch.float32)
708721

709722
accelerator.register_save_state_pre_hook(save_model_hook)
710723
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -725,6 +738,13 @@ def load_model_hook(models, input_dir):
725738
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
726739
)
727740

741+
# Make sure the trainable params are in float32.
742+
if args.mixed_precision == "fp16":
743+
models = [unet]
744+
if args.train_text_encoder:
745+
models.extend([text_encoder_one, text_encoder_two])
746+
cast_training_params(models, dtype=torch.float32)
747+
728748
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
729749
if args.use_8bit_adam:
730750
try:

0 commit comments

Comments
 (0)