Skip to content

add resize_token_embedding functionality #2542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 0 commits into from

Conversation

iamzainhuda
Copy link
Contributor

@iamzainhuda iamzainhuda commented Apr 1, 2025

This PR introduces a resize_token_embeddings method which allows users to add additional tokens to be trained as part of fine tuning. It supports both single and multi device setups. The function can be easily extended to support additional initialization strategies.

For the distributed case, we unshard() the token embedding and output layer before passing into resize_token_embedding(). This way we can use the same logic as the single device case as the parameters are gathered on to the rank. Once complete, we fully_shard() the new token embedding and output layers utilizing the same mesh used for the initial sharding. This is the recommended way to resize FSDP2 sharded parameters. Thanks to @weifengpy for his guidance on the FSDP2 sharding mechanism.

The usage is,

On single device and distributed setups:

model = setup_model(...)
tokenizer = setup_tokenizer(...)
resize_token_embeddings(model, tokenizer.vocab_size)

When using a recipe, (right now enabled for full_finetune_distributed and full_finetune_single_device), add resize_token_embeddings = True in YAML.
Example,

...
max_steps_per_epoch: null
gradient_accumulation_steps: 1  # Use to increase effective batch size
clip_grad_norm: null
compile: False  # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False  # True saves memory. Requires gradient_accumulation_steps=1
resize_token_embeddings: True
...

Follow up tasks:

  • additional embedding weight init methods
  • loading checkpoint from resized model requires model to be created with new vocab size
  • FusionEmbedding resizing as part of TransformerDecoder

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

  • Added resize_token_embedding() utility that allows users to flexibly change the size of the token embeddings and corresponding output projection.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Unit Testing:

Unit tests cover both single device and distributed cases, we check the new nn.Embedding size as well as the mean values of the new tokens to ensure they were init correctly. The E2E case covers testing a training run. We can see training completes successfully in the resized model that is sharded across 8 devices.

Without resizing, we can see special token init values on LLama 3 8B base model are very small, taking a look at token <|start_header_id|> = 128006
image

With resizing, we can see special tokens are init to the mean of 128k tokens on LLama 8B base:
image

E2E Test

Resizing the last 256 tokens in Llama 3 tokenizer to mean embedding of previous 120k tokens. Added an additional 4 tokens to resize the entire layer. We don't see significant difference in loss (run for 2 epochs, 20 steps each), but we show with the resized embedding in distributed setup we're able to finish training with at least equivalent loss.

image

vs. baseline
image

Running,

tune run --nproc_per_node 8 full_finetune_distributed --config recipes/configs/llama3/8B_full.yaml max_steps_per_epoch=20 epochs=2 checkpointer.checkpoint_dir=/tmp/Meta-Llama-3-8B

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Apr 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2542

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 4 Cancelled Jobs

As of commit 73546c1 with merge base 137eec3 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 1, 2025
@codecov-commenter
Copy link

codecov-commenter commented Apr 23, 2025

Codecov Report

Attention: Patch coverage is 7.14286% with 13 lines in your changes missing coverage. Please review.

Project coverage is 64.09%. Comparing base (f3e4747) to head (1a47859).
Report is 8 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/modules/transformer.py 7.14% 13 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2542      +/-   ##
==========================================
- Coverage   65.78%   64.09%   -1.70%     
==========================================
  Files         396      396              
  Lines       23764    23796      +32     
==========================================
- Hits        15634    15252     -382     
- Misses       8130     8544     +414     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@iamzainhuda iamzainhuda changed the title [prototype][not for review] add resize_embedding method [not for review] add resize_embedding method May 5, 2025
@iamzainhuda iamzainhuda changed the title [not for review] add resize_embedding method [not for review] add resize_token_embedding functionality May 5, 2025
@iamzainhuda iamzainhuda changed the title [not for review] add resize_token_embedding functionality add resize_token_embedding functionality May 6, 2025
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this and for the thorough testing! Left a bunch of comments but this is looking pretty good. Lmk if anything I suggested is unclear

