Skip to content

Commit 09c5c81

Browse files
Fix various bugs with LoRA Dreambooth and Dreambooth script (huggingface#3353)
* Improve checkpointing lora * fix more * Improve doc string * Update src/diffusers/loaders.py * make stytle * Apply suggestions from code review * Update src/diffusers/loaders.py * Apply suggestions from code review * Apply suggestions from code review * better * Fix all * Fix multi-GPU dreambooth * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Fix all * make style * make style --------- Co-authored-by: Pedro Cuenca <[email protected]>
1 parent a151a6a commit 09c5c81

File tree

1 file changed

+41
-12
lines changed

1 file changed

+41
-12
lines changed

loaders.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]):
7070
self.mapping = dict(enumerate(state_dict.keys()))
7171
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
7272

73+
# .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder
74+
self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"]
75+
7376
# we add a hook to state_dict() and load_state_dict() so that the
7477
# naming fits with `unet.attn_processors`
7578
def map_to(module, state_dict, *args, **kwargs):
@@ -81,10 +84,19 @@ def map_to(module, state_dict, *args, **kwargs):
8184

8285
return new_state_dict
8386

87+
def remap_key(key, state_dict):
88+
for k in self.split_keys:
89+
if k in key:
90+
return key.split(k)[0] + k
91+
92+
raise ValueError(
93+
f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
94+
)
95+
8496
def map_from(module, state_dict, *args, **kwargs):
8597
all_keys = list(state_dict.keys())
8698
for key in all_keys:
87-
replace_key = key.split(".processor")[0] + ".processor"
99+
replace_key = remap_key(key, state_dict)
88100
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
89101
state_dict[new_key] = state_dict[key]
90102
del state_dict[key]
@@ -898,6 +910,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
898910
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
899911
self._modify_text_encoder(attn_procs_text_encoder)
900912

913+
# save lora attn procs of text encoder so that it can be easily retrieved
914+
self._text_encoder_lora_attn_procs = attn_procs_text_encoder
915+
901916
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
902917
# contain the module names of the `unet` as its keys WITHOUT any prefix.
903918
elif not all(
@@ -907,6 +922,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
907922
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
908923
warnings.warn(warn_message)
909924

925+
@property
926+
def text_encoder_lora_attn_procs(self):
927+
if hasattr(self, "_text_encoder_lora_attn_procs"):
928+
return self._text_encoder_lora_attn_procs
929+
return
930+
910931
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
911932
r"""
912933
Monkey-patches the forward passes of attention modules of the text encoder.
@@ -1110,7 +1131,7 @@ def _load_text_encoder_attn_procs(
11101131
def save_lora_weights(
11111132
self,
11121133
save_directory: Union[str, os.PathLike],
1113-
unet_lora_layers: Dict[str, torch.nn.Module] = None,
1134+
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
11141135
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
11151136
is_main_process: bool = True,
11161137
weight_name: str = None,
@@ -1123,13 +1144,14 @@ def save_lora_weights(
11231144
Arguments:
11241145
save_directory (`str` or `os.PathLike`):
11251146
Directory to which to save. Will be created if it doesn't exist.
1126-
unet_lora_layers (`Dict[str, torch.nn.Module`]):
1147+
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
11271148
State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the
1128-
serialization process easier and cleaner.
1129-
text_encoder_lora_layers (`Dict[str, torch.nn.Module`]):
1149+
serialization process easier and cleaner. Values can be both LoRA torch.nn.Modules layers or torch
1150+
weights.
1151+
text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
11301152
State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from
11311153
`transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state
1132-
dict.
1154+
dict. Values can be both LoRA torch.nn.Modules layers or torch weights.
11331155
is_main_process (`bool`, *optional*, defaults to `True`):
11341156
Whether the process calling this is the main process or not. Useful when in distributed training like
11351157
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
@@ -1157,15 +1179,22 @@ def save_function(weights, filename):
11571179
# Create a flat dictionary.
11581180
state_dict = {}
11591181
if unet_lora_layers is not None:
1160-
unet_lora_state_dict = {
1161-
f"{self.unet_name}.{module_name}": param
1162-
for module_name, param in unet_lora_layers.state_dict().items()
1163-
}
1182+
weights = (
1183+
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
1184+
)
1185+
1186+
unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
11641187
state_dict.update(unet_lora_state_dict)
1188+
11651189
if text_encoder_lora_layers is not None:
1190+
weights = (
1191+
text_encoder_lora_layers.state_dict()
1192+
if isinstance(text_encoder_lora_layers, torch.nn.Module)
1193+
else text_encoder_lora_layers
1194+
)
1195+
11661196
text_encoder_lora_state_dict = {
1167-
f"{self.text_encoder_name}.{module_name}": param
1168-
for module_name, param in text_encoder_lora_layers.state_dict().items()
1197+
f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
11691198
}
11701199
state_dict.update(text_encoder_lora_state_dict)
11711200

0 commit comments

Comments
 (0)