Skip to content

Commit 124ac3e

Browse files
authored
[LoRA] feat: support non-diffusers wan t2v loras. (#11059)
feat: support non-diffusers wan t2v loras.
1 parent 2f0f281 commit 124ac3e

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,7 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
13551355
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
13561356

13571357
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
1358+
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
13581359

13591360
for i in range(num_blocks):
13601361
# Self-attention
@@ -1374,13 +1375,15 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
13741375
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
13751376
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
13761377
)
1377-
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1378-
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1379-
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1380-
)
1381-
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1382-
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1383-
)
1378+
1379+
if is_i2v_lora:
1380+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1381+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1382+
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1383+
)
1384+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1385+
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1386+
)
13841387

13851388
# FFN
13861389
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):

0 commit comments

Comments
 (0)