-
Notifications
You must be signed in to change notification settings - Fork 618
add resize_token_embedding functionality #2542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2542
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 4 Cancelled JobsAs of commit 73546c1 with merge base 137eec3 ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2542 +/- ##
==========================================
- Coverage 65.78% 64.09% -1.70%
==========================================
Files 396 396
Lines 23764 23796 +32
==========================================
- Hits 15634 15252 -382
- Misses 8130 8544 +414 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this and for the thorough testing! Left a bunch of comments but this is looking pretty good. Lmk if anything I suggested is unclear
utils.log_rank_zero( | ||
log, | ||
f"Tokenizer vocab size ({self._tokenizer.vocab_size}) is and " | ||
f"embedding size ({self._model.tok_embeddings.num_embeddings}).", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't necessarily need to log this in general. Instead I would recommend logging inside of the resize_token_embeddings
function only in the case that the tokenizer vocab size differs from the model's vocab size
log.info( | ||
f"Tokenizer vocab size ({self._tokenizer.vocab_size}) differs from model " | ||
f"embedding size ({self._model.tok_embeddings.num_embeddings}). Resizing model embeddings..." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar here.. I would actually just log this directly in the utility (and same for the corresponding log in the distributed recipe)
recipes/full_finetune_distributed.py
Outdated
self._model.tok_embeddings.unshard() | ||
if not isinstance(self._model.output, TiedLinear): self._model.output.unshard() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Crazy idea, but I wonder whether it would be helpful to define a context manager for things like this. E.g. if I have an operation I know how to apply to a vanilla tensor, but maybe the input is a DTensor, maybe it's not. Then given an optional device mesh, we either apply the op to a vanilla tensor, or for a DTensor, we unshard, apply the op, then reshard to match original. @weifengpy thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, the context manager would essentially handle the loop of d_tensor.full_tensor() --> user op --> redistribute_tensor(tensor, mesh)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah exactly. Doesn't have to be done in this PR anyways, just an idea for a future quality-of-life improvement
recipes/full_finetune_distributed.py
Outdated
# Unshard the FSDP wrapped token embeddings and output projection | ||
# tok_embeddings and output are wrapped by fully_shard, we can call .unshard() | ||
# direclty on those layers only | ||
self._model.tok_embeddings.unshard() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's possible that self._model
is not an instance of TransformerDecoder
. E.g. for multimodal models we provide both EarlyFusionModel and DeepFusionModel, so we should at least be able to handle both of these gracefully. One option is to just do all this directly in the resize_token_embeddings
function, otherwise we start to bloat the recipe code with lots of if/else checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it - I'll move all of this within the function and handle for the various types of model's that can be passed in
torchtune/modules/common_utils.py
Outdated
Resizes the token embeddings and the final output projection layer of a TransformerDecoder model. | ||
|
||
This function modifies the model in-place. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would also explain (somewhere) in the docstring that you're initializing new positions using the mean of the old embeddings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added - I think in a follow up we can add a "init_strategy" optional arg to the function and abstract the resize op to it's own helper as we add additional init strategies (+ user defined ones)
torchtune/modules/common_utils.py
Outdated
|
||
output_layer = model.output | ||
if isinstance(output_layer, TiedLinear): | ||
if output_layer.tied_module is old_embeddings: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dumb q: is this just checking equality of the underlying data_ptr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is checking if object points to the same memory as tok_embedding's nn.Embedding, underlying data_ptr() isn't checked. it is better check is to use data_ptr() instead if we have output and tok_embeddings as seperate nn.Embeddings for any case.
however, if output_layer is tied, we should always replace the tied_module with new tok_embedding - i feel like this check is not needed?
torchtune/modules/common_utils.py
Outdated
@@ -442,3 +448,131 @@ def delete_kv_caches(model: nn.Module): | |||
if hasattr(module, "kv_cache") and callable(module.kv_cache): | |||
module.cache_enabled = False | |||
module.kv_cache = None | |||
|
|||
|
|||
def resize_token_embeddings(model: nn.Module, num_embeddings: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This currently wouldn't work with fusion models or models that use the FusionEmbedding layer. For a fusion model you want to check for the "decoder" and grab that instead of model.
To support resizing FusionEmbeddings you would want resize the embedding parameter and num_embeddings. It would also be reasonable to throw an error and not support resizing FusionEmbeddings as they're already a different technique for resizing embeddings but I don't think it would be too much harder to support them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah this is my fault. I was thinking that the usage of FusionEmbedding
means that the model itself is a fusion model, but forgot that FusionEmbedding
can be the embedding layer for a TransformerDecoder
. In that case I agree, we should make a slight change here. Will leave a diff of what I think makes sense here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added FusionEmbedding
resizing into the follow up tasks. I agree from looking at the module, we can extend the function to resize the fusion_embedding instead with ease
This PR introduces a resize_token_embeddings method which allows users to add additional tokens to be trained as part of fine tuning. It supports both single and multi device setups. The function can be easily extended to support additional initialization strategies.
For the distributed case, we unshard() the token embedding and output layer before passing into resize_token_embedding(). This way we can use the same logic as the single device case as the parameters are gathered on to the rank. Once complete, we fully_shard() the new token embedding and output layers utilizing the same mesh used for the initial sharding. This is the recommended way to resize FSDP2 sharded parameters. Thanks to @weifengpy for his guidance on the FSDP2 sharding mechanism.
The usage is,
On single device and distributed setups:
When using a recipe, (right now enabled for
full_finetune_distributed
andfull_finetune_single_device
), addresize_token_embeddings = True
in YAML.Example,
Follow up tasks:
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
Unit Testing:
Unit tests cover both single device and distributed cases, we check the new nn.Embedding size as well as the mean values of the new tokens to ensure they were init correctly. The E2E case covers testing a training run. We can see training completes successfully in the resized model that is sharded across 8 devices.
Without resizing, we can see special token init values on LLama 3 8B base model are very small, taking a look at token

<|start_header_id|> = 128006
With resizing, we can see special tokens are init to the mean of 128k tokens on LLama 8B base:

E2E Test
Resizing the last 256 tokens in Llama 3 tokenizer to mean embedding of previous 120k tokens. Added an additional 4 tokens to resize the entire layer. We don't see significant difference in loss (run for 2 epochs, 20 steps each), but we show with the resized embedding in distributed setup we're able to finish training with at least equivalent loss.
vs. baseline

Running,
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example