3
3
import gguf
4
4
import argparse
5
5
import concurrent .futures
6
+ from concurrent .futures import ThreadPoolExecutor , ProcessPoolExecutor
6
7
import copy
7
8
import enum
8
9
import faulthandler
17
18
import signal
18
19
import struct
19
20
import sys
21
+ import time
20
22
import zipfile
21
23
import numpy as np
22
24
23
25
from abc import ABCMeta , abstractmethod
24
26
from dataclasses import dataclass
25
27
from pathlib import Path
26
- from typing import (IO , TYPE_CHECKING , Any , Callable , Dict , Iterable , List , Literal , Optional , Sequence , Tuple , TypeVar , Union )
28
+ from typing import (IO , TYPE_CHECKING , Any , Callable , Dict , Generator , Iterable , List , Literal , Optional , Sequence , Set , Tuple , TypeVar , Union )
27
29
from sentencepiece import SentencePieceProcessor # type: ignore
28
30
29
31
if TYPE_CHECKING :
37
39
ARCH = gguf .MODEL_ARCH .LLAMA
38
40
NAMES = gguf .MODEL_TENSOR_NAMES [ARCH ]
39
41
42
+ DEFAULT_CONCURRENCY = 8
40
43
#
41
44
# data types
42
45
#
43
46
44
47
@dataclass (frozen = True )
45
- class UnquantizedDataType :
48
+ class DataType :
46
49
name : str
50
+ dtype : 'np.dtype[Any]'
51
+ valid_conversions : List [str ]
47
52
48
- DT_F16 = UnquantizedDataType ('F16' )
49
- DT_F32 = UnquantizedDataType ('F32' )
50
- DT_I32 = UnquantizedDataType ('I32' )
51
- DT_BF16 = UnquantizedDataType ('BF16' )
53
+ def elements_to_bytes (self , n_elements : int ) -> int :
54
+ return n_elements * self .dtype .itemsize
52
55
53
- DataType = Union [UnquantizedDataType ]
56
+ @dataclass (frozen = True )
57
+ class UnquantizedDataType (DataType ):
58
+ pass
54
59
55
- DATA_TYPE_TO_NUMPY : Dict [DataType , 'np.dtype[Any]' ] = {
56
- DT_BF16 : np .dtype (np .uint16 ),
57
- DT_F16 : np .dtype (np .float16 ),
58
- DT_F32 : np .dtype (np .float32 ),
59
- DT_I32 : np .dtype (np .int32 ),
60
- }
60
+ DT_F16 = UnquantizedDataType ('F16' , dtype = np .dtype (np .float16 ), valid_conversions = ['F32' , 'Q8_0' ])
61
+ DT_F32 = UnquantizedDataType ('F32' , dtype = np .dtype (np .float32 ), valid_conversions = ['F16' , 'Q8_0' ])
62
+ DT_I32 = UnquantizedDataType ('I32' , dtype = np .dtype (np .int16 ), valid_conversions = [])
63
+ DT_BF16 = UnquantizedDataType ('BF16' , dtype = np .dtype (np .uint16 ), valid_conversions = ['F32' , 'F16' , 'Q8_0' ])
64
+
65
+ @dataclass (frozen = True )
66
+ class QuantizedDataType (DataType ):
67
+ block_size : int
68
+ quantized_dtype : 'np.dtype[Any]'
69
+ ggml_type : gguf .GGMLQuantizationType
61
70
62
- NUMPY_TYPE_TO_DATA_TYPE : Dict ['np.dtype[Any]' , DataType ] = \
63
- {dtype : data_type for (data_type , dtype ) in DATA_TYPE_TO_NUMPY .items ()}
71
+ def quantize (self , arr : NDArray ) -> NDArray :
72
+ raise NotImplementedError (f'Quantization for { self .name } not implemented' )
73
+
74
+ def elements_to_bytes (self , n_elements : int ) -> int :
75
+ assert n_elements % self .block_size == 0 , f'Invalid number of elements { n_elements } for { self .name } with block size { self .block_size } '
76
+ return self .quantized_dtype .itemsize * (n_elements // self .block_size )
77
+
78
+ @dataclass (frozen = True )
79
+ class Q8_0QuantizedDataType (QuantizedDataType ):
80
+ # Mini Q8_0 quantization in Python!
81
+ def quantize (self , arr : NDArray ) -> NDArray :
82
+ assert arr .size % self .block_size == 0 and arr .size != 0 , f'Bad array size { arr .size } '
83
+ assert arr .dtype == np .float32 , f'Bad array type { arr .dtype } '
84
+ n_blocks = arr .size // self .block_size
85
+ blocks = arr .reshape ((n_blocks , self .block_size ))
86
+ # Much faster implementation of block quantization contributed by @Cebtenzzre
87
+ def quantize_blocks_q8_0 (blocks : NDArray ) -> Iterable [Tuple [Any , Any ]]:
88
+ d = abs (blocks ).max (axis = 1 ) / np .float32 (127 )
89
+ with np .errstate (divide = 'ignore' ):
90
+ qs = (blocks / d [:, None ]).round ()
91
+ qs [d == 0 ] = 0
92
+ yield from zip (d , qs )
93
+ return np .fromiter (quantize_blocks_q8_0 (blocks ), count = n_blocks , dtype = self .quantized_dtype )
94
+
95
+ DT_Q8_0 = Q8_0QuantizedDataType ('Q8_0' ,
96
+ dtype = np .dtype (np .float32 ), valid_conversions = [],
97
+ ggml_type = gguf .GGMLQuantizationType .Q8_0 , block_size = 32 ,
98
+ quantized_dtype = np .dtype ([('d' , '<f2' ), ('qs' , 'i1' , (32 ,))]))
99
+
100
+ # Quantized types skipped here because they may also map to np.float32
101
+ NUMPY_TYPE_TO_DATA_TYPE : Dict ['np.dtype[Any]' , DataType ] = {}
102
+ for dt in (DT_BF16 , DT_F16 , DT_F32 , DT_I32 ):
103
+ if dt .dtype in NUMPY_TYPE_TO_DATA_TYPE :
104
+ raise ValueError (f'Invalid duplicate data type { dt } ' )
105
+ NUMPY_TYPE_TO_DATA_TYPE [dt .dtype ] = dt
64
106
65
107
SAFETENSORS_DATA_TYPES : Dict [str , DataType ] = {
66
108
'BF16' : DT_BF16 ,
@@ -73,20 +115,22 @@ class UnquantizedDataType:
73
115
# TODO: rename to LLAMAFileType
74
116
# TODO: move to `gguf.py`
75
117
class GGMLFileType (enum .IntEnum ):
76
- AllF32 = 0
77
- MostlyF16 = 1 # except 1d tensors
118
+ AllF32 = 0
119
+ MostlyF16 = 1 # except 1d tensors
120
+ MostlyQ8_0 = 7 # except 1d tensors
78
121
79
122
def type_for_tensor (self , name : str , tensor : 'LazyTensor' ) -> DataType :
80
- if len (tensor .shape ) == 1 :
81
- # 1D tensors are always F32.
82
- return DT_F32
83
- elif self == GGMLFileType .AllF32 :
84
- return DT_F32
85
- elif self == GGMLFileType .MostlyF16 :
86
- return DT_F16
87
- else :
123
+ dt = GGML_FILE_TYPE_TO_DATA_TYPE .get (self )
124
+ if dt is None :
88
125
raise ValueError (self )
126
+ # 1D tensors are always F32.
127
+ return dt if len (tensor .shape ) > 1 else DT_F32
89
128
129
+ GGML_FILE_TYPE_TO_DATA_TYPE : Dict [GGMLFileType , DataType ] = {
130
+ GGMLFileType .AllF32 : DT_F32 ,
131
+ GGMLFileType .MostlyF16 : DT_F16 ,
132
+ GGMLFileType .MostlyQ8_0 : DT_Q8_0 ,
133
+ }
90
134
91
135
#
92
136
# hparams loading
@@ -415,7 +459,7 @@ def __init__(self, ndarray: NDArray) -> None:
415
459
self .data_type = NUMPY_TYPE_TO_DATA_TYPE [ndarray .dtype ]
416
460
417
461
def astype (self , data_type : DataType ) -> Tensor :
418
- dtype = DATA_TYPE_TO_NUMPY [ data_type ]
462
+ dtype = data_type . dtype
419
463
if self .data_type == DT_BF16 :
420
464
self .ndarray = bf16_to_fp32 (self .ndarray )
421
465
return UnquantizedTensor (self .ndarray .astype (dtype ))
@@ -454,22 +498,6 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv
454
498
GGMLCompatibleTensor = Union [UnquantizedTensor ]
455
499
456
500
457
- class DeferredPermutedTensor (Tensor ):
458
- def __init__ (self , base : Tensor , n_head : int , n_head_kv : int ) -> None :
459
- self .base = base
460
- self .n_head = n_head
461
- self .data_type = self .base .data_type
462
-
463
- def astype (self , data_type : DataType ) -> Tensor :
464
- return self .base .astype (data_type ).permute (self .n_head , self .n_head_kv )
465
-
466
- def to_ggml (self ) -> GGMLCompatibleTensor :
467
- return self .base .to_ggml ().permute (self .n_head , self .n_head_kv )
468
-
469
- def permute (self , n_head : int , n_head_kv : int ) -> Tensor :
470
- raise Exception ("shouldn't permute twice" )
471
-
472
-
473
501
@dataclass
474
502
class LazyTensor :
475
503
_load : Callable [[], Tensor ]
@@ -479,7 +507,9 @@ class LazyTensor:
479
507
480
508
def load (self ) -> Tensor :
481
509
ret = self ._load ()
482
- assert ret .data_type == self .data_type , (self .data_type , ret .data_type , self .description )
510
+ # Should be okay if it maps to the same numpy type?
511
+ assert ret .data_type == self .data_type or (self .data_type .dtype == ret .data_type .dtype ), \
512
+ (self .data_type , ret .data_type , self .description )
483
513
return ret
484
514
485
515
def astype (self , data_type : DataType ) -> 'LazyTensor' :
@@ -490,8 +520,8 @@ def load() -> Tensor:
490
520
return LazyTensor (load , self .shape , data_type , f'convert({ data_type } ) { self .description } ' )
491
521
492
522
def validate_conversion_to (self , data_type : DataType ) -> None :
493
- if data_type == self .data_type :
494
- return
523
+ if data_type != self .data_type and data_type . name not in self . data_type . valid_conversions :
524
+ raise ValueError ( f'Cannot validate conversion from { self . data_type } to { data_type } .' )
495
525
496
526
497
527
LazyModel = Dict [str , LazyTensor ]
@@ -617,9 +647,7 @@ def persistent_load(self, pid: Any) -> Any:
617
647
info = self .zip_file .getinfo (filename )
618
648
619
649
def load (offset : int , elm_count : int ) -> NDArray :
620
- dtype = DATA_TYPE_TO_NUMPY .get (data_type )
621
- if dtype is None :
622
- raise Exception ("tensor stored in unsupported format" )
650
+ dtype = data_type .dtype
623
651
fp = self .zip_file .open (info )
624
652
fp .seek (offset * dtype .itemsize )
625
653
size = elm_count * dtype .itemsize
@@ -683,7 +711,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
683
711
684
712
def convert (info : Dict [str , Any ]) -> LazyTensor :
685
713
data_type = SAFETENSORS_DATA_TYPES [info ['dtype' ]]
686
- numpy_dtype = DATA_TYPE_TO_NUMPY [ data_type ]
714
+ numpy_dtype = data_type . dtype
687
715
shape : List [int ] = info ['shape' ]
688
716
begin , end = info ['data_offsets' ]
689
717
assert 0 <= begin <= end <= len (byte_buf )
@@ -723,23 +751,35 @@ def lazy_load_file(path: Path) -> ModelPlus:
723
751
In = TypeVar ('In' )
724
752
Out = TypeVar ('Out' )
725
753
726
- def bounded_parallel_map (func : Callable [[In ], Out ], iterable : Iterable [In ], concurrency : int ) -> Iterable [Out ]:
754
+ def bounded_parallel_map (func : Callable [[In ], Out ], iterable : Iterable [In ], concurrency : int , max_workers : Optional [ int ] = None , factory : Callable = ThreadPoolExecutor ) -> Iterable [Out ]:
727
755
'''Parallel map, but with backpressure. If the caller doesn't call `next`
728
756
fast enough, this will stop calling `func` at some point rather than
729
757
letting results pile up in memory. Specifically, there is a max of one
730
758
output value buffered per thread.'''
731
- with concurrent .futures .ThreadPoolExecutor () as executor :
759
+ if concurrency < 2 :
760
+ yield from map (func , iterable )
761
+ # Not reached.
762
+ iterable = iter (iterable )
763
+ with factory (max_workers = max_workers ) as executor :
732
764
futures : List [concurrent .futures .Future [Out ]] = []
733
- items_rev = list (iterable )[::- 1 ]
734
- for i in range (min (concurrency , len (items_rev ))):
735
- futures .append (executor .submit (func , items_rev .pop ()))
765
+ done = False
766
+ for _ in range (concurrency ):
767
+ try :
768
+ futures .append (executor .submit (func , next (iterable )))
769
+ except StopIteration :
770
+ done = True
771
+ break
772
+
736
773
while futures :
737
774
result = futures .pop (0 ).result ()
738
- if items_rev :
739
- futures .append (executor .submit (func , items_rev .pop ()))
775
+ while not done and len (futures ) < concurrency :
776
+ try :
777
+ futures .append (executor .submit (func , next (iterable )))
778
+ except StopIteration :
779
+ done = True
780
+ break
740
781
yield result
741
782
742
-
743
783
def check_vocab_size (params : Params , vocab : Vocab ) -> None :
744
784
if params .n_vocab != vocab .vocab_size :
745
785
assert isinstance (vocab , BpeVocab ) or isinstance (vocab , SentencePieceVocab )
@@ -804,12 +844,11 @@ def add_meta_vocab(self, vocab: Vocab) -> None:
804
844
self .gguf .add_token_types (toktypes )
805
845
806
846
def add_tensor_info (self , name : str , tensor : LazyTensor ) -> None :
807
- n_elements = 1
808
- for dim in tensor .shape :
809
- n_elements *= dim
810
- data_type = DATA_TYPE_TO_NUMPY [tensor .data_type ]
811
- data_nbytes = n_elements * data_type .itemsize
812
- self .gguf .add_tensor_info (name , tensor .shape , data_type , data_nbytes )
847
+ n_elements = int (np .prod (tensor .shape ))
848
+ raw_dtype = getattr (tensor .data_type , 'ggml_type' , None )
849
+ data_type = getattr (tensor .data_type , 'quantized_type' , None ) or tensor .data_type .dtype
850
+ data_nbytes = tensor .data_type .elements_to_bytes (n_elements )
851
+ self .gguf .add_tensor_info (name , tensor .shape , data_type , data_nbytes , raw_dtype = raw_dtype )
813
852
814
853
def write_meta (self ) -> None :
815
854
self .gguf .write_header_to_file ()
@@ -835,7 +874,20 @@ def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None:
835
874
of .close ()
836
875
837
876
@staticmethod
838
- def write_all (fname_out : Path , params : Params , model : LazyModel , vocab : Vocab ) -> None :
877
+ def do_item (item : Tuple [str , LazyTensor ]) -> Tuple [DataType , NDArray ]:
878
+ name , lazy_tensor = item
879
+ tensor = lazy_tensor .load ().to_ggml ()
880
+ return (lazy_tensor .data_type , tensor .ndarray )
881
+
882
+ @staticmethod
883
+ def maybe_do_quantize (item : Tuple [DataType , NDArray ]) -> NDArray :
884
+ dt , arr = item
885
+ if not isinstance (dt , QuantizedDataType ):
886
+ return arr
887
+ return dt .quantize (arr )
888
+
889
+ @staticmethod
890
+ def write_all (fname_out : Path , ftype : GGMLFileType , params : Params , model : LazyModel , vocab : Vocab , concurrency : int = DEFAULT_CONCURRENCY ) -> None :
839
891
check_vocab_size (params , vocab )
840
892
841
893
of = OutputFile (fname_out )
@@ -851,16 +903,19 @@ def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -
851
903
of .write_meta ()
852
904
of .write_tensor_info ()
853
905
854
- def do_item (item : Tuple [str , LazyTensor ]) -> NDArray :
855
- name , lazy_tensor = item
856
- return lazy_tensor .load ().to_ggml ().ndarray
857
-
858
906
# tensor data
859
- ndarrays = bounded_parallel_map (do_item , model .items (), concurrency = 8 )
907
+ ndarrays_inner = bounded_parallel_map (OutputFile .do_item , model .items (), concurrency = concurrency )
908
+ if ftype == GGMLFileType .MostlyQ8_0 :
909
+ ndarrays = bounded_parallel_map (OutputFile .maybe_do_quantize , ndarrays_inner , concurrency = concurrency , max_workers = concurrency , factory = ProcessPoolExecutor )
910
+ else :
911
+ ndarrays = map (OutputFile .maybe_do_quantize , ndarrays_inner )
912
+
913
+ start = time .time ()
860
914
for i , ((name , lazy_tensor ), ndarray ) in enumerate (zip (model .items (), ndarrays )):
915
+ elapsed = time .time () - start
861
916
size = ' x ' .join (f"{ dim :6d} " for dim in lazy_tensor .shape )
862
917
padi = len (str (len (model )))
863
- print (f"[{ i + 1 :{padi }d} /{ len (model )} ] Writing tensor { name :38s} | size { size :16} | type { lazy_tensor .data_type } " )
918
+ print (f"[{ i + 1 :{padi }d} /{ len (model )} ] Writing tensor { name :38s} | size { size :16} | type { lazy_tensor .data_type . name :4 } | T+ { int ( elapsed ):4 } " )
864
919
of .gguf .write_tensor_data (ndarray )
865
920
866
921
of .close ()
@@ -872,6 +927,8 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
872
927
return GGMLFileType .AllF32
873
928
if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16 , DT_BF16 )):
874
929
return GGMLFileType .MostlyF16
930
+ if output_type_str == "q8_0" :
931
+ return GGMLFileType .MostlyQ8_0
875
932
876
933
name_to_type = {name : lazy_tensor .data_type for (name , lazy_tensor ) in model .items ()}
877
934
@@ -918,7 +975,7 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
918
975
print (f"skipping tensor { name_new } " )
919
976
continue
920
977
else :
921
- print (f"{ name :48s} -> { name_new :40s} | { lazy_tensor .data_type } | { lazy_tensor .shape } " )
978
+ print (f"{ name :48s} -> { name_new :40s} | { lazy_tensor .data_type . name :6s } | { lazy_tensor .shape } " )
922
979
out [name_new ] = lazy_tensor
923
980
924
981
return out
@@ -1023,6 +1080,7 @@ def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
1023
1080
namestr = {
1024
1081
GGMLFileType .AllF32 : "f32" ,
1025
1082
GGMLFileType .MostlyF16 : "f16" ,
1083
+ GGMLFileType .MostlyQ8_0 :"q8_0" ,
1026
1084
}[file_type ]
1027
1085
ret = model_paths [0 ].parent / f"ggml-model-{ namestr } .gguf"
1028
1086
if ret in model_paths :
@@ -1046,12 +1104,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
1046
1104
parser .add_argument ("--dump" , action = "store_true" , help = "don't convert, just show what's in the model" )
1047
1105
parser .add_argument ("--dump-single" , action = "store_true" , help = "don't convert, just show what's in a single model file" )
1048
1106
parser .add_argument ("--vocab-only" , action = "store_true" , help = "extract only the vocab" )
1049
- parser .add_argument ("--outtype" , choices = ["f32" , "f16" ], help = "output format (default: based on input)" )
1107
+ parser .add_argument ("--outtype" , choices = ["f32" , "f16" , "q8_0" ], help = "output format - note: q8_0 may be very slow (default: f16 or f32 based on input)" )
1050
1108
parser .add_argument ("--vocab-dir" , type = Path , help = "directory containing tokenizer.model, if separate from model file" )
1051
1109
parser .add_argument ("--outfile" , type = Path , help = "path to write to; default: based on input" )
1052
1110
parser .add_argument ("model" , type = Path , help = "directory containing model file, or model file itself (*.pth, *.pt, *.bin)" )
1053
1111
parser .add_argument ("--vocabtype" , choices = ["spm" , "bpe" ], help = "vocab format (default: spm)" , default = "spm" )
1054
1112
parser .add_argument ("--ctx" , type = int , help = "model training context (default: based on input)" )
1113
+ parser .add_argument ("--concurrency" , type = int , help = f"concurrency used for conversion (default: { DEFAULT_CONCURRENCY } )" , default = DEFAULT_CONCURRENCY )
1055
1114
args = parser .parse_args (args_in )
1056
1115
1057
1116
if args .dump_single :
@@ -1073,6 +1132,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
1073
1132
params .ftype = {
1074
1133
"f32" : GGMLFileType .AllF32 ,
1075
1134
"f16" : GGMLFileType .MostlyF16 ,
1135
+ "q8_0" : GGMLFileType .MostlyQ8_0 ,
1076
1136
}[args .outtype ]
1077
1137
1078
1138
print (f"params = { params } " )
@@ -1104,7 +1164,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
1104
1164
params .ftype = ftype
1105
1165
print (f"Writing { outfile } , format { ftype } " )
1106
1166
1107
- OutputFile .write_all (outfile , params , model , vocab )
1167
+ OutputFile .write_all (outfile , ftype , params , model , vocab , concurrency = args . concurrency )
1108
1168
print (f"Wrote { outfile } " )
1109
1169
1110
1170
0 commit comments