Skip to content

Commit ca3f674

Browse files
remove context manager when loading shards and handle mlx weights (#2709)
1 parent 4b0b179 commit ca3f674

File tree

1 file changed

+11
-33
lines changed

1 file changed

+11
-33
lines changed

src/huggingface_hub/serialization/_torch.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818
import os
1919
import re
2020
from collections import defaultdict, namedtuple
21-
from contextlib import contextmanager
2221
from functools import lru_cache
2322
from pathlib import Path
24-
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union
23+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union
2524

2625
from packaging import version
2726

@@ -538,13 +537,15 @@ def _load_sharded_checkpoint(
538537
for shard_file in shard_files:
539538
# Load shard into memory
540539
shard_path = os.path.join(save_directory, shard_file)
541-
with _load_shard_into_memory(
540+
state_dict = load_state_dict_from_file(
542541
shard_path,
543-
load_fn=load_state_dict_from_file,
544-
kwargs={"weights_only": weights_only},
545-
) as state_dict:
546-
# Update model with parameters from this shard
547-
model.load_state_dict(state_dict, strict=strict)
542+
map_location="cpu",
543+
weights_only=weights_only,
544+
)
545+
# Update model with parameters from this shard
546+
model.load_state_dict(state_dict, strict=strict)
547+
# Explicitly remove the state dict from memory
548+
del state_dict
548549

549550
# 4. Return compatibility info
550551
loaded_keys = set(index["weight_map"].keys())
@@ -630,7 +631,8 @@ def load_state_dict_from_file(
630631
# Check format of the archive
631632
with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined]
632633
metadata = f.metadata()
633-
if metadata.get("format") != "pt":
634+
# see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966
635+
if metadata is not None and metadata.get("format") not in ["pt", "mlx"]:
634636
raise OSError(
635637
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
636638
"you save your model with the `save_torch_model` method."
@@ -668,30 +670,6 @@ def load_state_dict_from_file(
668670
# HELPERS
669671

670672

671-
@contextmanager
672-
def _load_shard_into_memory(
673-
shard_path: str,
674-
load_fn: Callable,
675-
kwargs: Optional[Dict[str, Any]] = None,
676-
):
677-
"""
678-
Context manager to handle loading and cleanup of model shards.
679-
680-
Args:
681-
shard_path: Path to the shard file
682-
load_fn: Function to load the shard (either torch.load or safetensors.load)
683-
684-
Yields:
685-
The loaded state dict for this shard
686-
"""
687-
try:
688-
state_dict = load_fn(shard_path, **kwargs) # type: ignore[arg-type]
689-
yield state_dict
690-
finally:
691-
# Explicitly remove the state dict from memory
692-
del state_dict
693-
694-
695673
def _validate_keys_for_strict_loading(
696674
model: "torch.nn.Module",
697675
loaded_keys: Iterable[str],

0 commit comments

Comments
 (0)