Skip to content

Download checkpoints from HuggingFace #9538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from executorch.devtools.backend_debug import print_delegation_info

from executorch.devtools.etrecord import generate_etrecord
from executorch.examples.models.llama.hf_download import (
download_and_convert_hf_checkpoint,
)
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass

from executorch.extension.llm.export.builder import DType, LLMEdgeManager
Expand Down Expand Up @@ -99,6 +102,11 @@
"smollm2",
]
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
HUGGING_FACE_REPO_IDS = {
"qwen2_5": "Qwen/Qwen2.5-1.5B",
"phi_4_mini": "microsoft/Phi-4-mini-instruct",
"smollm2": "HuggingFaceTB/SmolLM-135M",
}


class WeightType(Enum):
Expand Down Expand Up @@ -526,6 +534,22 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:


def export_llama(args) -> str:
# If a checkpoint isn't provided for an HF OSS model, download and convert the
# weights first.
if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS:
repo_id = HUGGING_FACE_REPO_IDS[args.model]
if args.model == "qwen2_5":
from executorch.examples.models.qwen2_5 import convert_weights
elif args.model == "phi_4_mini":
from executorch.examples.models.phi_4_mini import convert_weights
elif args.model == "smollm2":
from executorch.examples.models.smollm2 import convert_weights
else:
raise ValueError(
f"Converting weights to meta format for {args.model} is not yet supported"
)
args.checkpoint = download_and_convert_hf_checkpoint(repo_id, convert_weights)

if args.profile_path is not None:
try:
from executorch.util.python_profiler import CProfilerFlameGraph
Expand Down
50 changes: 50 additions & 0 deletions examples/models/llama/hf_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from typing import Callable

from huggingface_hub import snapshot_download


def download_and_convert_hf_checkpoint(
repo_id: str, convert_weights: Callable[[str, str], None]
) -> str:
"""
Downloads and converts to Meta format a HuggingFace checkpoint.

Args:
repo_id: Id of the HuggingFace repo, e.g. "Qwen/Qwen2.5-1.5B".
convert_weights: Weight conversion function taking in path to the downloaded HuggingFace
files and the desired output path.

Returns:
The output path of the Meta checkpoint converted from HuggingFace.
"""

# Build cache path.
cache_subdir = "meta_checkpoints"
cache_dir = Path.home() / ".cache" / cache_subdir
cache_dir.mkdir(parents=True, exist_ok=True)

# Use repo name to name the converted file.
model_name = repo_id.replace("/", "_")
converted_path = cache_dir / f"{model_name}.pth"

if converted_path.exists():
print(f"✔ Using cached converted model: {converted_path}")
return converted_path

# 1. Download weights from Hugging Face.
print("⬇ Downloading and converting checkpoint...")
checkpoint_path = snapshot_download(
repo_id=repo_id,
)

# 2. Convert weights to Meta format.
convert_weights(checkpoint_path, converted_path)
return converted_path
2 changes: 1 addition & 1 deletion examples/models/llama/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Install tokenizers for hf .json tokenizer.
# Install snakeviz for cProfile flamegraph
# Install lm-eval for Model Evaluation with lm-evalution-harness.
pip install tiktoken sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
pip install huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile

# Call the install helper for further setup
python examples/models/llama/install_requirement_helper.py
2 changes: 2 additions & 0 deletions examples/models/phi_4_mini/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# LICENSE file in the root directory of this source tree.

from executorch.examples.models.llama.model import Llama2Model
from executorch.examples.models.phi_4_mini.convert_weights import convert_weights


class Phi4MiniModel(Llama2Model):
Expand All @@ -11,4 +12,5 @@ def __init__(self, **kwargs):

__all__ = [
"Phi4MiniModel",
"convert_weights",
]
37 changes: 20 additions & 17 deletions examples/models/phi_4_mini/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,10 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
return converted_state_dict


