Skip to content

Commit 766bbdc

Browse files
authored
Download checkpoints from HuggingFace (#9538)
### Summary If no checkpoint is specified during `export_llama`, download the checkpoint from HuggingFace if it is an OSS model. Closes #8872 ### Test plan Manual export
1 parent c890809 commit 766bbdc

File tree

9 files changed

+143
-54
lines changed

9 files changed

+143
-54
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
from executorch.devtools.backend_debug import print_delegation_info
2929

3030
from executorch.devtools.etrecord import generate_etrecord
31+
from executorch.examples.models.llama.hf_download import (
32+
download_and_convert_hf_checkpoint,
33+
)
3134
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
3235

3336
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
@@ -99,6 +102,11 @@
99102
"smollm2",
100103
]
101104
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
105+
HUGGING_FACE_REPO_IDS = {
106+
"qwen2_5": "Qwen/Qwen2.5-1.5B",
107+
"phi_4_mini": "microsoft/Phi-4-mini-instruct",
108+
"smollm2": "HuggingFaceTB/SmolLM-135M",
109+
}
102110

103111

104112
class WeightType(Enum):
@@ -526,6 +534,22 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
526534

527535

528536
def export_llama(args) -> str:
537+
# If a checkpoint isn't provided for an HF OSS model, download and convert the
538+
# weights first.
539+
if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS:
540+
repo_id = HUGGING_FACE_REPO_IDS[args.model]
541+
if args.model == "qwen2_5":
542+
from executorch.examples.models.qwen2_5 import convert_weights
543+
elif args.model == "phi_4_mini":
544+
from executorch.examples.models.phi_4_mini import convert_weights
545+
elif args.model == "smollm2":
546+
from executorch.examples.models.smollm2 import convert_weights
547+
else:
548+
raise ValueError(
549+
f"Converting weights to meta format for {args.model} is not yet supported"
550+
)
551+
args.checkpoint = download_and_convert_hf_checkpoint(repo_id, convert_weights)
552+
529553
if args.profile_path is not None:
530554
try:
531555
from executorch.util.python_profiler import CProfilerFlameGraph

examples/models/llama/hf_download.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from pathlib import Path
9+
from typing import Callable
10+
11+
from huggingface_hub import snapshot_download
12+
13+
14+
def download_and_convert_hf_checkpoint(
15+
repo_id: str, convert_weights: Callable[[str, str], None]
16+
) -> str:
17+
"""
18+
Downloads and converts to Meta format a HuggingFace checkpoint.
19+
20+
Args:
21+
repo_id: Id of the HuggingFace repo, e.g. "Qwen/Qwen2.5-1.5B".
22+
convert_weights: Weight conversion function taking in path to the downloaded HuggingFace
23+
files and the desired output path.
24+
25+
Returns:
26+
The output path of the Meta checkpoint converted from HuggingFace.
27+
"""
28+
29+
# Build cache path.
30+
cache_subdir = "meta_checkpoints"
31+
cache_dir = Path.home() / ".cache" / cache_subdir
32+
cache_dir.mkdir(parents=True, exist_ok=True)
33+
34+
# Use repo name to name the converted file.
35+
model_name = repo_id.replace("/", "_")
36+
converted_path = cache_dir / f"{model_name}.pth"
37+
38+
if converted_path.exists():
39+
print(f"✔ Using cached converted model: {converted_path}")
40+
return converted_path
41+
42+
# 1. Download weights from Hugging Face.
43+
print("⬇ Downloading and converting checkpoint...")
44+
checkpoint_path = snapshot_download(
45+
repo_id=repo_id,
46+
)
47+
48+
# 2. Convert weights to Meta format.
49+
convert_weights(checkpoint_path, converted_path)
50+
return converted_path

examples/models/llama/install_requirements.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# Install tokenizers for hf .json tokenizer.
1111
# Install snakeviz for cProfile flamegraph
1212
# Install lm-eval for Model Evaluation with lm-evalution-harness.
13-
pip install tiktoken sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
13+
pip install huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
1414

1515
# Call the install helper for further setup
1616
python examples/models/llama/install_requirement_helper.py

examples/models/phi_4_mini/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# LICENSE file in the root directory of this source tree.
33

44
from executorch.examples.models.llama.model import Llama2Model
5+
from executorch.examples.models.phi_4_mini.convert_weights import convert_weights
56

67

78
class Phi4MiniModel(Llama2Model):
@@ -11,4 +12,5 @@ def __init__(self, **kwargs):
1112

1213
__all__ = [
1314
"Phi4MiniModel",
15+
"convert_weights",
1416
]

examples/models/phi_4_mini/convert_weights.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,10 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
5151
return converted_state_dict
5252

5353

