Skip to content

Commit 115c77d

Browse files
Fix LoRA loading call to add prefix=None
See: huggingface#10187 (comment)
1 parent 425cb39 commit 115c77d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/models/test_modeling_common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,7 +1791,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
17911791

17921792
file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
17931793
file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
1794-
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0")
1794+
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
17951795

17961796
if do_compile:
17971797
unet = torch.compile(unet, mode="reduce-overhead")
@@ -1801,7 +1801,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18011801
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
18021802

18031803
# hotswap the 2nd adapter
1804-
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True)
1804+
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
18051805

18061806
# we need to call forward to potentially trigger recompilation
18071807
with torch.inference_mode():
@@ -1812,7 +1812,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18121812
name = "does-not-exist"
18131813
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
18141814
with self.assertRaisesRegex(ValueError, msg):
1815-
unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True)
1815+
unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
18161816

18171817
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
18181818
def test_hotswapping_model(self, rank0, rank1):

0 commit comments

Comments
 (0)