Skip to content

Commit 60ffa84

Browse files
authored
[bitsandbbytes] follow-ups (#9730)
* bnb follow ups. * add a warning when dtypes mismatch. * fx-copies * clear cache. * check_if_quantized_param * add a check on shape. * updates * docs * improve readability. * resources. * fix
1 parent 0f079b9 commit 60ffa84

File tree

8 files changed

+123
-65
lines changed

8 files changed

+123
-65
lines changed

docs/source/en/quantization/bitsandbytes.md

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,7 @@ model_8bit = FluxTransformer2DModel.from_pretrained(
5959
model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype
6060
```
6161

62-
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights.
63-
64-
```py
65-
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
66-
67-
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
68-
69-
model_8bit = FluxTransformer2DModel.from_pretrained(
70-
"black-forest-labs/FLUX.1-dev",
71-
subfolder="transformer",
72-
quantization_config=quantization_config
73-
)
74-
```
62+
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
7563

7664
</hfoption>
7765
<hfoption id="4-bit">
@@ -131,7 +119,7 @@ from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
131119
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
132120

133121
model_4bit = FluxTransformer2DModel.from_pretrained(
134-
"sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
122+
"hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
135123
)
136124
```
137125

@@ -264,4 +252,9 @@ double_quant_model = SD3Transformer2DModel.from_pretrained(
264252
quantization_config=double_quant_config,
265253
)
266254
model.dequantize()
267-
```
255+
```
256+
257+
## Resources
258+
259+
* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)
260+
* [Training](https://gist.github.com/sayakpaul/05afd428bc089b47af7c016e42004527)

src/diffusers/models/model_loading_utils.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -211,21 +211,28 @@ def load_model_dict_into_meta(
211211
set_module_kwargs["dtype"] = dtype
212212

213213
# bnb params are flattened.
214-
if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:
215-
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
216-
raise ValueError(
217-
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
218-
)
214+
if empty_state_dict[param_name].shape != param.shape:
215+
if (
216+
is_quant_method_bnb
217+
and hf_quantizer.pre_quantized
218+
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
219+
):
220+
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
221+
elif not is_quant_method_bnb:
222+
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
223+
raise ValueError(
224+
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
225+
)
219226

220-
if not is_quantized or (
221-
not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device)
227+
if is_quantized and (
228+
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
222229
):
230+
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
231+
else:
223232
if accepts_dtype:
224233
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
225234
else:
226235
set_module_tensor_to_device(model, param_name, device, value=param)
227-
else:
228-
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
229236

230237
return unexpected_keys
231238

src/diffusers/quantizers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .auto import DiffusersAutoQuantizationConfig, DiffusersAutoQuantizer
15+
from .auto import DiffusersAutoQuantizer
1616
from .base import DiffusersQuantizer

src/diffusers/quantizers/auto.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@
3333
}
3434

3535

36-
class DiffusersAutoQuantizationConfig:
36+
class DiffusersAutoQuantizer:
3737
"""
38-
The auto diffusers quantization config class that takes care of automatically dispatching to the correct
39-
quantization config given a quantization config stored in a dictionary.
38+
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
39+
`DiffusersQuantizer` given the `QuantizationConfig`.
4040
"""
4141

4242
@classmethod
@@ -60,31 +60,11 @@ def from_dict(cls, quantization_config_dict: Dict):
6060
target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
6161
return target_cls.from_dict(quantization_config_dict)
6262

63-
@classmethod
64-
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
65-
model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
66-
if getattr(model_config, "quantization_config", None) is None:
67-
raise ValueError(
68-
f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
69-
)
70-
quantization_config_dict = model_config.quantization_config
71-
quantization_config = cls.from_dict(quantization_config_dict)
72-
# Update with potential kwargs that are passed through from_pretrained.
73-
quantization_config.update(kwargs)
74-
return quantization_config
75-
76-
77-
class DiffusersAutoQuantizer:
78-
"""
79-
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
80-
`DiffusersQuantizer` given the `QuantizationConfig`.
81-
"""
82-
8363
@classmethod
8464
def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
8565
# Convert it to a QuantizationConfig if the q_config is a dict
8666
if isinstance(quantization_config, dict):
87-
quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config)
67+
quantization_config = cls.from_dict(quantization_config)
8868

8969
quant_method = quantization_config.quant_method
9070

@@ -107,7 +87,16 @@ def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict],
10787

