|
7 | 7 | import numpy as np
|
8 | 8 | import torch
|
9 | 9 | import transformers
|
10 |
| -from transformers import AutoModelForCausalLM, AutoTokenizer |
| 10 | +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig |
| 11 | +from accelerate import infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch |
11 | 12 |
|
12 | 13 | import modules.shared as shared
|
13 | 14 |
|
@@ -94,39 +95,61 @@ def load_model(model_name):
|
94 | 95 |
|
95 | 96 | # Custom
|
96 | 97 | else:
|
97 |
| - command = "AutoModelForCausalLM.from_pretrained" |
98 |
| - params = ["low_cpu_mem_usage=True"] |
| 98 | + params = {"low_cpu_mem_usage": True} |
99 | 99 | if not shared.args.cpu and not torch.cuda.is_available():
|
100 | 100 | print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
|
101 | 101 | shared.args.cpu = True
|
102 | 102 |
|
103 | 103 | if shared.args.cpu:
|
104 |
| - params.append("low_cpu_mem_usage=True") |
105 |
| - params.append("torch_dtype=torch.float32") |
| 104 | + params["torch_dtype"] = torch.float32 |
106 | 105 | else:
|
107 |
| - params.append("device_map='auto'") |
108 |
| - params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16") |
| 106 | + params["device_map"] = 'auto' |
| 107 | + if shared.args.load_in_8bit: |
| 108 | + params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) |
| 109 | + elif shared.args.bf16: |
| 110 | + params["torch_dtype"] = torch.bfloat16 |
| 111 | + else: |
| 112 | + params["torch_dtype"] = torch.float16 |
109 | 113 |
|
110 | 114 | if shared.args.gpu_memory:
|
111 | 115 | memory_map = shared.args.gpu_memory
|
112 |
| - max_memory = f"max_memory={{0: '{memory_map[0]}GiB'" |
| 116 | + max_memory = { 0: f'{memory_map[0]}GiB' } |
113 | 117 | for i in range(1, len(memory_map)):
|
114 |
| - max_memory += (f", {i}: '{memory_map[i]}GiB'") |
115 |
| - max_memory += (f", 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}") |
116 |
| - params.append(max_memory) |
117 |
| - elif not shared.args.load_in_8bit: |
118 |
| - total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024)) |
119 |
| - suggestion = round((total_mem-1000)/1000)*1000 |
120 |
| - if total_mem-suggestion < 800: |
| 118 | + max_memory[i] = f'{memory_map[i]}GiB' |
| 119 | + max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB' |
| 120 | + params['max_memory'] = max_memory |
| 121 | + else: |
| 122 | + total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024)) |
| 123 | + suggestion = round((total_mem - 1000) / 1000) * 1000 |
| 124 | + if total_mem - suggestion < 800: |
121 | 125 | suggestion -= 1000
|
122 | 126 | suggestion = int(round(suggestion/1000))
|
123 | 127 | print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
124 |
| - params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}") |
125 |
| - if shared.args.disk: |
126 |
| - params.append(f"offload_folder='{shared.args.disk_cache_dir}'") |
| 128 | + |
| 129 | + max_memory = { |
| 130 | + 0: f'{suggestion}GiB', |
| 131 | + 'cpu': f'{shared.args.cpu_memory or 99}GiB' |
| 132 | + } |
| 133 | + params['max_memory'] = max_memory |
127 | 134 |
|
128 |
| - command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})" |
129 |
| - model = eval(command) |
| 135 | + if shared.args.disk: |
| 136 | + params["offload_folder"] = shared.args.disk_cache_dir |
| 137 | + |
| 138 | + checkpoint = Path(f'models/{shared.model_name}') |
| 139 | + |
| 140 | + if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': |
| 141 | + config = AutoConfig.from_pretrained(checkpoint) |
| 142 | + with init_empty_weights(): |
| 143 | + model = AutoModelForCausalLM.from_config(config) |
| 144 | + model.tie_weights() |
| 145 | + params['device_map'] = infer_auto_device_map( |
| 146 | + model, |
| 147 | + dtype=torch.int8, |
| 148 | + max_memory=params['max_memory'], |
| 149 | + no_split_module_classes = model._no_split_modules |
| 150 | + ) |
| 151 | + |
| 152 | + model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) |
130 | 153 |
|
131 | 154 | # Loading the tokenizer
|
132 | 155 | if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():
|
|
0 commit comments