@@ -1791,7 +1791,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
1791
1791
1792
1792
file_name0 = os .path .join (os .path .join (tmp_dirname , "0" ), "pytorch_lora_weights.safetensors" )
1793
1793
file_name1 = os .path .join (os .path .join (tmp_dirname , "1" ), "pytorch_lora_weights.safetensors" )
1794
- unet .load_lora_adapter (file_name0 , safe_serialization = True , adapter_name = "adapter0" )
1794
+ unet .load_lora_adapter (file_name0 , safe_serialization = True , adapter_name = "adapter0" , prefix = None )
1795
1795
1796
1796
if do_compile :
1797
1797
unet = torch .compile (unet , mode = "reduce-overhead" )
@@ -1801,7 +1801,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
1801
1801
assert torch .allclose (output0_before , output0_after , atol = tol , rtol = tol )
1802
1802
1803
1803
# hotswap the 2nd adapter
1804
- unet .load_lora_adapter (file_name1 , adapter_name = "adapter0" , hotswap = True )
1804
+ unet .load_lora_adapter (file_name1 , adapter_name = "adapter0" , hotswap = True , prefix = None )
1805
1805
1806
1806
# we need to call forward to potentially trigger recompilation
1807
1807
with torch .inference_mode ():
@@ -1812,7 +1812,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
1812
1812
name = "does-not-exist"
1813
1813
msg = f"Trying to hotswap LoRA adapter '{ name } ' but there is no existing adapter by that name"
1814
1814
with self .assertRaisesRegex (ValueError , msg ):
1815
- unet .load_lora_adapter (file_name1 , adapter_name = name , hotswap = True )
1815
+ unet .load_lora_adapter (file_name1 , adapter_name = name , hotswap = True , prefix = None )
1816
1816
1817
1817
@parameterized .expand ([(11 , 11 ), (7 , 13 ), (13 , 7 )]) # important to test small to large and vice versa
1818
1818
def test_hotswapping_model (self , rank0 , rank1 ):
0 commit comments