Comment on lines 281 to 285
utils.log_rank_zero(
log,
f"Tokenizer vocab size ({self._tokenizer.vocab_size}) is and "
f"embedding size ({self._model.tok_embeddings.num_embeddings}).",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't necessarily need to log this in general. Instead I would recommend logging inside of the resize_token_embeddings function only in the case that the tokenizer vocab size differs from the model's vocab size

Comment on lines 287 to 290
log.info(
f"Tokenizer vocab size ({self._tokenizer.vocab_size}) differs from model "
f"embedding size ({self._model.tok_embeddings.num_embeddings}). Resizing model embeddings..."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar here.. I would actually just log this directly in the utility (and same for the corresponding log in the distributed recipe)

Comment on lines 330 to 331
self._model.tok_embeddings.unshard()
if not isinstance(self._model.output, TiedLinear): self._model.output.unshard()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Crazy idea, but I wonder whether it would be helpful to define a context manager for things like this. E.g. if I have an operation I know how to apply to a vanilla tensor, but maybe the input is a DTensor, maybe it's not. Then given an optional device mesh, we either apply the op to a vanilla tensor, or for a DTensor, we unshard, apply the op, then reshard to match original. @weifengpy thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the context manager would essentially handle the loop of d_tensor.full_tensor() --> user op --> redistribute_tensor(tensor, mesh)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah exactly. Doesn't have to be done in this PR anyways, just an idea for a future quality-of-life improvement

# Unshard the FSDP wrapped token embeddings and output projection
# tok_embeddings and output are wrapped by fully_shard, we can call .unshard()
# direclty on those layers only
self._model.tok_embeddings.unshard()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible that self._model is not an instance of TransformerDecoder. E.g. for multimodal models we provide both EarlyFusionModel and DeepFusionModel, so we should at least be able to handle both of these gracefully. One option is to just do all this directly in the resize_token_embeddings function, otherwise we start to bloat the recipe code with lots of if/else checks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it - I'll move all of this within the function and handle for the various types of model's that can be passed in

Comment on lines 454 to 456
Resizes the token embeddings and the final output projection layer of a TransformerDecoder model.

This function modifies the model in-place.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also explain (somewhere) in the docstring that you're initializing new positions using the mean of the old embeddings

Copy link
Contributor Author

@iamzainhuda iamzainhuda May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added - I think in a follow up we can add a "init_strategy" optional arg to the function and abstract the resize op to it's own helper as we add additional init strategies (+ user defined ones)


output_layer = model.output
if isinstance(output_layer, TiedLinear):
if output_layer.tied_module is old_embeddings:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb q: is this just checking equality of the underlying data_ptr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is checking if object points to the same memory as tok_embedding's nn.Embedding, underlying data_ptr() isn't checked. it is better check is to use data_ptr() instead if we have output and tok_embeddings as seperate nn.Embeddings for any case.

however, if output_layer is tied, we should always replace the tied_module with new tok_embedding - i feel like this check is not needed?

@@ -442,3 +448,131 @@ def delete_kv_caches(model: nn.Module):
if hasattr(module, "kv_cache") and callable(module.kv_cache):
module.cache_enabled = False
module.kv_cache = None


def resize_token_embeddings(model: nn.Module, num_embeddings: int) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This currently wouldn't work with fusion models or models that use the FusionEmbedding layer. For a fusion model you want to check for the "decoder" and grab that instead of model.

To support resizing FusionEmbeddings you would want resize the embedding parameter and num_embeddings. It would also be reasonable to throw an error and not support resizing FusionEmbeddings as they're already a different technique for resizing embeddings but I don't think it would be too much harder to support them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah this is my fault. I was thinking that the usage of FusionEmbedding means that the model itself is a fusion model, but forgot that FusionEmbedding can be the embedding layer for a TransformerDecoder. In that case I agree, we should make a slight change here. Will leave a diff of what I think makes sense here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added FusionEmbedding resizing into the follow up tasks. I agree from looking at the module, we can extend the function to resize the fusion_embedding instead with ease

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants