Skip to content

Commit f922530

Browse files
Fix various bugs with LoRA Dreambooth and Dreambooth script (#3353)
* Improve checkpointing lora * fix more * Improve doc string * Update src/diffusers/loaders.py * make stytle * Apply suggestions from code review * Update src/diffusers/loaders.py * Apply suggestions from code review * Apply suggestions from code review * better * Fix all * Fix multi-GPU dreambooth * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Fix all * make style * make style --------- Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 58c6f9c commit f922530

File tree

3 files changed

+135
-53
lines changed

3 files changed

+135
-53
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import warnings
2323
from pathlib import Path
2424

25-
import accelerate
2625
import numpy as np
2726
import torch
2827
import torch.nn.functional as F
@@ -733,36 +732,34 @@ def main(args):
733732
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
734733
)
735734

736-
# `accelerate` 0.16.0 will have better support for customized saving
737-
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
738-
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
739-
def save_model_hook(models, weights, output_dir):
740-
for model in models:
741-
sub_dir = "unet" if type(model) == type(unet) else "text_encoder"
742-
model.save_pretrained(os.path.join(output_dir, sub_dir))
743-
744-
# make sure to pop weight so that corresponding model is not saved again
745-
weights.pop()
746-
747-
def load_model_hook(models, input_dir):
748-
while len(models) > 0:
749-
# pop models so that they are not loaded again
750-
model = models.pop()
751-
752-
if type(model) == type(text_encoder):
753-
# load transformers style into model
754-
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
755-
model.config = load_model.config
756-
else:
757-
# load diffusers style into model
758-
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
759-
model.register_to_config(**load_model.config)
735+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
736+
def save_model_hook(models, weights, output_dir):
737+
for model in models:
738+
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
739+
model.save_pretrained(os.path.join(output_dir, sub_dir))
740+
741+
# make sure to pop weight so that corresponding model is not saved again
742+
weights.pop()
743+
744+
def load_model_hook(models, input_dir):
745+
while len(models) > 0:
746+
# pop models so that they are not loaded again
747+
model = models.pop()
748+
749+
if isinstance(model, type(accelerator.unwrap_model(text_encoder))):
750+
# load transformers style into model
751+
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
752+
model.config = load_model.config
753+
else:
754+
# load diffusers style into model
755+
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
756+
model.register_to_config(**load_model.config)
760757

761-
model.load_state_dict(load_model.state_dict())
762-
del load_model
758+
model.load_state_dict(load_model.state_dict())
759+
del load_model
763760

764-
accelerator.register_save_state_pre_hook(save_model_hook)
765-
accelerator.register_load_state_pre_hook(load_model_hook)
761+
accelerator.register_save_state_pre_hook(save_model_hook)
762+
accelerator.register_load_state_pre_hook(load_model_hook)
766763

767764
vae.requires_grad_(False)
768765
if not args.train_text_encoder:

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,6 @@ def main(args):
834834

835835
unet.set_attn_processor(unet_lora_attn_procs)
836836
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
837-
accelerator.register_for_checkpointing(unet_lora_layers)
838837

839838
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
840839
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
@@ -853,9 +852,68 @@ def main(args):
853852
)
854853
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
855854
text_encoder = temp_pipeline.text_encoder
856-
accelerator.register_for_checkpointing(text_encoder_lora_layers)
857855
del temp_pipeline
858856

