@@ -226,6 +226,8 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
226
226
diffusers_name = diffusers_name .replace ("k.proj.lora" , "to_k_lora" )
227
227
diffusers_name = diffusers_name .replace ("v.proj.lora" , "to_v_lora" )
228
228
diffusers_name = diffusers_name .replace ("out.proj.lora" , "to_out_lora" )
229
+ diffusers_name = diffusers_name .replace ("text.projection" , "text_projection" )
230
+
229
231
if "self_attn" in diffusers_name :
230
232
if lora_name .startswith (("lora_te_" , "lora_te1_" )):
231
233
te_state_dict [diffusers_name ] = state_dict .pop (key )
@@ -243,6 +245,10 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
243
245
else :
244
246
te2_state_dict [diffusers_name ] = state_dict .pop (key )
245
247
te2_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
248
+ # OneTrainer specificity
249
+ elif "text_projection" in diffusers_name and lora_name .startswith ("lora_te2_" ):
250
+ te2_state_dict [diffusers_name ] = state_dict .pop (key )
251
+ te2_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict .pop (lora_name_up )
246
252
247
253
if (is_te_dora_lora or is_te2_dora_lora ) and lora_name .startswith (("lora_te_" , "lora_te1_" , "lora_te2_" )):
248
254
dora_scale_key_to_replace_te = (
@@ -270,7 +276,7 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
270
276
network_alphas .update ({new_name : alpha })
271
277
272
278
if len (state_dict ) > 0 :
273
- raise ValueError (f"The following keys have not been correctly be renamed: \n \n { ', ' .join (state_dict .keys ())} " )
279
+ raise ValueError (f"The following keys have not been correctly renamed: \n \n { ', ' .join (state_dict .keys ())} " )
274
280
275
281
logger .info ("Kohya-style checkpoint detected." )
276
282
unet_state_dict = {f"{ unet_name } .{ module_name } " : params for module_name , params in unet_state_dict .items ()}
0 commit comments