|
7 | 7 | import numpy as np
|
8 | 8 | import torch
|
9 | 9 | import transformers
|
10 |
| -from transformers import AutoModelForCausalLM, AutoTokenizer |
| 10 | +from accelerate import infer_auto_device_map, init_empty_weights |
| 11 | +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, |
| 12 | + BitsAndBytesConfig) |
11 | 13 |
|
12 | 14 | import modules.shared as shared
|
13 | 15 |
|
@@ -94,39 +96,58 @@ def load_model(model_name):
|
94 | 96 |
|
95 | 97 | # Custom
|
96 | 98 | else:
|
97 |
| - command = "AutoModelForCausalLM.from_pretrained" |
98 |
| - params = ["low_cpu_mem_usage=True"] |
| 99 | + params = {"low_cpu_mem_usage": True} |
99 | 100 | if not shared.args.cpu and not torch.cuda.is_available():
|
100 | 101 | print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
|
101 | 102 | shared.args.cpu = True
|
102 | 103 |
|
103 | 104 | if shared.args.cpu:
|
104 |
| - params.append("low_cpu_mem_usage=True") |
105 |
| - params.append("torch_dtype=torch.float32") |
| 105 | + params["torch_dtype"] = torch.float32 |
106 | 106 | else:
|
107 |
| - params.append("device_map='auto'") |
108 |
| - params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16") |
| 107 | + params["device_map"] = 'auto' |
| 108 | + if shared.args.load_in_8bit: |
| 109 | + params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) |
| 110 | + elif shared.args.bf16: |
| 111 | + params["torch_dtype"] = torch.bfloat16 |
| 112 | + else: |
| 113 | + params["torch_dtype"] = torch.float16 |
109 | 114 |
|
110 | 115 | if shared.args.gpu_memory:
|
111 | 116 | memory_map = shared.args.gpu_memory
|
112 |
| - max_memory = f"max_memory={{0: '{memory_map[0]}GiB'" |
113 |
| - for i in range(1, len(memory_map)): |
114 |
| - max_memory += (f", {i}: '{memory_map[i]}GiB'") |
115 |
| - max_memory += (f", 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}") |
116 |
| - params.append(max_memory) |
117 |
| - elif not shared.args.load_in_8bit: |
118 |
| - total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024)) |
119 |
| - suggestion = round((total_mem-1000)/1000)*1000 |
120 |
| - if total_mem-suggestion < 800: |
| 117 | + max_memory = {} |
| 118 | + for i in range(len(memory_map)): |
| 119 | + max_memory[i] = f'{memory_map[i]}GiB' |
| 120 | + max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB' |
| 121 | + params['max_memory'] = max_memory |
| 122 | + else: |
| 123 | + total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024)) |
| 124 | + suggestion = round((total_mem-1000) / 1000) * 1000 |
| 125 | + if total_mem - suggestion < 800: |
121 | 126 | suggestion -= 1000
|
122 | 127 | suggestion = int(round(suggestion/1000))
|
123 | 128 | print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
124 |
| - params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}") |
125 |
| - if shared.args.disk: |
126 |
| - params.append(f"offload_folder='{shared.args.disk_cache_dir}'") |
| 129 | + |
| 130 | + max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'} |
| 131 | + params['max_memory'] = max_memory |
127 | 132 |
|
128 |
| - command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})" |
129 |
| - model = eval(command) |
| 133 | + if shared.args.disk: |
| 134 | + params["offload_folder"] = shared.args.disk_cache_dir |
| 135 | + |
| 136 | + checkpoint = Path(f'models/{shared.model_name}') |
| 137 | + |
| 138 | + if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': |
| 139 | + config = AutoConfig.from_pretrained(checkpoint) |
| 140 | + with init_empty_weights(): |
| 141 | + model = AutoModelForCausalLM.from_config(config) |
| 142 | + model.tie_weights() |
| 143 | + params['device_map'] = infer_auto_device_map( |
| 144 | + model, |
| 145 | + dtype=torch.int8, |
| 146 | + max_memory=params['max_memory'], |
| 147 | + no_split_module_classes = model._no_split_modules |
| 148 | + ) |
| 149 | + |
| 150 | + model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) |
130 | 151 |
|
131 | 152 | # Loading the tokenizer
|
132 | 153 | if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():
|
|
0 commit comments