@@ -70,6 +70,9 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]):
70
70
self .mapping = dict (enumerate (state_dict .keys ()))
71
71
self .rev_mapping = {v : k for k , v in enumerate (state_dict .keys ())}
72
72
73
+ # .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder
74
+ self .split_keys = [".processor" , ".k_proj" , ".q_proj" , ".v_proj" , ".out_proj" ]
75
+
73
76
# we add a hook to state_dict() and load_state_dict() so that the
74
77
# naming fits with `unet.attn_processors`
75
78
def map_to (module , state_dict , * args , ** kwargs ):
@@ -81,10 +84,19 @@ def map_to(module, state_dict, *args, **kwargs):
81
84
82
85
return new_state_dict
83
86
87
+ def remap_key (key , state_dict ):
88
+ for k in self .split_keys :
89
+ if k in key :
90
+ return key .split (k )[0 ] + k
91
+
92
+ raise ValueError (
93
+ f"There seems to be a problem with the state_dict: { set (state_dict .keys ())} . { key } has to have one of { self .split_keys } ."
94
+ )
95
+
84
96
def map_from (module , state_dict , * args , ** kwargs ):
85
97
all_keys = list (state_dict .keys ())
86
98
for key in all_keys :
87
- replace_key = key . split ( ".processor" )[ 0 ] + ".processor"
99
+ replace_key = remap_key ( key , state_dict )
88
100
new_key = key .replace (replace_key , f"layers.{ module .rev_mapping [replace_key ]} " )
89
101
state_dict [new_key ] = state_dict [key ]
90
102
del state_dict [key ]
@@ -898,6 +910,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
898
910
attn_procs_text_encoder = self ._load_text_encoder_attn_procs (text_encoder_lora_state_dict )
899
911
self ._modify_text_encoder (attn_procs_text_encoder )
900
912
913
+ # save lora attn procs of text encoder so that it can be easily retrieved
914
+ self ._text_encoder_lora_attn_procs = attn_procs_text_encoder
915
+
901
916
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
902
917
# contain the module names of the `unet` as its keys WITHOUT any prefix.
903
918
elif not all (
@@ -907,6 +922,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
907
922
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
908
923
warnings .warn (warn_message )
909
924
925
+ @property
926
+ def text_encoder_lora_attn_procs (self ):
927
+ if hasattr (self , "_text_encoder_lora_attn_procs" ):
928
+ return self ._text_encoder_lora_attn_procs
929
+ return
930
+
910
931
def _modify_text_encoder (self , attn_processors : Dict [str , LoRAAttnProcessor ]):
911
932
r"""
912
933
Monkey-patches the forward passes of attention modules of the text encoder.
@@ -1110,7 +1131,7 @@ def _load_text_encoder_attn_procs(
1110
1131
def save_lora_weights (
1111
1132
self ,
1112
1133
save_directory : Union [str , os .PathLike ],
1113
- unet_lora_layers : Dict [str , torch .nn .Module ] = None ,
1134
+ unet_lora_layers : Dict [str , Union [ torch .nn .Module , torch . Tensor ] ] = None ,
1114
1135
text_encoder_lora_layers : Dict [str , torch .nn .Module ] = None ,
1115
1136
is_main_process : bool = True ,
1116
1137
weight_name : str = None ,
@@ -1123,13 +1144,14 @@ def save_lora_weights(
1123
1144
Arguments:
1124
1145
save_directory (`str` or `os.PathLike`):
1125
1146
Directory to which to save. Will be created if it doesn't exist.
1126
- unet_lora_layers (`Dict[str, torch.nn.Module`] ):
1147
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]` ):
1127
1148
State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the
1128
- serialization process easier and cleaner.
1129
- text_encoder_lora_layers (`Dict[str, torch.nn.Module`]):
1149
+ serialization process easier and cleaner. Values can be both LoRA torch.nn.Modules layers or torch
1150
+ weights.
1151
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
1130
1152
State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from
1131
1153
`transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state
1132
- dict.
1154
+ dict. Values can be both LoRA torch.nn.Modules layers or torch weights.
1133
1155
is_main_process (`bool`, *optional*, defaults to `True`):
1134
1156
Whether the process calling this is the main process or not. Useful when in distributed training like
1135
1157
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
@@ -1157,15 +1179,22 @@ def save_function(weights, filename):
1157
1179
# Create a flat dictionary.
1158
1180
state_dict = {}
1159
1181
if unet_lora_layers is not None :
1160
- unet_lora_state_dict = {
1161
- f"{ self .unet_name } .{ module_name } " : param
1162
- for module_name , param in unet_lora_layers .state_dict ().items ()
1163
- }
1182
+ weights = (
1183
+ unet_lora_layers .state_dict () if isinstance (unet_lora_layers , torch .nn .Module ) else unet_lora_layers
1184
+ )
1185
+
1186
+ unet_lora_state_dict = {f"{ self .unet_name } .{ module_name } " : param for module_name , param in weights .items ()}
1164
1187
state_dict .update (unet_lora_state_dict )
1188
+
1165
1189
if text_encoder_lora_layers is not None :
1190
+ weights = (
1191
+ text_encoder_lora_layers .state_dict ()
1192
+ if isinstance (text_encoder_lora_layers , torch .nn .Module )
1193
+ else text_encoder_lora_layers
1194
+ )
1195
+
1166
1196
text_encoder_lora_state_dict = {
1167
- f"{ self .text_encoder_name } .{ module_name } " : param
1168
- for module_name , param in text_encoder_lora_layers .state_dict ().items ()
1197
+ f"{ self .text_encoder_name } .{ module_name } " : param for module_name , param in weights .items ()
1169
1198
}
1170
1199
state_dict .update (text_encoder_lora_state_dict )
1171
1200
0 commit comments