7
7
import modules .shared as shared
8
8
9
9
sys .path .insert (0 , str (Path ("repositories/GPTQ-for-LLaMa" )))
10
- from llama import load_quant
11
10
12
11
13
12
# 4-bit LLaMA
14
- def load_quantized_LLaMA (model_name ):
15
- if shared .args .load_in_4bit :
16
- bits = 4
13
+ def load_quant (model_name , model_type ):
14
+ if model_type == 'llama' :
15
+ from llama import load_quant
16
+ elif model_type == 'opt' :
17
+ from opt import load_quant
17
18
else :
18
- bits = shared .args .gptq_bits
19
+ print ("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported" )
20
+ exit ()
19
21
20
22
path_to_model = Path (f'models/{ model_name } ' )
21
- pt_model = ''
22
- if path_to_model .name .lower ().startswith ('llama-7b' ):
23
- pt_model = f'llama-7b-{ bits } bit.pt'
24
- elif path_to_model .name .lower ().startswith ('llama-13b' ):
25
- pt_model = f'llama-13b-{ bits } bit.pt'
26
- elif path_to_model .name .lower ().startswith ('llama-30b' ):
27
- pt_model = f'llama-30b-{ bits } bit.pt'
28
- elif path_to_model .name .lower ().startswith ('llama-65b' ):
29
- pt_model = f'llama-65b-{ bits } bit.pt'
30
- else :
31
- pt_model = f'{ model_name } -{ bits } bit.pt'
23
+ pt_model = f'{ model_name } -{ shared .args .gptq_bits } bit.pt'
32
24
33
25
# Try to find the .pt both in models/ and in the subfolder
34
26
pt_path = None
@@ -40,7 +32,7 @@ def load_quantized_LLaMA(model_name):
40
32
print (f"Could not find { pt_model } , exiting..." )
41
33
exit ()
42
34
43
- model = load_quant (path_to_model , str (pt_path ), bits )
35
+ model = load_quant (path_to_model , str (pt_path ), shared . args . gptq_bits )
44
36
45
37
# Multiple GPUs or GPU+CPU
46
38
if shared .args .gpu_memory :
0 commit comments