Skip to content

Commit e6df8ed

Browse files
authored
[LoRA] attempt at fixing onetrainer lora. (#8242)
* attempt at fixing onetrainer lora. * fix
1 parent 80cfaeb commit e6df8ed

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
226226
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
227227
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
228228
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
229+
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
230+
229231
if "self_attn" in diffusers_name:
230232
if lora_name.startswith(("lora_te_", "lora_te1_")):
231233
te_state_dict[diffusers_name] = state_dict.pop(key)
@@ -243,6 +245,10 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
243245
else:
244246
te2_state_dict[diffusers_name] = state_dict.pop(key)
245247
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
248+
# OneTrainer specificity
249+
elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
250+
te2_state_dict[diffusers_name] = state_dict.pop(key)
251+
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
246252

247253
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
248254
dora_scale_key_to_replace_te = (
@@ -270,7 +276,7 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
270276
network_alphas.update({new_name: alpha})
271277

272278
if len(state_dict) > 0:
273-
raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}")
279+
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
274280

275281
logger.info("Kohya-style checkpoint detected.")
276282
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}

src/diffusers/utils/state_dict_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class StateDictType(enum.Enum):
6262
".out_proj.lora_linear_layer.down": ".out_proj.lora_A",
6363
".lora_linear_layer.up": ".lora_B",
6464
".lora_linear_layer.down": ".lora_A",
65+
"text_projection.lora.down.weight": "text_projection.lora_A.weight",
66+
"text_projection.lora.up.weight": "text_projection.lora_B.weight",
6567
}
6668

6769
DIFFUSERS_OLD_TO_PEFT = {

0 commit comments

Comments
 (0)