10888
@classmethod
10989
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
110-
quantization_config = DiffusersAutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
90+
model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
91+
if getattr(model_config, "quantization_config", None) is None:
92+
raise ValueError(
93+
f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
94+
)
95+
quantization_config_dict = model_config.quantization_config
96+
quantization_config = cls.from_dict(quantization_config_dict)
97+
# Update with potential kwargs that are passed through from_pretrained.
98+
quantization_config.update(kwargs)
99+
111100
return cls.from_config(quantization_config)
112101

113102
@classmethod
@@ -129,7 +118,7 @@ def merge_quantization_configs(
129118
warning_msg = ""
130119

131120
if isinstance(quantization_config, dict):
132-
quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config)
121+
quantization_config = cls.from_dict(quantization_config)
133122

134123
if warning_msg != "":
135124
warnings.warn(warning_msg)

src/diffusers/quantizers/base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str,
134134
"""adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
135135
return max_memory
136136

137-
def check_quantized_param(
137+
def check_if_quantized_param(
138138
self,
139139
model: "ModelMixin",
140140
param_value: "torch.Tensor",
@@ -152,10 +152,13 @@ def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter":
152152
"""
153153
takes needed components from state_dict and creates quantized param.
154154
"""
155-
if not hasattr(self, "check_quantized_param"):
156-
raise AttributeError(
157-
f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}."
158-
)
155+
return
156+
157+
def check_quantized_param_shape(self, *args, **kwargs):
158+
"""
159+
checks if the quantized param has expected shape.
160+
"""
161+
return True
159162

160163
def validate_environment(self, *args, **kwargs):
161164
"""

src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
106106
else:
107107
raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.")
108108

