Skip to content

Commit 07b4267

Browse files
wsxiaoysggerganov
authored andcommitted
llama : add support for StarCoder model architectures (ggml-org#3187)
* add placeholder of starcoder in gguf / llama.cpp * support convert starcoder weights to gguf * convert MQA to MHA * fix ffn_down name * add LLM_ARCH_STARCODER to llama.cpp * set head_count_kv = 1 * load starcoder weight * add max_position_embeddings * set n_positions to max_positioin_embeddings * properly load all starcoder params * fix head count kv * fix comments * fix vram calculation for starcoder * store mqa directly * add input embeddings handling * add TBD * working in cpu, metal buggy * cleanup useless code * metal : fix out-of-bounds access in soft_max kernels * llama : make starcoder graph build more consistent with others * refactor: cleanup comments a bit * add other starcoder models: 3B, 7B, 15B * support-mqa-directly * fix: remove max_position_embeddings, use n_train_ctx * Update llama.cpp Co-authored-by: Georgi Gerganov <[email protected]> * Update llama.cpp Co-authored-by: Georgi Gerganov <[email protected]> * Apply suggestions from code review Co-authored-by: Georgi Gerganov <[email protected]> * fix: switch to space from tab --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent f5f9296 commit 07b4267

File tree

3 files changed

+637
-21
lines changed

3 files changed

+637
-21
lines changed

convert-starcoder-hf-to-gguf.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
#!/usr/bin/env python3
2+
# HF starcoder --> gguf conversion
3+
4+
from __future__ import annotations
5+
6+
import argparse
7+
import json
8+
import os
9+
import struct
10+
import sys
11+
from pathlib import Path
12+
from typing import Any
13+
14+
import numpy as np
15+
import torch
16+
from transformers import AutoTokenizer # type: ignore[import]
17+
18+
if 'NO_LOCAL_GGUF' not in os.environ:
19+
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
20+
import gguf
21+
22+
23+
def bytes_to_unicode():
24+
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
25+
"""
26+
Returns list of utf-8 byte and a corresponding list of unicode strings.
27+
The reversible bpe codes work on unicode strings.
28+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
29+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
30+
This is a significant percentage of your normal, say, 32K bpe vocab.
31+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
32+
And avoids mapping to whitespace/control characters the bpe code barfs on.
33+
"""
34+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
35+
cs = bs[:]
36+
n = 0
37+
for b in range(2**8):
38+
if b not in bs:
39+
bs.append(b)
40+
cs.append(2**8+n)
41+
n += 1
42+
return dict(zip(bs, (chr(n) for n in cs)))
43+
44+
45+
def count_model_parts(dir_model: Path) -> int:
46+
num_parts = 0
47+
for filename in os.listdir(dir_model):
48+
if filename.startswith("pytorch_model-"):
49+
num_parts += 1
50+
51+
if num_parts > 0:
52+
print("gguf: found " + str(num_parts) + " model parts")
53+
return num_parts
54+
55+
56+
def parse_args() -> argparse.Namespace:
57+
parser = argparse.ArgumentParser(description="Convert a StarCoder model to a GGML compatible file")
58+
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
59+
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
60+
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
61+
parser.add_argument("ftype", type=int, help="output format - use 0 for float32, 1 for float16", choices=[0, 1], default = 1)
62+
return parser.parse_args()
63+
64+
args = parse_args()
65+
66+
dir_model = args.model
67+
ftype = args.ftype
68+
if not dir_model.is_dir():
69+
print(f'Error: {args.model} is not a directory', file = sys.stderr)
70+
sys.exit(1)
71+
72+
# possible tensor data types
73+
# ftype == 0 -> float32
74+
# ftype == 1 -> float16
75+
76+
# map from ftype to string
77+
ftype_str = ["f32", "f16"]
78+
79+
if args.outfile is not None:
80+
fname_out = args.outfile
81+
else:
82+
# output in the same directory as the model by default
83+
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
84+
85+
print("gguf: loading model "+dir_model.name)
86+
87+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
88+
hparams = json.load(f)
89+
90+
if hparams["architectures"][0] != "GPTBigCodeForCausalLM":
91+
print("Model architecture not supported: " + hparams["architectures"][0])
92+
93+
sys.exit(1)
94+
95+
# get number of model parts
96+
num_parts = count_model_parts(dir_model)
97+
98+
ARCH=gguf.MODEL_ARCH.STARCODER
99+
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
100+
101+
print("gguf: get model metadata")
102+
103+
block_count = hparams["n_layer"]
104+
105+
gguf_writer.add_name("StarCoder")
106+
gguf_writer.add_context_length(hparams["n_positions"])
107+
gguf_writer.add_embedding_length(hparams["n_embd"])
108+
gguf_writer.add_feed_forward_length(4 * hparams["n_embd"])
109+
gguf_writer.add_block_count(block_count)
110+
gguf_writer.add_head_count(hparams["n_head"])
111+
gguf_writer.add_head_count_kv(1)
112+
gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
113+
gguf_writer.add_file_type(ftype)
114+
115+
# TOKENIZATION
116+
117+
print("gguf: get tokenizer metadata")
118+
119+
tokens: list[bytearray] = []
120+
scores: list[float] = []
121+
toktypes: list[int] = []
122+
123+
tokenizer_json_file = dir_model / 'tokenizer.json'
124+
if not tokenizer_json_file.is_file():
125+
print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr)
126+
sys.exit(1)
127+
128+
# gpt2 tokenizer
129+
gguf_writer.add_tokenizer_model("gpt2")
130+
131+
with open(tokenizer_json_file, "r", encoding="utf-8") as f:
132+
tokenizer_json = json.load(f)
133+
134+
print("gguf: get gpt2 tokenizer vocab")
135+
136+
# The number of tokens in tokenizer.json can differ from the expected vocab size.
137+
# This causes downstream issues with mismatched tensor sizes when running the inference
138+
vocab_size = hparams["vocab_size"] if "vocab_size" in hparams else len(tokenizer_json["model"]["vocab"])
139+
140+
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
141+
tokenizer = AutoTokenizer.from_pretrained(dir_model)
142+
143+
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
144+
byte_encoder = bytes_to_unicode()
145+
byte_decoder = {v: k for k, v in byte_encoder.items()}
146+
147+
for i in range(vocab_size):
148+
if i in reverse_vocab:
149+
try:
150+
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
151+
except KeyError:
152+
text = bytearray()
153+
for c in reverse_vocab[i]:
154+
if ord(c) < 256: # single byte character
155+
text.append(byte_decoder[ord(c)])
156+
else: # multibyte special token character
157+
text.extend(c.encode('utf-8'))
158+
else:
159+
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
160+
pad_token = f"[PAD{i}]".encode("utf8")
161+
text = bytearray(pad_token)
162+
163+
tokens.append(text)
164+
scores.append(0.0) # dymmy
165+
toktypes.append(gguf.TokenType.NORMAL) # dummy
166+
167+
gguf_writer.add_token_list(tokens)
168+
gguf_writer.add_token_scores(scores)
169+
gguf_writer.add_token_types(toktypes)
170+
171+
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
172+
special_vocab.add_to_gguf(gguf_writer)
173+
174+
# TENSORS
175+
176+
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
177+
178+
# params for qkv transform
179+
n_head = hparams["n_head"]
180+
n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1
181+
182+
head_dim = hparams["n_embd"] // n_head
183+
184+
# tensor info
185+
print("gguf: get tensor metadata")
186+
187+
if num_parts == 0:
188+
part_names = iter(("pytorch_model.bin",))
189+
else:
190+
part_names = (
191+
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
192+
)
193+
194+
for part_name in part_names:
195+
if args.vocab_only:
196+
break
197+
print("gguf: loading model part '" + part_name + "'")
198+
model_part = torch.load(dir_model / part_name, map_location="cpu")
199+
200+
for name in model_part.keys():
201+
data = model_part[name]
202+
203+
old_dtype = data.dtype
204+
205+
# convert any unsupported data types to float32
206+
if data.dtype != torch.float16 and data.dtype != torch.float32:
207+
data = data.to(torch.float32)
208+
209+
data = data.squeeze().numpy()
210+
211+
# map tensor names
212+
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
213+
if new_name is None:
214+
print("Can not map tensor '" + name + "'")
215+
sys.exit()
216+
217+
n_dims = len(data.shape)
218+
data_dtype = data.dtype
219+
220+
# if f32 desired, convert any float16 to float32
221+
if ftype == 0 and data_dtype == np.float16:
222+
data = data.astype(np.float32)
223+
224+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
225+
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
226+
data = data.astype(np.float32)
227+
228+
# if f16 desired, convert any float32 2-dim weight tensors to float16
229+
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
230+
data = data.astype(np.float16)
231+
232+
print(name, "=>", new_name + ", shape = " + str(data.shape) + ", " + str(old_dtype) + " --> " + str(data.dtype))
233+
234+
gguf_writer.add_tensor(new_name, data)
235+
236+
237+
print("gguf: write header")
238+
gguf_writer.write_header_to_file()
239+
print("gguf: write metadata")
240+
gguf_writer.write_kv_data_to_file()
241+
if not args.vocab_only:
242+
print("gguf: write tensors")
243+
gguf_writer.write_tensors_to_file()
244+
245+
gguf_writer.close()
246+
247+
print(f"gguf: model successfully exported to '{fname_out}'")
248+
print("")

