Skip to content

Commit 4db8f07

Browse files
committed
Support smollm2
1 parent 7990574 commit 4db8f07

File tree

3 files changed

+29
-19
lines changed

3 files changed

+29
-19
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
HUGGING_FACE_REPO_IDS = {
106106
"qwen2_5": "Qwen/Qwen2.5-1.5B",
107107
"phi_4_mini": "microsoft/Phi-4-mini-instruct",
108+
"smollm2": "HuggingFaceTB/SmolLM-135M",
108109
}
109110

110111

@@ -541,6 +542,8 @@ def export_llama(args) -> str:
541542
from executorch.examples.models.qwen2_5 import convert_weights
542543
elif args.model == "phi_4_mini":
543544
from executorch.examples.models.phi_4_mini import convert_weights
545+
elif args.model == "smollm2":
546+
from executorch.examples.models.smollm2 import convert_weights
544547
else:
545548
raise ValueError(
546549
f"Converting weights to meta format for {args.model} is not yet supported"

examples/models/smollm2/__init__ renamed to examples/models/smollm2/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# This source code is licensed under the BSD-style license found in the
22
# LICENSE file in the root directory of this source tree.
33

4-
from executorch.example.models.llama.model import Llama2Model
4+
from executorch.examples.models.llama.model import Llama2Model
5+
from executorch.examples.models.smollm2.convert_weights import convert_weights
56

67

78
class SmolLM2Model(Llama2Model):
@@ -11,4 +12,5 @@ def __init__(self, **kwargs):
1112

1213
__all__ = [
1314
"SmolLM2Model",
15+
"convert_weights",
1416
]

examples/models/smollm2/convert_weights.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
_SMOLLM_FROM_META = {
1212
"tok_embeddings.weight": "tok_embeddings.weight",
1313
"norm.weight": "norm.scale",
14-
"output.weight": "output.weight",
1514
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
1615
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
1716
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
@@ -41,10 +40,32 @@ def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
4140
for key, value in state_dict.items():
4241
new_key = get_mapped_key(key, inverted_mapping_dict)
4342
converted_state_dict[new_key] = value
43+
converted_state_dict["output.weight"] = converted_state_dict[
44+
"tok_embeddings.weight"
45+
]
4446

4547
return converted_state_dict
4648

4749

50+
def convert_weights(input_dir: str, output_file: str) -> None:
51+
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
52+
checkpointer = FullModelHFCheckpointer(
53+
checkpoint_dir=input_dir,
54+
checkpoint_files=["model.safetensors"],
55+
output_dir=".",
56+
model_type="LLAMA3",
57+
)
58+
59+
print("Loading checkpoint...")
60+
sd = checkpointer.load_checkpoint()
61+
print("Converting checkpoint...")
62+
breakpoint()
63+
sd = smollm_tune_to_meta(sd["model"])
64+
print("Saving checkpoint...")
65+
torch.save(sd, output_file)
66+
print(f"Done.")
67+
68+
4869
def main():
4970
parser = argparse.ArgumentParser(
5071
description="Convert SmolLM weights to Meta format."
@@ -57,23 +78,7 @@ def main():
5778
parser.add_argument("output", type=str, help="Path to the output checkpoint")
5879

5980
args = parser.parse_args()
60-
61-
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
62-
checkpointer = FullModelHFCheckpointer(
63-
checkpoint_dir=args.input_dir,
64-
checkpoint_files=["model.safetensors"],
65-
output_dir=".",
66-
model_type="LLAMA",
67-
)
68-
69-
print("Loading checkpoint...")
70-
sd = checkpointer.load_checkpoint()
71-
72-
print("Converting checkpoint...")
73-
sd = smollm_tune_to_meta(sd["model"])
74-
75-
torch.save(sd, args.output)
76-
print(f"Checkpoint saved to {args.output}")
81+
convert_weights(args.input_dir, args.output)
7782

7883

7984
if __name__ == "__main__":

0 commit comments

Comments
 (0)