Skip to content

Commit b75f8d9

Browse files
[Serialization] support loading torch state dict from disk (#2687)
* add first version of state dict loading helpers * Rename function * Update documentation * Update documentation * Fix typo * change titles * remove file * fix docstrings * fix test for torch<=2.1.0 * changes post-review * fix importing * fix static imports * fix documentation * add requires decorator to the test * Add mmap parameter * fix Windows path escaping issue in regex match * pass device when loading safetensors
1 parent 51b866f commit b75f8d9

File tree

5 files changed

+656
-17
lines changed

5 files changed

+656
-17
lines changed

docs/source/en/package_reference/serialization.md

+19-3
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,22 @@ rendered properly in your Markdown viewer.
44

55
# Serialization
66

7-
`huggingface_hub` contains helpers to help ML libraries serialize models weights in a standardized way. This part of the lib is still under development and will be improved in future releases. The goal is to harmonize how weights are serialized on the Hub, both to remove code duplication across libraries and to foster conventions on the Hub.
7+
`huggingface_hub` provides helpers to save and load ML model weights in a standardized way. This part of the library is still under development and will be improved in future releases. The goal is to harmonize how weights are saved and loaded across the Hub, both to remove code duplication across libraries and to establish consistent conventions.
88

9-
## Save torch state dict
9+
## Saving
1010

1111
The main helper of the `serialization` module takes a torch `nn.Module` as input and saves it to disk. It handles the logic to save shared tensors (see [safetensors explanation](https://huggingface.co/docs/safetensors/torch_shared_tensors)) as well as logic to split the state dictionary into shards, using [`split_torch_state_dict_into_shards`] under the hood. At the moment, only `torch` framework is supported.
1212

1313
If you want to save a state dictionary (e.g. a mapping between layer names and related tensors) instead of a `nn.Module`, you can use [`save_torch_state_dict`] which provides the same features. This is useful for example if you want to apply custom logic to the state dict before saving it.
1414

15+
### save_torch_model
16+
1517
[[autodoc]] huggingface_hub.save_torch_model
1618

19+
### save_torch_state_dict
20+
1721
[[autodoc]] huggingface_hub.save_torch_state_dict
1822

19-
## Split state dict into shards
2023

2124
The `serialization` module also contains low-level helpers to split a state dictionary into several shards, while creating a proper index in the process. These helpers are available for `torch` and `tensorflow` tensors and are designed to be easily extended to any other ML frameworks.
2225

@@ -34,6 +37,19 @@ This is the underlying factory from which each framework-specific helper is deri
3437

3538
[[autodoc]] huggingface_hub.split_state_dict_into_shards_factory
3639

40+
## Loading
41+
42+
The loading helpers support both single-file and sharded checkpoints in either safetensors or pickle format. [`load_torch_model`] takes a `nn.Module` and a checkpoint path (either a single file or a directory) as input and load the weights into the model.
43+
44+
### load_torch_model
45+
46+
[[autodoc]] huggingface_hub.load_torch_model
47+
48+
### load_state_dict_from_file
49+
50+
[[autodoc]] huggingface_hub.load_state_dict_from_file
51+
52+
3753
## Helpers
3854

3955
### get_torch_storage_id

src/huggingface_hub/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,8 @@
461461
"get_tf_storage_size",
462462
"get_torch_storage_id",
463463
"get_torch_storage_size",
464+
"load_state_dict_from_file",
465+
"load_torch_model",
464466
"save_torch_model",
465467
"save_torch_state_dict",
466468
"split_state_dict_into_shards_factory",
@@ -987,6 +989,8 @@ def __dir__():
987989
get_tf_storage_size, # noqa: F401
988990
get_torch_storage_id, # noqa: F401
989991
get_torch_storage_size, # noqa: F401
992+
load_state_dict_from_file, # noqa: F401
993+
load_torch_model, # noqa: F401
990994
save_torch_model, # noqa: F401
991995
save_torch_state_dict, # noqa: F401
992996
split_state_dict_into_shards_factory, # noqa: F401

src/huggingface_hub/serialization/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from ._torch import (
2020
get_torch_storage_id,
2121
get_torch_storage_size,
22+
load_state_dict_from_file,
23+
load_torch_model,
2224
save_torch_model,
2325
save_torch_state_dict,
2426
split_torch_state_dict_into_shards,

0 commit comments

Comments
 (0)