Skip to content

Commit 83cb20a

Browse files
author
awoo
committed
Add support for --gpu-memory witn --load-in-8bit
1 parent 23a5e88 commit 83cb20a

File tree

1 file changed

+43
-20
lines changed

1 file changed

+43
-20
lines changed

modules/models.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import numpy as np
88
import torch
99
import transformers
10-
from transformers import AutoModelForCausalLM, AutoTokenizer
10+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
11+
from accelerate import infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
1112

1213
import modules.shared as shared
1314

@@ -94,39 +95,61 @@ def load_model(model_name):
9495

9596
# Custom
9697
else:
97-
command = "AutoModelForCausalLM.from_pretrained"
98-
params = ["low_cpu_mem_usage=True"]
98+
params = {"low_cpu_mem_usage": True}
9999
if not shared.args.cpu and not torch.cuda.is_available():
100100
print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
101101
shared.args.cpu = True
102102

103103
if shared.args.cpu:
104-
params.append("low_cpu_mem_usage=True")
105-
params.append("torch_dtype=torch.float32")
104+
params["torch_dtype"] = torch.float32
106105
else:
107-
params.append("device_map='auto'")
108-
params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16")
106+
params["device_map"] = 'auto'
107+
if shared.args.load_in_8bit:
108+
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
109+
elif shared.args.bf16:
110+
params["torch_dtype"] = torch.bfloat16
111+
else:
112+
params["torch_dtype"] = torch.float16
109113

110114
if shared.args.gpu_memory:
111115
memory_map = shared.args.gpu_memory
112-
max_memory = f"max_memory={{0: '{memory_map[0]}GiB'"
116+
max_memory = { 0: f'{memory_map[0]}GiB' }
113117
for i in range(1, len(memory_map)):
114-
max_memory += (f", {i}: '{memory_map[i]}GiB'")
115-
max_memory += (f", 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
116-
params.append(max_memory)
117-
elif not shared.args.load_in_8bit:
118-
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
119-
suggestion = round((total_mem-1000)/1000)*1000
120-
if total_mem-suggestion < 800:
118+
max_memory[i] = f'{memory_map[i]}GiB'
119+
max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB'
120+
params['max_memory'] = max_memory
121+
else:
122+
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
123+
suggestion = round((total_mem - 1000) / 1000) * 1000
124+
if total_mem - suggestion < 800:
121125
suggestion -= 1000
122126
suggestion = int(round(suggestion/1000))
123127
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
124-
params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
125-
if shared.args.disk:
126-
params.append(f"offload_folder='{shared.args.disk_cache_dir}'")
128+
129+
max_memory = {
130+
0: f'{suggestion}GiB',
131+
'cpu': f'{shared.args.cpu_memory or 99}GiB'
132+
}
133+
params['max_memory'] = max_memory
127134

128-
command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})"
129-
model = eval(command)
135+
if shared.args.disk:
136+
params["offload_folder"] = shared.args.disk_cache_dir
137+
138+
checkpoint = Path(f'models/{shared.model_name}')
139+
140+
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
141+
config = AutoConfig.from_pretrained(checkpoint)
142+
with init_empty_weights():
143+
model = AutoModelForCausalLM.from_config(config)
144+
model.tie_weights()
145+
params['device_map'] = infer_auto_device_map(
146+
model,
147+
dtype=torch.int8,
148+
max_memory=params['max_memory'],
149+
no_split_module_classes = model._no_split_modules
150+
)
151+
152+
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
130153

131154
# Loading the tokenizer
132155
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():

0 commit comments

Comments
 (0)