Skip to content

Commit a74b84c

Browse files
committed
[LoRA] add tests for partial text encoders LoRAs
1 parent 57f3fa7 commit a74b84c

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

tests/lora/utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,65 @@ def test_simple_inference_with_text_lora_save_load(self):
395395
"Loading from saved checkpoints should give same results.",
396396
)
397397

398+
def test_simple_inference_with_partial_text_lora(self):
399+
"""
400+
Tests a simple inference with lora attached on the text encoder
401+
with different ranks and some adapters removed
402+
and makes sure it works as expected
403+
"""
404+
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
405+
components, _, _ = self.get_dummy_components(scheduler_cls)
406+
text_lora_config = LoraConfig(
407+
r=4,
408+
rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
409+
lora_alpha=4,
410+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
411+
init_lora_weights=False,
412+
use_dora=False,
413+
)
414+
pipe = self.pipeline_class(**components)
415+
pipe = pipe.to(torch_device)
416+
pipe.set_progress_bar_config(disable=None)
417+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
418+
419+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
420+
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
421+
422+
pipe.text_encoder.add_adapter(text_lora_config)
423+
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
424+
state_dict = {
425+
f"text_encoder.{module_name}": param
426+
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
427+
}
428+
429+
if self.has_two_text_encoders:
430+
pipe.text_encoder_2.add_adapter(text_lora_config)
431+
self.assertTrue(
432+
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
433+
)
434+
state_dict.update(
435+
{
436+
f"text_encoder.{module_name}": param
437+
for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
438+
}
439+
)
440+
441+
# Discard half of the adapters.
442+
rng = np.random.default_rng(0)
443+
key2adapters = {k: k.rsplit(".", 2)[0] for k in state_dict.keys()}
444+
adapters = list(set(key2adapters.values()))
445+
adapters = set(rng.choice(adapters, size=len(adapters) // 2, replace=False))
446+
state_dict = {k: state_dict[k] for k, adapter in key2adapters.items() if adapter in adapters}
447+
448+
# Unload lora and load it back using the pipe.load_lora_weights machinery
449+
pipe.unload_lora_weights()
450+
pipe.load_lora_weights(state_dict)
451+
452+
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
453+
self.assertTrue(
454+
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
455+
)
456+
398457
def test_simple_inference_save_pretrained(self):
399458
"""
400459
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained

0 commit comments

Comments
 (0)