def main():
parser = argparse.ArgumentParser(
description="Convert Phi-4-mini weights to Meta format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()

def convert_weights(input_dir: str, output_file: str) -> None:
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=args.input_dir,
checkpoint_dir=input_dir,
checkpoint_files=[
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
Expand All @@ -76,12 +65,26 @@ def main():

print("Loading checkpoint...")
sd = checkpointer.load_checkpoint()

print("Converting checkpoint...")
sd = phi_4_tune_to_meta(sd["model"])
print("Saving checkpoint...")
torch.save(sd, output_file)
print("Done.")

torch.save(sd, args.output)
print(f"Checkpoint saved to {args.output}")

def main():
parser = argparse.ArgumentParser(
description="Convert Phi-4-mini weights to Meta format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()
convert_weights(args.input_dir, args.output)


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions examples/models/qwen2_5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# LICENSE file in the root directory of this source tree.

from executorch.examples.models.llama.model import Llama2Model
from executorch.examples.models.qwen2_5.convert_weights import convert_weights


class Qwen2_5Model(Llama2Model):
Expand All @@ -11,4 +12,5 @@ def __init__(self, **kwargs):

__all__ = [
"Qwen2_5Model",
"convert_weights",
]
36 changes: 19 additions & 17 deletions examples/models/qwen2_5/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,35 +53,37 @@ def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
return converted_state_dict


def main():
parser = argparse.ArgumentParser(
description="Convert Qwen2 weights to Meta format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()

def convert_weights(input_dir: str, output_file: str) -> None:
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=args.input_dir,
checkpoint_dir=input_dir,
checkpoint_files=["model.safetensors"],
output_dir=".",
model_type="QWEN2",
)

print("Loading checkpoint...")
sd = checkpointer.load_checkpoint()

print("Converting checkpoint...")
sd = qwen_2_tune_to_meta(sd["model"])
print("Saving checkpoint...")
torch.save(sd, output_file)
print("Done.")

torch.save(sd, args.output)
print(f"Checkpoint saved to {args.output}")

def main():
parser = argparse.ArgumentParser(
description="Convert Qwen2 weights to Meta format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()
convert_weights(args.input_dir, args.output)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.example.models.llama.model import Llama2Model
from executorch.examples.models.llama.model import Llama2Model
from executorch.examples.models.smollm2.convert_weights import convert_weights


class SmolLM2Model(Llama2Model):
Expand All @@ -11,4 +12,5 @@ def __init__(self, **kwargs):

__all__ = [
"SmolLM2Model",
"convert_weights",
]
40 changes: 22 additions & 18 deletions examples/models/smollm2/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
_SMOLLM_FROM_META = {
"tok_embeddings.weight": "tok_embeddings.weight",
"norm.weight": "norm.scale",
"output.weight": "output.weight",
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
Expand Down Expand Up @@ -41,10 +40,31 @@ def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
converted_state_dict[new_key] = value
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]

return converted_state_dict


def convert_weights(input_dir: str, output_file: str) -> None:
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=input_dir,
checkpoint_files=["model.safetensors"],
output_dir=".",
model_type="LLAMA3",
)

print("Loading checkpoint...")
sd = checkpointer.load_checkpoint()
print("Converting checkpoint...")
sd = smollm_tune_to_meta(sd["model"])
print("Saving checkpoint...")
torch.save(sd, output_file)
print("Done.")


def main():
parser = argparse.ArgumentParser(
description="Convert SmolLM weights to Meta format."
Expand All @@ -57,23 +77,7 @@ def main():
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()

# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=args.input_dir,
checkpoint_files=["model.safetensors"],
output_dir=".",
model_type="LLAMA",
)

print("Loading checkpoint...")
sd = checkpointer.load_checkpoint()

print("Converting checkpoint...")
sd = smollm_tune_to_meta(sd["model"])

torch.save(sd, args.output)
print(f"Checkpoint saved to {args.output}")
convert_weights(args.input_dir, args.output)


if __name__ == "__main__":
Expand Down
Loading