Skip to content

Commit 730d9c6

Browse files
authored
convert.py : advanced option (#2753)
* Allow convert.py to convert to q8_0 Fix issue with bounded_parallel_map and greedy consuming iterator Display elapsed time during conversion * Add --concurrency option Minor improvements to help text Clean up bounded_parallel_map function a bit * Massive speed improvement thanks to Cebtenzzre * Refactor types
1 parent c7d92e6 commit 730d9c6

File tree

1 file changed

+133
-73
lines changed

1 file changed

+133
-73
lines changed

convert.py

Lines changed: 133 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gguf
44
import argparse
55
import concurrent.futures
6+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
67
import copy
78
import enum
89
import faulthandler
@@ -17,13 +18,14 @@
1718
import signal
1819
import struct
1920
import sys
21+
import time
2022
import zipfile
2123
import numpy as np
2224

2325
from abc import ABCMeta, abstractmethod
2426
from dataclasses import dataclass
2527
from pathlib import Path
26-
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, Union)
28+
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, TypeVar, Union)
2729
from sentencepiece import SentencePieceProcessor # type: ignore
2830

2931
if TYPE_CHECKING:
@@ -37,30 +39,70 @@
3739
ARCH=gguf.MODEL_ARCH.LLAMA
3840
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
3941

42+
DEFAULT_CONCURRENCY = 8
4043
#
4144
# data types
4245
#
4346

4447
@dataclass(frozen=True)
45-
class UnquantizedDataType:
48+
class DataType:
4649
name: str
50+
dtype: 'np.dtype[Any]'
51+
valid_conversions: List[str]
4752

48-
DT_F16 = UnquantizedDataType('F16')
49-
DT_F32 = UnquantizedDataType('F32')
50-
DT_I32 = UnquantizedDataType('I32')
51-
DT_BF16 = UnquantizedDataType('BF16')
53+
def elements_to_bytes(self, n_elements: int) -> int:
54+
return n_elements * self.dtype.itemsize
5255

53-
DataType = Union[UnquantizedDataType]
56+
@dataclass(frozen=True)
57+
class UnquantizedDataType(DataType):
58+
pass
5459

55-
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
56-
DT_BF16: np.dtype(np.uint16),
57-
DT_F16: np.dtype(np.float16),
58-
DT_F32: np.dtype(np.float32),
59-
DT_I32: np.dtype(np.int32),
60-
}
60+
DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
61+
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
62+
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
63+
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
64+
65+
@dataclass(frozen=True)
66+
class QuantizedDataType(DataType):
67+
block_size: int
68+
quantized_dtype: 'np.dtype[Any]'
69+
ggml_type: gguf.GGMLQuantizationType
6170

62-
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \
63-
{dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
71+
def quantize(self, arr: NDArray) -> NDArray:
72+
raise NotImplementedError(f'Quantization for {self.name} not implemented')
73+
74+
def elements_to_bytes(self, n_elements: int) -> int:
75+
assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}'
76+
return self.quantized_dtype.itemsize * (n_elements // self.block_size)
77+
78+
@dataclass(frozen=True)
79+
class Q8_0QuantizedDataType(QuantizedDataType):
80+
# Mini Q8_0 quantization in Python!
81+
def quantize(self, arr: NDArray) -> NDArray:
82+
assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}'
83+
assert arr.dtype == np.float32, f'Bad array type {arr.dtype}'
84+
n_blocks = arr.size // self.block_size
85+
blocks = arr.reshape((n_blocks, self.block_size))
86+
# Much faster implementation of block quantization contributed by @Cebtenzzre
87+
def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[Tuple[Any, Any]]:
88+
d = abs(blocks).max(axis = 1) / np.float32(127)
89+
with np.errstate(divide = 'ignore'):
90+
qs = (blocks / d[:, None]).round()
91+
qs[d == 0] = 0
92+
yield from zip(d, qs)
93+
return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype)
94+
95+
DT_Q8_0 = Q8_0QuantizedDataType('Q8_0',
96+
dtype = np.dtype(np.float32), valid_conversions = [],
97+
ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32,
98+
quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))]))
99+
100+
# Quantized types skipped here because they may also map to np.float32
101+
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = {}
102+
for dt in (DT_BF16, DT_F16, DT_F32, DT_I32):
103+
if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE:
104+
raise ValueError(f'Invalid duplicate data type {dt}')
105+
NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt
64106

65107
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
66108
'BF16': DT_BF16,
@@ -73,20 +115,22 @@ class UnquantizedDataType:
73115
# TODO: rename to LLAMAFileType
74116
# TODO: move to `gguf.py`
75117
class GGMLFileType(enum.IntEnum):
76-
AllF32 = 0
77-
MostlyF16 = 1 # except 1d tensors
118+
AllF32 = 0
119+
MostlyF16 = 1 # except 1d tensors
120+
MostlyQ8_0 = 7 # except 1d tensors
78121

