Skip to content

[LoRA] text encoder: read the ranks for all the attn modules #8324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 18, 2024

Conversation

elias-gaeros
Copy link
Contributor

@elias-gaeros elias-gaeros commented May 30, 2024

What does this PR do?

The enumeration of LoRA adapters for gathering their ranks in LoraLoaderMixin.load_lora_into_text_encoder is modified in order to:

  • In addition to out_proj, read the ranks of adapters for q_proj, k_proj, and v_proj

  • Allow missing adapters (UNet already supports this)

Motivation

Some resized LoRAs have a different rank for each adapter module. Currently, only the ranks of the out_proj attention modules are gathered in the rank dictionary. Using a LoraConfig with missing rank_patterns result in the following errors:

RuntimeError: Error(s) in loading state_dict for CLIPTextModel:
        size mismatch for text_model.encoder.layers.0.self_attn.k_proj.lora_A.default_0.weight: copying a param with shape torch.Size([8, 768]) from checkpoint, the shape in current model is torch.Size([3, 768]).
        size mismatch for text_model.encoder.layers.0.self_attn.k_proj.lora_B.default_0.weight: copying a param with shape torch.Size([768, 8]) from checkpoint, the shape in current model is torch.Size([768, 3]).

UNet supports missing adapter modules, but the text encoder doesn't, resulting in:

  File "/home/elias/repos/diffusers/src/diffusers/loaders/lora.py", line 1363, in load_lora_weights
    self.load_lora_into_text_encoder(
  File "/home/elias/repos/diffusers/src/diffusers/loaders/lora.py", line 572, in load_lora_into_text_encoder
    rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
                     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^
KeyError: 'text_model.encoder.layers.11.self_attn.out_proj.lora_B.weight'

encoder.layers.11.*.lora_B.weight is often all zeroes when the layer is skipped during fine-tuning, in this case it might be desirable to not store them in LoRA safetensors.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you read our philosophy doc (important for complex PRs)?
  • Was this discussed/approved via a GitHub issue or the forum? I could'nt find an open issue or forum thread about the problems
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul is the author of the lines affected by this PR.

 * In addition to out_proj, read the ranks of adapters for q_proj, k_proj, and  v_proj

 * Allow missing adapters (UNet already supports this)
@yiyixuxu yiyixuxu requested a review from sayakpaul May 30, 2024 23:08
@sayakpaul sayakpaul requested a review from younesbelkada May 30, 2024 23:15
@sayakpaul
Copy link
Member

Thanks for your contributions!

Could you also add a test to ensure the robustness of the changes?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean, thanks !

@elias-gaeros
Copy link
Contributor Author

elias-gaeros commented May 31, 2024

Thanks for your contributions!

Could you also add a test to ensure the robustness of the changes?

@sayakpaul I may need some guidance for writing the tests. Testing this change requires a specially resized LoRA.
What about:

Alternatively the test could drop layers and ranks in the state_dict, but the output will be nonsensical without orthogonalization.

@sayakpaul
Copy link
Member

Sure.

@sayakpaul I may need some guidance for writing the tests. Testing this change requires a specially resized LoRA.

So, we could add a utility to create a peft LoRA config that tests this use case specifically. And then add an actual test case for it. We could add the utility and the test case to this file.

WDYT?

@elias-gaeros
Copy link
Contributor Author

So, we could add a utility to create a peft LoRA config that tests this use case specifically. And then add an actual test case for it. We could add the utility and the test case to this file.

WDYT?

I added a test that takes inspiration from test_simple_inference_with_text_unet_lora_save_load since it's is one of the few tests that actually use LoraLoaderMixin.load_lora_into_text_encoder.

A LoraConfig is used for assigning different ranks to each projection layer. After initializing the adapters, the state_dict is extracted and half of the layers are discarded before reloading them using load_lora_weights.

@elias-gaeros elias-gaeros force-pushed the lora_te_read_all_ranks branch from 8dae060 to a74b84c Compare June 6, 2024 14:42
Comment on lines 441 to 446
# Discard half of the adapters.
rng = np.random.default_rng(0)
key2adapters = {k: k.rsplit(".", 2)[0] for k in state_dict.keys()}
adapters = list(set(key2adapters.values()))
adapters = set(rng.choice(adapters, size=len(adapters) // 2, replace=False))
state_dict = {k: state_dict[k] for k, adapter in key2adapters.items() if adapter in adapters}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep this behavior truly deterministic, i.e., don't rely on rng.choice() and just manually remove the adapters. I prefer this to be a slightly better for testing.

Copy link
Member

@sayakpaul sayakpaul Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think that the test should:

  • Run a single inference pass after the LoRAs have been added to the text encoder. Then we compare the outputs to the non-LoRA case.
  • Run another inference pass with adapters discarded and compare the outputs to the original LoRA outputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the test, now removing all adapters from text_model.encoder.layers.4.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tests. I left some comments.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a minor comment. Rest looks really good to me. Thanks for taking care of it!

@sayakpaul
Copy link
Member

Thanks so much for the iterations! Will merge once the CI is green.

@elias-gaeros
Copy link
Contributor Author

What is keeping this from being merged?

@sayakpaul
Copy link
Member

That will be on me. I forgot to merge it. Let the CI run once again and I will merge.

@sayakpaul
Copy link
Member

Ah looks like need to run a couple of quality related formatting. Keeping it open and will take care of it once I am back to my keyboard.

@sayakpaul
Copy link
Member

I am unable to locally reproduce the LoRA related test failures :/ Seems like some NumPy version mismatch bug. Can quickly open a PR to pin the NumPy version to be down under 2 so that we can unblock the PRs. @yiyixuxu @DN6 WDYT?

@sayakpaul sayakpaul merged commit 298ce67 into huggingface:main Jun 18, 2024
14 of 15 checks passed
@sayakpaul
Copy link
Member

sayakpaul commented Jun 18, 2024

The failing test is completely unrelated and sorry for the delay on my end. Thank you for your amazing work!

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* [LoRA] text encoder: read the ranks for all the attn modules

 * In addition to out_proj, read the ranks of adapters for q_proj, k_proj, and  v_proj

 * Allow missing adapters (UNet already supports this)

* ruff format loaders.lora

* [LoRA] add tests for partial text encoders LoRAs

* [LoRA] update test_simple_inference_with_partial_text_lora to be deterministic

* [LoRA] comment justifying test_simple_inference_with_partial_text_lora

* style

---------

Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants