Skip to content

Commit 210fa1e

Browse files
committed
up
1 parent d44ef85 commit 210fa1e

File tree

3 files changed

+140
-0
lines changed

3 files changed

+140
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
529529
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
530530
variant = kwargs.pop("variant", None)
531531
use_safetensors = kwargs.pop("use_safetensors", None)
532+
quantization_config = kwargs.pop("quantization_config", None)
532533

533534
allow_pickle = False
534535
if use_safetensors is None:
@@ -624,6 +625,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
624625
**kwargs,
625626
)
626627

628+
# determine quantization config.
629+
##############################
630+
627631
# Determine if we're loading from a directory of sharded checkpoints.
628632
is_sharded = False
629633
index_file = None

src/diffusers/quantizers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@
2727

2828
if is_torch_available():
2929
_import_structure["base"] = ["DiffusersQuantizer"]
30+
3031
if is_bitsandbytes_available() and is_accelerate_available():
3132
_import_structure["bitsandbytes"] = [
3233
"set_module_quantized_tensor_to_device",
3334
"replace_with_bnb_linear",
3435
"dequantize_bnb_weight",
3536
"dequantize_and_replace",
37+
"BitsAndBytesConfig"
3638
]
3739

3840
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -46,6 +48,7 @@
4648
replace_with_bnb_linear,
4749
set_module_quantized_tensor_to_device,
4850
)
51+
from .quantization_config import BitsAndBytesConfig
4952

5053
else:
5154
import sys

src/diffusers/quantizers/auto.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
2+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import warnings
16+
from typing import Dict, Optional, Union
17+
18+
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
19+
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod
20+
21+
22+
AUTO_QUANTIZER_MAPPING = {
23+
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
24+
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
25+
}
26+
27+
AUTO_QUANTIZATION_CONFIG_MAPPING = {
28+
"bitsandbytes_4bit": BitsAndBytesConfig,
29+
"bitsandbytes_8bit": BitsAndBytesConfig,
30+
}
31+
32+
33+
class DiffusersAutoQuantizationConfig:
34+
"""
35+
The Auto-HF quantization config class that takes care of automatically dispatching to the correct
36+
quantization config given a quantization config stored in a dictionary.
37+
"""
38+
39+
@classmethod
40+
def from_dict(cls, quantization_config_dict: Dict):
41+
quant_method = quantization_config_dict.get("quant_method", None)
42+
# We need a special care for bnb models to make sure everything is BC ..
43+
if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
44+
suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
45+
quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
46+
elif quant_method is None:
47+
raise ValueError(
48+
"The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
49+
)
50+
51+
if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
52+
raise ValueError(
53+
f"Unknown quantization type, got {quant_method} - supported types are:"
54+
f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
55+
)
56+
57+
target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
58+
return target_cls.from_dict(quantization_config_dict)
59+
60+
@classmethod
61+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
62+
model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
63+
if getattr(model_config, "quantization_config", None) is None:
64+
raise ValueError(
65+
f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
66+
)
67+
quantization_config_dict = model_config.quantization_config
68+
quantization_config = cls.from_dict(quantization_config_dict)
69+
# Update with potential kwargs that are passed through from_pretrained.
70+
quantization_config.update(kwargs)
71+
return quantization_config
72+
73+
74+
class DiffusersAutoQuantizer:
75+
"""
76+
The Auto-HF quantizer class that takes care of automatically instantiating to the correct
77+
`HfQuantizer` given the `QuantizationConfig`.
78+
"""
79+
80+
@classmethod
81+
def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
82+
# Convert it to a QuantizationConfig if the q_config is a dict
83+
if isinstance(quantization_config, dict):
84+
quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config)
85+
86+
quant_method = quantization_config.quant_method
87+
88+
# Again, we need a special care for bnb as we have a single quantization config
89+
# class for both 4-bit and 8-bit quantization
90+
if quant_method == QuantizationMethod.BITS_AND_BYTES:
91+
if quantization_config.load_in_8bit:
92+
quant_method += "_8bit"
93+
else:
94+
quant_method += "_4bit"
95+
96+
if quant_method not in AUTO_QUANTIZER_MAPPING.keys():
97+
raise ValueError(
98+
f"Unknown quantization type, got {quant_method} - supported types are:"
99+
f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
100+
)
101+
102+
target_cls = AUTO_QUANTIZER_MAPPING[quant_method]
103+
return target_cls(quantization_config, **kwargs)
104+
105+
@classmethod
106+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
107+
quantization_config = DiffusersAutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
108+
return cls.from_config(quantization_config)
109+
110+
@classmethod
111+
def merge_quantization_configs(
112+
cls,
113+
quantization_config: Union[dict, QuantizationConfigMixin],
114+
quantization_config_from_args: Optional[QuantizationConfigMixin],
115+
):
116+
"""
117+
handles situations where both quantization_config from args and quantization_config from model config are present.
118+
"""
119+
if quantization_config_from_args is not None:
120+
warning_msg = (
121+
"You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
122+
" already has a `quantization_config` attribute. The `quantization_config` from the model will be used."
123+
)
124+
else:
125+
warning_msg = ""
126+
127+
if isinstance(quantization_config, dict):
128+
quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config)
129+
130+
if warning_msg != "":
131+
warnings.warn(warning_msg)
132+
133+
return quantization_config

0 commit comments

Comments
 (0)