79122
def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
80-
if len(tensor.shape) == 1:
81-
# 1D tensors are always F32.
82-
return DT_F32
83-
elif self == GGMLFileType.AllF32:
84-
return DT_F32
85-
elif self == GGMLFileType.MostlyF16:
86-
return DT_F16
87-
else:
123+
dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self)
124+
if dt is None:
88125
raise ValueError(self)
126+
# 1D tensors are always F32.
127+
return dt if len(tensor.shape) > 1 else DT_F32
89128

129+
GGML_FILE_TYPE_TO_DATA_TYPE: Dict[GGMLFileType, DataType] = {
130+
GGMLFileType.AllF32 : DT_F32,
131+
GGMLFileType.MostlyF16 : DT_F16,
132+
GGMLFileType.MostlyQ8_0: DT_Q8_0,
133+
}
90134

91135
#
92136
# hparams loading
@@ -415,7 +459,7 @@ def __init__(self, ndarray: NDArray) -> None:
415459
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
416460

417461
def astype(self, data_type: DataType) -> Tensor:
418-
dtype = DATA_TYPE_TO_NUMPY[data_type]
462+
dtype = data_type.dtype
419463
if self.data_type == DT_BF16:
420464
self.ndarray = bf16_to_fp32(self.ndarray)
421465
return UnquantizedTensor(self.ndarray.astype(dtype))
@@ -454,22 +498,6 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv
454498
GGMLCompatibleTensor = Union[UnquantizedTensor]
455499

456500

457-
class DeferredPermutedTensor(Tensor):
458-
def __init__(self, base: Tensor, n_head: int, n_head_kv: int) -> None:
459-
self.base = base
460-
self.n_head = n_head
461-
self.data_type = self.base.data_type
462-
463-
def astype(self, data_type: DataType) -> Tensor:
464-
return self.base.astype(data_type).permute(self.n_head, self.n_head_kv)
465-
466-
def to_ggml(self) -> GGMLCompatibleTensor:
467-
return self.base.to_ggml().permute(self.n_head, self.n_head_kv)
468-
469-
def permute(self, n_head: int, n_head_kv: int) -> Tensor:
470-
raise Exception("shouldn't permute twice")
471-
472-
473501
@dataclass
474502
class LazyTensor:
475503
_load: Callable[[], Tensor]
@@ -479,7 +507,9 @@ class LazyTensor:
479507

480508
def load(self) -> Tensor:
481509
ret = self._load()
482-
assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description)
510+
# Should be okay if it maps to the same numpy type?
511+
assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \
512+
(self.data_type, ret.data_type, self.description)
483513
return ret
484514

485515
def astype(self, data_type: DataType) -> 'LazyTensor':
@@ -490,8 +520,8 @@ def load() -> Tensor:
490520
return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
491521

492522
def validate_conversion_to(self, data_type: DataType) -> None:
493-
if data_type == self.data_type:
494-
return
523+
if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions:
524+
raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.')
495525

496526

497527
LazyModel = Dict[str, LazyTensor]
@@ -617,9 +647,7 @@ def persistent_load(self, pid: Any) -> Any:
617647
info = self.zip_file.getinfo(filename)
618648

619649
def load(offset: int, elm_count: int) -> NDArray:
620-
dtype = DATA_TYPE_TO_NUMPY.get(data_type)
621-
if dtype is None:
622-
raise Exception("tensor stored in unsupported format")
650+
dtype = data_type.dtype
623651
fp = self.zip_file.open(info)
624652
fp.seek(offset * dtype.itemsize)
625653
size = elm_count * dtype.itemsize
@@ -683,7 +711,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
683711

684712
def convert(info: Dict[str, Any]) -> LazyTensor:
685713
data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
686-
numpy_dtype = DATA_TYPE_TO_NUMPY[data_type]
714+
numpy_dtype = data_type.dtype
687715
shape: List[int] = info['shape']
688716
begin, end = info['data_offsets']
689717
assert 0 <= begin <= end <= len(byte_buf)
@@ -723,23 +751,35 @@ def lazy_load_file(path: Path) -> ModelPlus:
723751
In = TypeVar('In')
724752
Out = TypeVar('Out')
725753