gguf-py/gguf/gguf.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,14 @@
7777

7878

7979
class MODEL_ARCH(IntEnum):
80-
LLAMA : int = auto()
81-
FALCON : int = auto()
82-
BAICHUAN:int = auto()
83-
GPT2 : int = auto()
84-
GPTJ : int = auto()
85-
GPTNEOX: int = auto()
86-
MPT : int = auto()
80+
LLAMA : int = auto()
81+
FALCON : int = auto()
82+
BAICHUAN : int = auto()
83+
GPT2 : int = auto()
84+
GPTJ : int = auto()
85+
GPTNEOX : int = auto()
86+
MPT : int = auto()
87+
STARCODER : int = auto()
8788

8889

8990
class MODEL_TENSOR(IntEnum):
@@ -107,13 +108,14 @@ class MODEL_TENSOR(IntEnum):
107108

108109

109110
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
110-
MODEL_ARCH.LLAMA: "llama",
111-
MODEL_ARCH.FALCON: "falcon",
112-
MODEL_ARCH.BAICHUAN:"baichuan",
113-
MODEL_ARCH.GPT2: "gpt2",
114-
MODEL_ARCH.GPTJ: "gptj",
115-
MODEL_ARCH.GPTNEOX: "gptneox",
116-
MODEL_ARCH.MPT: "mpt",
111+
MODEL_ARCH.LLAMA: "llama",
112+
MODEL_ARCH.FALCON: "falcon",
113+
MODEL_ARCH.BAICHUAN: "baichuan",
114+
MODEL_ARCH.GPT2: "gpt2",
115+
MODEL_ARCH.GPTJ: "gptj",
116+
MODEL_ARCH.GPTNEOX: "gptneox",
117+
MODEL_ARCH.MPT: "mpt",
118+
MODEL_ARCH.STARCODER: "starcoder",
117119
}
118120

119121
MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
@@ -171,6 +173,18 @@ class MODEL_TENSOR(IntEnum):
171173
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
172174
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
173175
},
176+
MODEL_ARCH.STARCODER: {
177+
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
178+
MODEL_TENSOR.POS_EMBD: "position_embd",
179+
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
180+
MODEL_TENSOR.OUTPUT: "output",
181+
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
182+
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
183+
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
184+
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
185+
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
186+
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
187+
},
174188
MODEL_ARCH.GPT2: {
175189
# TODO
176190
},

0 commit comments

Comments
 (0)