Skip to content

Commit 345b6de

Browse files
committed
refactor quant models loader and add support of OPT
1 parent 2c4699a commit 345b6de

File tree

1 file changed

+9
-17
lines changed

1 file changed

+9
-17
lines changed

modules/quantized_LLaMA.py renamed to modules/quant_loader.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,20 @@
77
import modules.shared as shared
88

99
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
10-
from llama import load_quant
1110

1211

1312
# 4-bit LLaMA
14-
def load_quantized_LLaMA(model_name):
15-
if shared.args.load_in_4bit:
16-
bits = 4
13+
def load_quant(model_name, model_type):
14+
if model_type == 'llama':
15+
from llama import load_quant
16+
elif model_type == 'opt':
17+
from opt import load_quant
1718
else:
18-
bits = shared.args.gptq_bits
19+
print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported")
20+
exit()
1921

2022
path_to_model = Path(f'models/{model_name}')
21-
pt_model = ''
22-
if path_to_model.name.lower().startswith('llama-7b'):
23-
pt_model = f'llama-7b-{bits}bit.pt'
24-
elif path_to_model.name.lower().startswith('llama-13b'):
25-
pt_model = f'llama-13b-{bits}bit.pt'
26-
elif path_to_model.name.lower().startswith('llama-30b'):
27-
pt_model = f'llama-30b-{bits}bit.pt'
28-
elif path_to_model.name.lower().startswith('llama-65b'):
29-
pt_model = f'llama-65b-{bits}bit.pt'
30-
else:
31-
pt_model = f'{model_name}-{bits}bit.pt'
23+
pt_model = f'{model_name}-{shared.args.gptq_bits}bit.pt'
3224

3325
# Try to find the .pt both in models/ and in the subfolder
3426
pt_path = None
@@ -40,7 +32,7 @@ def load_quantized_LLaMA(model_name):
4032
print(f"Could not find {pt_model}, exiting...")
4133
exit()
4234

43-
model = load_quant(path_to_model, str(pt_path), bits)
35+
model = load_quant(path_to_model, str(pt_path), shared.args.gptq_bits)
4436

4537
# Multiple GPUs or GPU+CPU
4638
if shared.args.gpu_memory:

0 commit comments

Comments
 (0)