726-
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int) -> Iterable[Out]:
754+
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]:
727755
'''Parallel map, but with backpressure. If the caller doesn't call `next`
728756
fast enough, this will stop calling `func` at some point rather than
729757
letting results pile up in memory. Specifically, there is a max of one
730758
output value buffered per thread.'''
731-
with concurrent.futures.ThreadPoolExecutor() as executor:
759+
if concurrency < 2:
760+
yield from map(func, iterable)
761+
# Not reached.
762+
iterable = iter(iterable)
763+
with factory(max_workers = max_workers) as executor:
732764
futures: List[concurrent.futures.Future[Out]] = []
733-
items_rev = list(iterable)[::-1]
734-
for i in range(min(concurrency, len(items_rev))):
735-
futures.append(executor.submit(func, items_rev.pop()))
765+
done = False
766+
for _ in range(concurrency):
767+
try:
768+
futures.append(executor.submit(func, next(iterable)))
769+
except StopIteration:
770+
done = True
771+
break
772+
736773
while futures:
737774
result = futures.pop(0).result()
738-
if items_rev:
739-
futures.append(executor.submit(func, items_rev.pop()))
775+
while not done and len(futures) < concurrency:
776+
try:
777+
futures.append(executor.submit(func, next(iterable)))
778+
except StopIteration:
779+
done = True
780+
break
740781
yield result
741782

742-
743783
def check_vocab_size(params: Params, vocab: Vocab) -> None:
744784
if params.n_vocab != vocab.vocab_size:
745785
assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
@@ -804,12 +844,11 @@ def add_meta_vocab(self, vocab: Vocab) -> None:
804844
self.gguf.add_token_types(toktypes)
805845

806846
def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
807-
n_elements = 1
808-
for dim in tensor.shape:
809-
n_elements *= dim
810-
data_type = DATA_TYPE_TO_NUMPY[tensor.data_type]
811-
data_nbytes = n_elements * data_type.itemsize
812-
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes)
847+
n_elements = int(np.prod(tensor.shape))
848+
raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
849+
data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
850+
data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
851+
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype)
813852

814853
def write_meta(self) -> None:
815854
self.gguf.write_header_to_file()
@@ -835,7 +874,20 @@ def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None:
835874
of.close()
836875

837876
@staticmethod
838-
def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
877+
def do_item(item: Tuple[str, LazyTensor]) -> Tuple[DataType, NDArray]:
878+
name, lazy_tensor = item
879+
tensor = lazy_tensor.load().to_ggml()
880+
return (lazy_tensor.data_type, tensor.ndarray)
881+
882+
@staticmethod
883+
def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray:
884+
dt, arr = item
885+
if not isinstance(dt, QuantizedDataType):
886+
return arr
887+
return dt.quantize(arr)
888+
889+
@staticmethod
890+
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
839891
check_vocab_size(params, vocab)
840892

841893
of = OutputFile(fname_out)
@@ -851,16 +903,19 @@ def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -
851903
of.write_meta()
852904
of.write_tensor_info()
853905

854-
def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
855-
name, lazy_tensor = item
856-
return lazy_tensor.load().to_ggml().ndarray
857-
858906
# tensor data
859-
ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8)
907+
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
908+
if ftype == GGMLFileType.MostlyQ8_0:
909+
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
910+
else:
911+
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
912+
913+
start = time.time()
860914
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
915+
elapsed = time.time() - start
861916
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
862917
padi = len(str(len(model)))
863-
print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type}")
918+
print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}")
864919
of.gguf.write_tensor_data(ndarray)
865920

866921
of.close()
@@ -872,6 +927,8 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
872927
return GGMLFileType.AllF32
873928
if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)):
874929
return GGMLFileType.MostlyF16
930+
if output_type_str == "q8_0":
931+
return GGMLFileType.MostlyQ8_0
875932

876933
name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
877934

@@ -918,7 +975,7 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
918975
print(f"skipping tensor {name_new}")
919976
continue
920977
else:
921-
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")
978+
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
922979
out[name_new] = lazy_tensor
923980

924981
return out
@@ -1023,6 +1080,7 @@ def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
10231080
namestr = {
10241081
GGMLFileType.AllF32: "f32",
10251082
GGMLFileType.MostlyF16: "f16",
1083+
GGMLFileType.MostlyQ8_0:"q8_0",
10261084
}[file_type]
10271085
ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
10281086
if ret in model_paths:
@@ -1046,12 +1104,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
10461104
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
10471105
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
10481106
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
1049-
parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
1107+
parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
10501108
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
10511109
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
10521110
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
10531111
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
10541112
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
1113+
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
10551114
args = parser.parse_args(args_in)
10561115

10571116
if args.dump_single:
@@ -1073,6 +1132,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
10731132
params.ftype = {
10741133
"f32": GGMLFileType.AllF32,
10751134
"f16": GGMLFileType.MostlyF16,
1135+
"q8_0": GGMLFileType.MostlyQ8_0,
10761136
}[args.outtype]
10771137

10781138
print(f"params = {params}")
@@ -1104,7 +1164,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
11041164
params.ftype = ftype
11051165
print(f"Writing {outfile}, format {ftype}")
11061166

1107-
OutputFile.write_all(outfile, params, model, vocab)
1167+
OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency)
11081168
print(f"Wrote {outfile}")
11091169

11101170

0 commit comments

Comments
 (0)