857+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
858+
def save_model_hook(models, weights, output_dir):
859+
# there are only two options here. Either are just the unet attn processor layers
860+
# or there are the unet and text encoder atten layers
861+
unet_lora_layers_to_save = None
862+
text_encoder_lora_layers_to_save = None
863+
864+
if args.train_text_encoder:
865+
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
866+
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
867+
868+
for model in models:
869+
state_dict = model.state_dict()
870+
871+
if (
872+
text_encoder_lora_layers is not None
873+
and text_encoder_keys is not None
874+
and state_dict.keys() == text_encoder_keys
875+
):
876+
# text encoder
877+
text_encoder_lora_layers_to_save = state_dict
878+
elif state_dict.keys() == unet_keys:
879+
# unet
880+
unet_lora_layers_to_save = state_dict
881+
882+
# make sure to pop weight so that corresponding model is not saved again
883+
weights.pop()
884+
885+
LoraLoaderMixin.save_lora_weights(
886+
output_dir,
887+
unet_lora_layers=unet_lora_layers_to_save,
888+
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
889+
)
890+
891+
def load_model_hook(models, input_dir):
892+
# Note we DON'T pass the unet and text encoder here an purpose
893+
# so that the we don't accidentally override the LoRA layers of
894+
# unet_lora_layers and text_encoder_lora_layers which are stored in `models`
895+
# with new torch.nn.Modules / weights. We simply use the pipeline class as
896+
# an easy way to load the lora checkpoints
897+
temp_pipeline = DiffusionPipeline.from_pretrained(
898+
args.pretrained_model_name_or_path,
899+
revision=args.revision,
900+
torch_dtype=weight_dtype,
901+
)
902+
temp_pipeline.load_lora_weights(input_dir)
903+
904+
# load lora weights into models
905+
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict())
906+
if len(models) > 1:
907+
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
908+
909+
# delete temporary pipeline and pop models
910+
del temp_pipeline
911+
for _ in range(len(models)):
912+
models.pop()
913+
914+
accelerator.register_save_state_pre_hook(save_model_hook)
915+
accelerator.register_load_state_pre_hook(load_model_hook)
916+
859917
# Enable TF32 for faster training on Ampere GPUs,
860918
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
861919
if args.allow_tf32:
@@ -1130,17 +1188,10 @@ def compute_text_embeddings(prompt):
11301188
progress_bar.update(1)
11311189
global_step += 1
11321190

1133-
if global_step % args.checkpointing_steps == 0:
1134-
if accelerator.is_main_process:
1191+
if accelerator.is_main_process:
1192+
if global_step % args.checkpointing_steps == 0:
11351193
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1136-
# We combine the text encoder and UNet LoRA parameters with a simple
1137-
# custom logic. `accelerator.save_state()` won't know that. So,
1138-
# use `LoraLoaderMixin.save_lora_weights()`.
1139-
LoraLoaderMixin.save_lora_weights(
1140-
save_directory=save_path,
1141-
unet_lora_layers=unet_lora_layers,
1142-
text_encoder_lora_layers=text_encoder_lora_layers,
1143-
)
1194+
accelerator.save_state(save_path)
11441195
logger.info(f"Saved state to {save_path}")
11451196

11461197
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -1217,8 +1268,12 @@ def compute_text_embeddings(prompt):
12171268
accelerator.wait_for_everyone()
12181269
if accelerator.is_main_process:
12191270
unet = unet.to(torch.float32)
1271+
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
1272+
12201273
if text_encoder is not None:
12211274
text_encoder = text_encoder.to(torch.float32)
1275+
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
1276+
12221277
LoraLoaderMixin.save_lora_weights(
12231278
save_directory=args.output_dir,
12241279
unet_lora_layers=unet_lora_layers,
@@ -1250,6 +1305,7 @@ def compute_text_embeddings(prompt):
12501305
pipeline.load_lora_weights(args.output_dir)
12511306

12521307
# run inference
1308+
images = []
12531309
if args.validation_prompt and args.num_validation_images > 0:
12541310
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
12551311
images = [

src/diffusers/loaders.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]):
7070
self.mapping = dict(enumerate(state_dict.keys()))
7171
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
7272

73+
# .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder
74+
self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"]
75+
7376
# we add a hook to state_dict() and load_state_dict() so that the
7477
# naming fits with `unet.attn_processors`
7578
def map_to(module, state_dict, *args, **kwargs):
@@ -81,10 +84,19 @@ def map_to(module, state_dict, *args, **kwargs):
8184

