-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] text encoder: read the ranks for all the attn modules #8324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] text encoder: read the ranks for all the attn modules #8324
Conversation
* In addition to out_proj, read the ranks of adapters for q_proj, k_proj, and v_proj * Allow missing adapters (UNet already supports this)
Thanks for your contributions! Could you also add a test to ensure the robustness of the changes? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean, thanks !
@sayakpaul I may need some guidance for writing the tests. Testing this change requires a specially resized LoRA.
Alternatively the test could drop layers and ranks in the state_dict, but the output will be nonsensical without orthogonalization. |
Sure.
So, we could add a utility to create a WDYT? |
I added a test that takes inspiration from A |
8dae060
to
a74b84c
Compare
tests/lora/utils.py
Outdated
# Discard half of the adapters. | ||
rng = np.random.default_rng(0) | ||
key2adapters = {k: k.rsplit(".", 2)[0] for k in state_dict.keys()} | ||
adapters = list(set(key2adapters.values())) | ||
adapters = set(rng.choice(adapters, size=len(adapters) // 2, replace=False)) | ||
state_dict = {k: state_dict[k] for k, adapter in key2adapters.items() if adapter in adapters} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should keep this behavior truly deterministic, i.e., don't rely on rng.choice()
and just manually remove the adapters. I prefer this to be a slightly better for testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also think that the test should:
- Run a single inference pass after the LoRAs have been added to the text encoder. Then we compare the outputs to the non-LoRA case.
- Run another inference pass with adapters discarded and compare the outputs to the original LoRA outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the test, now removing all adapters from text_model.encoder.layers.4
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the tests. I left some comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a minor comment. Rest looks really good to me. Thanks for taking care of it!
Thanks so much for the iterations! Will merge once the CI is green. |
What is keeping this from being merged? |
That will be on me. I forgot to merge it. Let the CI run once again and I will merge. |
Ah looks like need to run a couple of quality related formatting. Keeping it open and will take care of it once I am back to my keyboard. |
The failing test is completely unrelated and sorry for the delay on my end. Thank you for your amazing work! |
* [LoRA] text encoder: read the ranks for all the attn modules * In addition to out_proj, read the ranks of adapters for q_proj, k_proj, and v_proj * Allow missing adapters (UNet already supports this) * ruff format loaders.lora * [LoRA] add tests for partial text encoders LoRAs * [LoRA] update test_simple_inference_with_partial_text_lora to be deterministic * [LoRA] comment justifying test_simple_inference_with_partial_text_lora * style --------- Co-authored-by: Sayak Paul <[email protected]>
What does this PR do?
The enumeration of LoRA adapters for gathering their ranks in
LoraLoaderMixin.load_lora_into_text_encoder
is modified in order to:In addition to out_proj, read the ranks of adapters for q_proj, k_proj, and v_proj
Allow missing adapters (UNet already supports this)
Motivation
Some resized LoRAs have a different rank for each adapter module. Currently, only the ranks of the
out_proj
attention modules are gathered in therank
dictionary. Using aLoraConfig
with missingrank_patterns
result in the following errors:UNet supports missing adapter modules, but the text encoder doesn't, resulting in:
encoder.layers.11.*.lora_B.weight
is often all zeroes when the layer is skipped during fine-tuning, in this case it might be desirable to not store them in LoRA safetensors.Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul is the author of the lines affected by this PR.