|
18 | 18 | import os
|
19 | 19 | import re
|
20 | 20 | from collections import defaultdict, namedtuple
|
21 |
| -from contextlib import contextmanager |
22 | 21 | from functools import lru_cache
|
23 | 22 | from pathlib import Path
|
24 |
| -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union |
| 23 | +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union |
25 | 24 |
|
26 | 25 | from packaging import version
|
27 | 26 |
|
@@ -538,13 +537,15 @@ def _load_sharded_checkpoint(
|
538 | 537 | for shard_file in shard_files:
|
539 | 538 | # Load shard into memory
|
540 | 539 | shard_path = os.path.join(save_directory, shard_file)
|
541 |
| - with _load_shard_into_memory( |
| 540 | + state_dict = load_state_dict_from_file( |
542 | 541 | shard_path,
|
543 |
| - load_fn=load_state_dict_from_file, |
544 |
| - kwargs={"weights_only": weights_only}, |
545 |
| - ) as state_dict: |
546 |
| - # Update model with parameters from this shard |
547 |
| - model.load_state_dict(state_dict, strict=strict) |
| 542 | + map_location="cpu", |
| 543 | + weights_only=weights_only, |
| 544 | + ) |
| 545 | + # Update model with parameters from this shard |
| 546 | + model.load_state_dict(state_dict, strict=strict) |
| 547 | + # Explicitly remove the state dict from memory |
| 548 | + del state_dict |
548 | 549 |
|
549 | 550 | # 4. Return compatibility info
|
550 | 551 | loaded_keys = set(index["weight_map"].keys())
|
@@ -630,7 +631,8 @@ def load_state_dict_from_file(
|
630 | 631 | # Check format of the archive
|
631 | 632 | with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined]
|
632 | 633 | metadata = f.metadata()
|
633 |
| - if metadata.get("format") != "pt": |
| 634 | + # see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966 |
| 635 | + if metadata is not None and metadata.get("format") not in ["pt", "mlx"]: |
634 | 636 | raise OSError(
|
635 | 637 | f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
636 | 638 | "you save your model with the `save_torch_model` method."
|
@@ -668,30 +670,6 @@ def load_state_dict_from_file(
|
668 | 670 | # HELPERS
|
669 | 671 |
|
670 | 672 |
|
671 |
| -@contextmanager |
672 |
| -def _load_shard_into_memory( |
673 |
| - shard_path: str, |
674 |
| - load_fn: Callable, |
675 |
| - kwargs: Optional[Dict[str, Any]] = None, |
676 |
| -): |
677 |
| - """ |
678 |
| - Context manager to handle loading and cleanup of model shards. |
679 |
| -
|
680 |
| - Args: |
681 |
| - shard_path: Path to the shard file |
682 |
| - load_fn: Function to load the shard (either torch.load or safetensors.load) |
683 |
| -
|
684 |
| - Yields: |
685 |
| - The loaded state dict for this shard |
686 |
| - """ |
687 |
| - try: |
688 |
| - state_dict = load_fn(shard_path, **kwargs) # type: ignore[arg-type] |
689 |
| - yield state_dict |
690 |
| - finally: |
691 |
| - # Explicitly remove the state dict from memory |
692 |
| - del state_dict |
693 |
| - |
694 |
| - |
695 | 673 | def _validate_keys_for_strict_loading(
|
696 | 674 | model: "torch.nn.Module",
|
697 | 675 | loaded_keys: Iterable[str],
|
|
0 commit comments