Skip to content

Commit 298ce67

Browse files
[LoRA] text encoder: read the ranks for all the attn modules (#8324)
* [LoRA] text encoder: read the ranks for all the attn modules * In addition to out_proj, read the ranks of adapters for q_proj, k_proj, and v_proj * Allow missing adapters (UNet already supports this) * ruff format loaders.lora * [LoRA] add tests for partial text encoders LoRAs * [LoRA] update test_simple_inference_with_partial_text_lora to be deterministic * [LoRA] comment justifying test_simple_inference_with_partial_text_lora * style --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent d2e7a19 commit 298ce67

File tree

2 files changed

+75
-11
lines changed

2 files changed

+75
-11
lines changed

src/diffusers/loaders/lora.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -462,17 +462,18 @@ def load_lora_into_text_encoder(
462462
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
463463

464464
for name, _ in text_encoder_attn_modules(text_encoder):
465-
rank_key = f"{name}.out_proj.lora_B.weight"
466-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
467-
468-
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
469-
if patch_mlp:
470-
for name, _ in text_encoder_mlp_modules(text_encoder):
471-
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
472-
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
473-
474-
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
475-
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
465+
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
466+
rank_key = f"{name}.{module}.lora_B.weight"
467+
if rank_key not in text_encoder_lora_state_dict:
468+
continue
469+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
470+
471+
for name, _ in text_encoder_mlp_modules(text_encoder):
472+
for module in ("fc1", "fc2"):
473+
rank_key = f"{name}.{module}.lora_B.weight"
474+
if rank_key not in text_encoder_lora_state_dict:
475+
continue
476+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
476477

477478
if network_alphas is not None:
478479
alpha_keys = [

tests/lora/utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,69 @@ def test_simple_inference_with_text_lora_save_load(self):
395395
"Loading from saved checkpoints should give same results.",
396396
)
397397

398+
def test_simple_inference_with_partial_text_lora(self):
399+
"""
400+
Tests a simple inference with lora attached on the text encoder
401+
with different ranks and some adapters removed
402+
and makes sure it works as expected
403+
"""
404+
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
405+
components, _, _ = self.get_dummy_components(scheduler_cls)
406+
# Verify `LoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
407+
text_lora_config = LoraConfig(
408+
r=4,
409+
rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
410+
lora_alpha=4,
411+
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
412+
init_lora_weights=False,
413+
use_dora=False,
414+
)
415+
pipe = self.pipeline_class(**components)
416+
pipe = pipe.to(torch_device)
417+
pipe.set_progress_bar_config(disable=None)
418+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
419+
420+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
421+
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
422+
423+
pipe.text_encoder.add_adapter(text_lora_config)
424+
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
425+
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
426+
# supports missing layers (PR#8324).
427+
state_dict = {
428+
f"text_encoder.{module_name}": param
429+
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
430+
if "text_model.encoder.layers.4" not in module_name
431+
}
432+
433+
if self.has_two_text_encoders:
434+
pipe.text_encoder_2.add_adapter(text_lora_config)
435+
self.assertTrue(
436+
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
437+
)
438+
state_dict.update(
439+
{
440+
f"text_encoder_2.{module_name}": param
441+
for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
442+
if "text_model.encoder.layers.4" not in module_name
443+
}
444+
)
445+
446+
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
447+
self.assertTrue(
448+
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
449+
)
450+
451+
# Unload lora and load it back using the pipe.load_lora_weights machinery
452+
pipe.unload_lora_weights()
453+
pipe.load_lora_weights(state_dict)
454+
455+
output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
456+
self.assertTrue(
457+
not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
458+
"Removing adapters should change the output",
459+
)
460+
398461
def test_simple_inference_save_pretrained(self):
399462
"""
400463
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained

0 commit comments

Comments
 (0)