54-
def main():
55-
parser = argparse.ArgumentParser(
56-
description="Convert Phi-4-mini weights to Meta format."
57-
)
58-
parser.add_argument(
59-
"input_dir",
60-
type=str,
61-
help="Path to directory containing checkpoint files",
62-
)
63-
parser.add_argument("output", type=str, help="Path to the output checkpoint")
64-
65-
args = parser.parse_args()
66-
54+
def convert_weights(input_dir: str, output_file: str) -> None:
55+
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
6756
checkpointer = FullModelHFCheckpointer(
68-
checkpoint_dir=args.input_dir,
57+
checkpoint_dir=input_dir,
6958
checkpoint_files=[
7059
"model-00001-of-00002.safetensors",
7160
"model-00002-of-00002.safetensors",
@@ -76,12 +65,26 @@ def main():
7665

7766
print("Loading checkpoint...")
7867
sd = checkpointer.load_checkpoint()
79-
8068
print("Converting checkpoint...")
8169
sd = phi_4_tune_to_meta(sd["model"])
70+
print("Saving checkpoint...")
71+
torch.save(sd, output_file)
72+
print("Done.")
8273

83-
torch.save(sd, args.output)
84-
print(f"Checkpoint saved to {args.output}")
74+
75+
def main():
76+
parser = argparse.ArgumentParser(
77+
description="Convert Phi-4-mini weights to Meta format."
78+
)
79+
parser.add_argument(
80+
"input_dir",
81+
type=str,
82+
help="Path to directory containing checkpoint files",
83+
)
84+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
85+
86+
args = parser.parse_args()
87+
convert_weights(args.input_dir, args.output)
8588

8689

8790
if __name__ == "__main__":

examples/models/qwen2_5/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# LICENSE file in the root directory of this source tree.
33

44
from executorch.examples.models.llama.model import Llama2Model
5+
from executorch.examples.models.qwen2_5.convert_weights import convert_weights
56

67

78
class Qwen2_5Model(Llama2Model):
@@ -11,4 +12,5 @@ def __init__(self, **kwargs):
1112

1213
__all__ = [
1314
"Qwen2_5Model",
15+
"convert_weights",
1416
]

examples/models/qwen2_5/convert_weights.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,35 +53,37 @@ def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
5353
return converted_state_dict
5454

5555

56-
def main():
57-
parser = argparse.ArgumentParser(
58-
description="Convert Qwen2 weights to Meta format."
59-
)
60-
parser.add_argument(
61-
"input_dir",
62-
type=str,
63-
help="Path to directory containing checkpoint files",
64-
)
65-
parser.add_argument("output", type=str, help="Path to the output checkpoint")
66-
67-
args = parser.parse_args()
68-
56+
def convert_weights(input_dir: str, output_file: str) -> None:
6957
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
7058
checkpointer = FullModelHFCheckpointer(
71-
checkpoint_dir=args.input_dir,
59+
checkpoint_dir=input_dir,
7260
checkpoint_files=["model.safetensors"],
7361
output_dir=".",
7462
model_type="QWEN2",
7563
)
7664

7765
print("Loading checkpoint...")
7866
sd = checkpointer.load_checkpoint()
79-
8067
print("Converting checkpoint...")
8168
sd = qwen_2_tune_to_meta(sd["model"])
69+
print("Saving checkpoint...")
70+
torch.save(sd, output_file)
71+
print("Done.")
8272

83-
torch.save(sd, args.output)
84-
print(f"Checkpoint saved to {args.output}")
73+
74+
def main():
75+
parser = argparse.ArgumentParser(
76+
description="Convert Qwen2 weights to Meta format."
77+
)
78+
parser.add_argument(
79+
"input_dir",
80+
type=str,
81+
help="Path to directory containing checkpoint files",
82+
)
83+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
84+
85+
args = parser.parse_args()
86+
convert_weights(args.input_dir, args.output)
8587

8688

8789
if __name__ == "__main__":

examples/models/smollm2/__init__ renamed to examples/models/smollm2/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# This source code is licensed under the BSD-style license found in the
22
# LICENSE file in the root directory of this source tree.
33

4-
from executorch.example.models.llama.model import Llama2Model
4+
from executorch.examples.models.llama.model import Llama2Model
5+
from executorch.examples.models.smollm2.convert_weights import convert_weights
56

67

78
class SmolLM2Model(Llama2Model):
@@ -11,4 +12,5 @@ def __init__(self, **kwargs):
1112

1213
__all__ = [
1314
"SmolLM2Model",
15+
"convert_weights",
1416
]

examples/models/smollm2/convert_weights.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
_SMOLLM_FROM_META = {
1212
"tok_embeddings.weight": "tok_embeddings.weight",
1313
"norm.weight": "norm.scale",
14-
"output.weight": "output.weight",
1514
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
1615
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
1716
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
@@ -41,10 +40,31 @@ def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
4140
for key, value in state_dict.items():
4241
new_key = get_mapped_key(key, inverted_mapping_dict)
4342
converted_state_dict[new_key] = value
43+
converted_state_dict["output.weight"] = converted_state_dict[
44+
"tok_embeddings.weight"
45+
]
4446

4547
return converted_state_dict
4648

4749

50+
def convert_weights(input_dir: str, output_file: str) -> None:
51+
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
52+
checkpointer = FullModelHFCheckpointer(
53+
checkpoint_dir=input_dir,
54+
checkpoint_files=["model.safetensors"],
55+
output_dir=".",
56+
model_type="LLAMA3",
57+
)
58+
59+
print("Loading checkpoint...")
60+
sd = checkpointer.load_checkpoint()
61+
print("Converting checkpoint...")
62+
sd = smollm_tune_to_meta(sd["model"])
63+
print("Saving checkpoint...")
64+
torch.save(sd, output_file)
65+
print("Done.")
66+
67+
4868
def main():
4969
parser = argparse.ArgumentParser(
5070
description="Convert SmolLM weights to Meta format."
@@ -57,23 +77,7 @@ def main():
5777
parser.add_argument("output", type=str, help="Path to the output checkpoint")
5878

5979
args = parser.parse_args()
60-
61-
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
62-
checkpointer = FullModelHFCheckpointer(
63-
checkpoint_dir=args.input_dir,
64-
checkpoint_files=["model.safetensors"],
65-
output_dir=".",
66-
model_type="LLAMA",
67-
)
68-
69-
print("Loading checkpoint...")
70-
sd = checkpointer.load_checkpoint()
71-
72-
print("Converting checkpoint...")
73-
sd = smollm_tune_to_meta(sd["model"])
74-
75-
torch.save(sd, args.output)
76-
print(f"Checkpoint saved to {args.output}")
80+
convert_weights(args.input_dir, args.output)
7781

7882

7983
if __name__ == "__main__":

0 commit comments

Comments
 (0)