Skip to content

[LoRA] text encoder: read the ranks for all the attn modules #8324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,17 +566,18 @@ def load_lora_into_text_encoder(
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)

for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_B.weight"
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]

patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight"

rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]

for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]

if network_alphas is not None:
alpha_keys = [
Expand Down
59 changes: 59 additions & 0 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,65 @@ def test_simple_inference_with_text_lora_save_load(self):
"Loading from saved checkpoints should give same results.",
)

def test_simple_inference_with_partial_text_lora(self):
"""
Tests a simple inference with lora attached on the text encoder
with different ranks and some adapters removed
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, _ = self.get_dummy_components(scheduler_cls)
text_lora_config = LoraConfig(
r=4,
rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
lora_alpha=4,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
init_lora_weights=False,
use_dora=False,
)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))

pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
state_dict = {
f"text_encoder.{module_name}": param
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
}

if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
state_dict.update(
{
f"text_encoder.{module_name}": param
for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
}
)

# Discard half of the adapters.
rng = np.random.default_rng(0)
key2adapters = {k: k.rsplit(".", 2)[0] for k in state_dict.keys()}
adapters = list(set(key2adapters.values()))
adapters = set(rng.choice(adapters, size=len(adapters) // 2, replace=False))
state_dict = {k: state_dict[k] for k, adapter in key2adapters.items() if adapter in adapters}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep this behavior truly deterministic, i.e., don't rely on rng.choice() and just manually remove the adapters. I prefer this to be a slightly better for testing.

Copy link
Member

@sayakpaul sayakpaul Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think that the test should:

  • Run a single inference pass after the LoRAs have been added to the text encoder. Then we compare the outputs to the non-LoRA case.
  • Run another inference pass with adapters discarded and compare the outputs to the original LoRA outputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the test, now removing all adapters from text_model.encoder.layers.4.


# Unload lora and load it back using the pipe.load_lora_weights machinery
pipe.unload_lora_weights()
pipe.load_lora_weights(state_dict)

output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)

def test_simple_inference_save_pretrained(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
Expand Down
Loading