8285
return new_state_dict
8386

87+
def remap_key(key, state_dict):
88+
for k in self.split_keys:
89+
if k in key:
90+
return key.split(k)[0] + k
91+
92+
raise ValueError(
93+
f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
94+
)
95+
8496
def map_from(module, state_dict, *args, **kwargs):
8597
all_keys = list(state_dict.keys())
8698
for key in all_keys:
87-
replace_key = key.split(".processor")[0] + ".processor"
99+
replace_key = remap_key(key, state_dict)
88100
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
89101
state_dict[new_key] = state_dict[key]
90102
del state_dict[key]
@@ -898,6 +910,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
898910
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
899911
self._modify_text_encoder(attn_procs_text_encoder)
900912

913+
# save lora attn procs of text encoder so that it can be easily retrieved
914+
self._text_encoder_lora_attn_procs = attn_procs_text_encoder
915+
901916
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
902917
# contain the module names of the `unet` as its keys WITHOUT any prefix.
903918
elif not all(
@@ -907,6 +922,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
907922
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
908923
warnings.warn(warn_message)
909924

925+
@property
926+
def text_encoder_lora_attn_procs(self):
927+
if hasattr(self, "_text_encoder_lora_attn_procs"):
928+
return self._text_encoder_lora_attn_procs
929+
return
930+
910931
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
911932
r"""
912933
Monkey-patches the forward passes of attention modules of the text encoder.
@@ -1110,7 +1131,7 @@ def _load_text_encoder_attn_procs(
11101131
def save_lora_weights(
11111132
self,
11121133
save_directory: Union[str, os.PathLike],
1113-
unet_lora_layers: Dict[str, torch.nn.Module] = None,
1134+
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
11141135
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
11151136
is_main_process: bool = True,
11161137
weight_name: str = None,
@@ -1123,13 +1144,14 @@ def save_lora_weights(
11231144
Arguments:
11241145
save_directory (`str` or `os.PathLike`):
11251146
Directory to which to save. Will be created if it doesn't exist.
1126-
unet_lora_layers (`Dict[str, torch.nn.Module`]):
1147+
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
11271148
State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the
1128-
serialization process easier and cleaner.
1129-
text_encoder_lora_layers (`Dict[str, torch.nn.Module`]):
1149+
serialization process easier and cleaner. Values can be both LoRA torch.nn.Modules layers or torch
1150+
weights.
1151+
text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
11301152
State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from
11311153
`transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state
1132-
dict.
1154+
dict. Values can be both LoRA torch.nn.Modules layers or torch weights.
11331155
is_main_process (`bool`, *optional*, defaults to `True`):
11341156
Whether the process calling this is the main process or not. Useful when in distributed training like
11351157
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
@@ -1157,15 +1179,22 @@ def save_function(weights, filename):
11571179
# Create a flat dictionary.
11581180
state_dict = {}
11591181
if unet_lora_layers is not None:
1160-
unet_lora_state_dict = {
1161-
f"{self.unet_name}.{module_name}": param
1162-
for module_name, param in unet_lora_layers.state_dict().items()
1163-
}
1182+
weights = (
1183+
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
1184+
)
1185+
1186+
unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
11641187
state_dict.update(unet_lora_state_dict)
1188+
11651189
if text_encoder_lora_layers is not None:
1190+
weights = (
1191+
text_encoder_lora_layers.state_dict()
1192+
if isinstance(text_encoder_lora_layers, torch.nn.Module)
1193+
else text_encoder_lora_layers
1194+
)
1195+
11661196
text_encoder_lora_state_dict = {
1167-
f"{self.text_encoder_name}.{module_name}": param
1168-
for module_name, param in text_encoder_lora_layers.state_dict().items()
1197+
f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
11691198
}
11701199
state_dict.update(text_encoder_lora_state_dict)
11711200

0 commit comments

Comments
 (0)