109-
def check_quantized_param(
109+
def check_if_quantized_param(
110110
self,
111111
model: "ModelMixin",
112112
param_value: "torch.Tensor",
@@ -204,6 +204,16 @@ def create_quantized_param(
204204

205205
module._parameters[tensor_name] = new_value
206206

207+
def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape):
208+
n = current_param_shape.numel()
209+
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
210+
if loaded_param_shape != inferred_shape:
211+
raise ValueError(
212+
f"Expected the flattened shape of the current param ({param_name}) to be {loaded_param_shape} but is {inferred_shape}."
213+
)
214+
else:
215+
return True
216+
207217
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
208218
# need more space for buffers that are created during quantization
209219
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
@@ -330,7 +340,6 @@ def __init__(self, quantization_config, **kwargs):
330340
if self.quantization_config.llm_int8_skip_modules is not None:
331341
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
332342

333-
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4-bit->8-bit
334343
def validate_environment(self, *args, **kwargs):
335344
if not torch.cuda.is_available():
336345
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
@@ -404,7 +413,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
404413
logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization")
405414
return torch.int8
406415

407-
def check_quantized_param(
416+
def check_if_quantized_param(
408417
self,
409418
model: "ModelMixin",
410419
param_value: "torch.Tensor",

tests/quantization/bnb/test_4bit.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import gc
16+
import os
1617
import tempfile
1718
import unittest
1819

1920
import numpy as np
21+
import safetensors.torch
2022

2123
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
2224
from diffusers.utils import logging
@@ -118,6 +120,9 @@ def get_dummy_inputs(self):
118120

119121
class BnB4BitBasicTests(Base4bitTests):
120122
def setUp(self):
123+
gc.collect()
124+
torch.cuda.empty_cache()
125+
121126
# Models
122127
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
123128
self.model_name, subfolder="transformer", torch_dtype=torch.float16
@@ -232,7 +237,7 @@ def test_linear_are_4bit(self):
232237

233238
def test_config_from_pretrained(self):
234239
transformer_4bit = FluxTransformer2DModel.from_pretrained(
235-
"sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
240+
"hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
236241
)
237242
linear = get_some_linear_layer(transformer_4bit)
238243
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
@@ -312,9 +317,42 @@ def test_bnb_4bit_wrong_config(self):
312317
with self.assertRaises(ValueError):
313318
_ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add")
314319

320+
def test_bnb_4bit_errors_loading_incorrect_state_dict(self):
321+
r"""
322+
Test if loading with an incorrect state dict raises an error.
323+
"""
324+
with tempfile.TemporaryDirectory() as tmpdirname:
325+
nf4_config = BitsAndBytesConfig(load_in_4bit=True)
326+
model_4bit = SD3Transformer2DModel.from_pretrained(
327+
self.model_name, subfolder="transformer", quantization_config=nf4_config
328+
)
329+
model_4bit.save_pretrained(tmpdirname)
330+
del model_4bit
331+
332+
with self.assertRaises(ValueError) as err_context:
333+
state_dict = safetensors.torch.load_file(
334+
os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")
335+
)
336+
337+
# corrupt the state dict
338+
key_to_target = "context_embedder.weight" # can be other keys too.
339+
compatible_param = state_dict[key_to_target]
340+
corrupted_param = torch.randn(compatible_param.shape[0] - 1, 1)
341+
state_dict[key_to_target] = bnb.nn.Params4bit(corrupted_param, requires_grad=False)
342+
safetensors.torch.save_file(
343+
state_dict, os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")
344+
)
345+
346+
_ = SD3Transformer2DModel.from_pretrained(tmpdirname)
347+
348+
assert key_to_target in str(err_context.exception)
349+
315350

316351
class BnB4BitTrainingTests(Base4bitTests):
317352
def setUp(self):
353+
gc.collect()
354+
torch.cuda.empty_cache()
355+
318356
nf4_config = BitsAndBytesConfig(
319357
load_in_4bit=True,
320358
bnb_4bit_quant_type="nf4",
@@ -360,6 +398,9 @@ def test_training(self):
360398
@require_transformers_version_greater("4.44.0")
361399
class SlowBnb4BitTests(Base4bitTests):
362400
def setUp(self) -> None:
401+
gc.collect()
402+
torch.cuda.empty_cache()
403+
363404
nf4_config = BitsAndBytesConfig(
364405
load_in_4bit=True,
365406
bnb_4bit_quant_type="nf4",
@@ -447,8 +488,10 @@ def test_moving_to_cpu_throws_warning(self):
447488
@require_transformers_version_greater("4.44.0")
448489
class SlowBnb4BitFluxTests(Base4bitTests):
449490
def setUp(self) -> None:
450-
# TODO: Copy sayakpaul/flux.1-dev-nf4-pkg to testing repo.
451-
model_id = "sayakpaul/flux.1-dev-nf4-pkg"
491+
gc.collect()
492+
torch.cuda.empty_cache()
493+
494+
model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
452495
t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
453496
transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
454497
self.pipeline_4bit = DiffusionPipeline.from_pretrained(

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ def get_dummy_inputs(self):
117117

118118
class BnB8bitBasicTests(Base8bitTests):
119119
def setUp(self):
120+
gc.collect()
121+
torch.cuda.empty_cache()
122+
120123
# Models
121124
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
122125
self.model_name, subfolder="transformer", torch_dtype=torch.float16
@@ -238,7 +241,7 @@ def test_llm_skip(self):
238241

239242
def test_config_from_pretrained(self):
240243
transformer_8bit = FluxTransformer2DModel.from_pretrained(
241-
"sayakpaul/flux.1-dev-int8-pkg", subfolder="transformer"
244+
"hf-internal-testing/flux.1-dev-int8-pkg", subfolder="transformer"
242245
)
243246
linear = get_some_linear_layer(transformer_8bit)
244247
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
@@ -296,6 +299,9 @@ def test_device_and_dtype_assignment(self):
296299

297300
class BnB8bitTrainingTests(Base8bitTests):
298301
def setUp(self):
302+
gc.collect()
303+
torch.cuda.empty_cache()
304+
299305
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
300306
self.model_8bit = SD3Transformer2DModel.from_pretrained(
301307
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
@@ -337,6 +343,9 @@ def test_training(self):
337343
@require_transformers_version_greater("4.44.0")
338344
class SlowBnb8bitTests(Base8bitTests):
339345
def setUp(self) -> None:
346+
gc.collect()
347+
torch.cuda.empty_cache()
348+
340349
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
341350
model_8bit = SD3Transformer2DModel.from_pretrained(
342351
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
@@ -427,8 +436,10 @@ def test_generate_quality_dequantize(self):
427436
@require_transformers_version_greater("4.44.0")
428437
class SlowBnb8bitFluxTests(Base8bitTests):
429438
def setUp(self) -> None:
430-
# TODO: Copy sayakpaul/flux.1-dev-int8-pkg to testing repo.
431-
model_id = "sayakpaul/flux.1-dev-int8-pkg"
439+
gc.collect()
440+
torch.cuda.empty_cache()
441+
442+
model_id = "hf-internal-testing/flux.1-dev-int8-pkg"
432443
t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
433444
transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
434445
self.pipeline_8bit = DiffusionPipeline.from_pretrained(
@@ -466,6 +477,9 @@ def test_quality(self):
466477
@slow
467478
class BaseBnb8bitSerializationTests(Base8bitTests):
468479
def setUp(self):
480+
gc.collect()
481+
torch.cuda.empty_cache()
482+
469483
quantization_config = BitsAndBytesConfig(
470484
load_in_8bit=True,
471485
)

0 commit comments

Comments
 (0)