11
11
_SMOLLM_FROM_META = {
12
12
"tok_embeddings.weight" : "tok_embeddings.weight" ,
13
13
"norm.weight" : "norm.scale" ,
14
- "output.weight" : "output.weight" ,
15
14
"layers.{}.attention.wk.weight" : "layers.{}.attn.k_proj.weight" ,
16
15
"layers.{}.attention.wq.weight" : "layers.{}.attn.q_proj.weight" ,
17
16
"layers.{}.attention.wv.weight" : "layers.{}.attn.v_proj.weight" ,
@@ -41,10 +40,32 @@ def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
41
40
for key , value in state_dict .items ():
42
41
new_key = get_mapped_key (key , inverted_mapping_dict )
43
42
converted_state_dict [new_key ] = value
43
+ converted_state_dict ["output.weight" ] = converted_state_dict [
44
+ "tok_embeddings.weight"
45
+ ]
44
46
45
47
return converted_state_dict
46
48
47
49
50
+ def convert_weights (input_dir : str , output_file : str ) -> None :
51
+ # Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
52
+ checkpointer = FullModelHFCheckpointer (
53
+ checkpoint_dir = input_dir ,
54
+ checkpoint_files = ["model.safetensors" ],
55
+ output_dir = "." ,
56
+ model_type = "LLAMA3" ,
57
+ )
58
+
59
+ print ("Loading checkpoint..." )
60
+ sd = checkpointer .load_checkpoint ()
61
+ print ("Converting checkpoint..." )
62
+ breakpoint ()
63
+ sd = smollm_tune_to_meta (sd ["model" ])
64
+ print ("Saving checkpoint..." )
65
+ torch .save (sd , output_file )
66
+ print (f"Done." )
67
+
68
+
48
69
def main ():
49
70
parser = argparse .ArgumentParser (
50
71
description = "Convert SmolLM weights to Meta format."
@@ -57,23 +78,7 @@ def main():
57
78
parser .add_argument ("output" , type = str , help = "Path to the output checkpoint" )
58
79
59
80
args = parser .parse_args ()
60
-
61
- # Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
62
- checkpointer = FullModelHFCheckpointer (
63
- checkpoint_dir = args .input_dir ,
64
- checkpoint_files = ["model.safetensors" ],
65
- output_dir = "." ,
66
- model_type = "LLAMA" ,
67
- )
68
-
69
- print ("Loading checkpoint..." )
70
- sd = checkpointer .load_checkpoint ()
71
-
72
- print ("Converting checkpoint..." )
73
- sd = smollm_tune_to_meta (sd ["model" ])
74
-
75
- torch .save (sd , args .output )
76
- print (f"Checkpoint saved to { args .output } " )
81
+ convert_weights (args .input_dir , args .output )
77
82
78
83
79
84
if __name__ == "__main__" :
0 commit comments