Skip to content

Commit f4feee1

Browse files
committed
up
rotary embedding refactor 2: update comments, fix dtype for use_real=False (#9312) fix notes and dtype up up
1 parent 210fa1e commit f4feee1

File tree

9 files changed

+207
-83
lines changed

9 files changed

+207
-83
lines changed

src/diffusers/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"loaders": ["FromOriginalModelMixin"],
3232
"models": [],
3333
"pipelines": [],
34-
"quantizers": [],
34+
"quantizers": ["BitsAndBytesConfig"],
3535
"schedulers": [],
3636
"utils": [
3737
"OptionalDependencyNotAvailable",
@@ -155,7 +155,7 @@
155155
"StableDiffusionMixin",
156156
]
157157
)
158-
_import_structure["quantizers"] = ["HfQuantizer"]
158+
_import_structure["quantizers"] = ["DiffusersQuantizer"]
159159
_import_structure["schedulers"].extend(
160160
[
161161
"AmusedScheduler",
@@ -527,6 +527,7 @@
527527

528528
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
529529
from .configuration_utils import ConfigMixin
530+
from .quantizers import BitsAndBytesConfig
530531

531532
try:
532533
if not is_onnx_available():

src/diffusers/models/embeddings.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def get_1d_rotary_pos_embed(
514514
linear_factor=1.0,
515515
ntk_factor=1.0,
516516
repeat_interleave_real=True,
517-
freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux)
517+
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
518518
):
519519
"""
520520
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -551,15 +551,18 @@ def get_1d_rotary_pos_embed(
551551
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
552552
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
553553
if use_real and repeat_interleave_real:
554+
# flux, hunyuan-dit, cogvideox
554555
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
555556
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
556557
return freqs_cos, freqs_sin
557558
elif use_real:
559+
# stable audio
558560
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
559561
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
560562
return freqs_cos, freqs_sin
561563
else:
562-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2]
564+
# lumina
565+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
563566
return freqs_cis
564567

565568

@@ -590,11 +593,11 @@ def apply_rotary_emb(
590593
cos, sin = cos.to(x.device), sin.to(x.device)
591594

592595
if use_real_unbind_dim == -1:
593-
# Use for example in Lumina
596+
# Used for flux, cogvideox, hunyuan-dit
594597
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
595598
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
596599
elif use_real_unbind_dim == -2:
597-
# Use for example in Stable Audio
600+
# Used for Stable Audio
598601
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
599602
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
600603
else:
@@ -604,6 +607,7 @@ def apply_rotary_emb(
604607

605608
return out
606609
else:
610+
# used for lumina
607611
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
608612
freqs_cis = freqs_cis.unsqueeze(2)
609613
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)

src/diffusers/models/model_loading_utils.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,38 @@
5353

5454

5555
# Adapted from `transformers` (see modeling_utils.py)
56-
def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
56+
def _determine_device_map(
57+
model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
58+
):
5759
if isinstance(device_map, str):
60+
special_dtypes = {}
61+
62+
if hf_quantizer is not None:
63+
special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
64+
65+
special_dtypes.update(
66+
{
67+
name: torch.float32
68+
for name, _ in model.named_parameters()
69+
if any(m in name for m in keep_in_fp32_modules)
70+
}
71+
)
72+
73+
target_dtype = torch_dtype
74+
if hf_quantizer is not None:
75+
target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
76+
5877
no_split_modules = model._get_no_split_modules(device_map)
5978
device_map_kwargs = {"no_split_module_classes": no_split_modules}
6079

80+
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
81+
device_map_kwargs["special_dtypes"] = special_dtypes
82+
elif len(special_dtypes) > 0:
83+
logger.warning(
84+
"This model has some weights that should be kept in higher precision, you need to upgrade "
85+
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
86+
)
87+
6188
if device_map != "sequential":
6289
max_memory = get_balanced_memory(
6390
model,
@@ -69,8 +96,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
6996
else:
7097
max_memory = get_max_memory(max_memory)
7198

99+
if hf_quantizer is not None:
100+
max_memory = hf_quantizer.adjust_max_memory(max_memory)
101+
72102
device_map_kwargs["max_memory"] = max_memory
73-
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
103+
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
104+
105+
if hf_quantizer is not None:
106+
hf_quantizer.validate_environment(device_map=device_map)
74107

75108
return device_map
76109

@@ -136,29 +169,57 @@ def load_model_dict_into_meta(
136169
device: Optional[Union[str, torch.device]] = None,
137170
dtype: Optional[Union[str, torch.dtype]] = None,
138171
model_name_or_path: Optional[str] = None,
172+
hf_quantizer=None,
173+
keep_in_fp32_modules=None,
139174
) -> List[str]:
140175
device = device or torch.device("cpu")
141176
dtype = dtype or torch.float32
177+
is_quantized = hf_quantizer is not None
142178

143179
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
144180

145181
unexpected_keys = []
146182
empty_state_dict = model.state_dict()
183+
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
184+
147185
for param_name, param in state_dict.items():
148186
if param_name not in empty_state_dict:
149187
unexpected_keys.append(param_name)
150188
continue
151189

190+
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
191+
# in int/uint/bool and not cast them.
192+
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
193+
if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn:
194+
if (
195+
keep_in_fp32_modules is not None
196+
and any(
197+
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
198+
)
199+
and dtype == torch.float16
200+
):
201+
param = param.to(torch.float32)
202+
else:
203+
param = param.to(dtype)
204+
152205
if empty_state_dict[param_name].shape != param.shape:
153206
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
154207
raise ValueError(
155208
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
156209
)
157210

158-
if accepts_dtype:
159-
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
211+
if (
212+
not is_quantized
213+
or (not hf_quantizer.requires_parameters_quantization)
214+
or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device))
215+
):
216+
if accepts_dtype:
217+
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
218+
else:
219+
set_module_tensor_to_device(model, param_name, device, value=param)
160220
else:
161-
set_module_tensor_to_device(model, param_name, device, value=param)
221+
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
222+
162223
return unexpected_keys
163224

164225

src/diffusers/models/modeling_utils.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import os
2121
import re
2222
from collections import OrderedDict
23-
from functools import partial
23+
from functools import partial, wraps
2424
from pathlib import Path
2525
from typing import Any, Callable, List, Optional, Tuple, Union
2626

@@ -31,6 +31,8 @@
3131
from torch import Tensor, nn
3232

3333
from .. import __version__
34+
from ..quantizers import DiffusersAutoQuantizer
35+
from ..quantizers.quantization_config import QuantizationMethod
3436
from ..utils import (
3537
CONFIG_NAME,
3638
FLAX_WEIGHTS_NAME,
@@ -128,6 +130,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
128130
_supports_gradient_checkpointing = False
129131
_keys_to_ignore_on_load_unexpected = None
130132
_no_split_modules = None
133+
_keep_in_fp32_modules = []
131134

132135
def __init__(self):
133136
super().__init__()
@@ -407,6 +410,18 @@ def save_pretrained(
407410
create_pr=create_pr,
408411
)
409412

413+
def dequantize(self):
414+
"""
415+
Potentially dequantize the model in case it has been quantized by a quantization method that support
416+
dequantization.
417+
"""
418+
hf_quantizer = getattr(self, "hf_quantizer", None)
419+
420+
if hf_quantizer is None:
421+
raise ValueError("You need to first quantize your model in order to dequantize it")
422+
423+
return hf_quantizer.dequantize(self)
424+
410425
@classmethod
411426
@validate_hf_hub_args
412427
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
@@ -625,8 +640,42 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
625640
**kwargs,
626641
)
627642

628-
# determine quantization config.
629-
##############################
643+
# determine initial quantization config.
644+
###############################
645+
pre_quantized = getattr(config, "quantization_config", None) is not None
646+
if pre_quantized or quantization_config is not None:
647+
if pre_quantized:
648+
config.quantization_config = DiffusersAutoQuantizer.merge_quantization_configs(
649+
config.quantization_config, quantization_config
650+
)
651+
else:
652+
config.quantization_config = quantization_config
653+
hf_quantizer = DiffusersAutoQuantizer.from_config(config.quantization_config, pre_quantized=pre_quantized)
654+
else:
655+
hf_quantizer = None
656+
657+
if hf_quantizer is not None:
658+
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
659+
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
660+
device_map = hf_quantizer.update_device_map(device_map)
661+
662+
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
663+
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
664+
665+
# Force-set to `True` for more mem efficiency
666+
if low_cpu_mem_usage is None:
667+
low_cpu_mem_usage = True
668+
logger.warning("`low_cpu_mem_usage` was None, now set to True since model is quantized.")
669+
670+
# Check if `_keep_in_fp32_modules` is not None
671+
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
672+
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
673+
)
674+
if use_keep_in_fp32_modules:
675+
keep_in_fp32_modules = cls._keep_in_fp32_modules
676+
else:
677+
keep_in_fp32_modules = []
678+
###############################
630679

631680
# Determine if we're loading from a directory of sharded checkpoints.
632681
is_sharded = False
@@ -733,6 +782,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
733782
with accelerate.init_empty_weights():
734783
model = cls.from_config(config, **unused_kwargs)
735784

785+
if hf_quantizer is not None:
786+
hf_quantizer.preprocess_model(
787+
model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
788+
)
789+
790+
# We store the original dtype for quantized models as we cannot easily retrieve it
791+
# once the weights have been quantized
792+
# Note that once you have loaded a quantized model, you can't change its dtype so this will
793+
# remain a single source of truth
794+
config._pre_quantization_dtype = torch_dtype
795+
736796
# if device_map is None, load the state dict and move the params from meta device to the cpu
737797
if device_map is None and not is_sharded:
738798
param_device = "cpu"
@@ -754,6 +814,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
754814
device=param_device,
755815
dtype=torch_dtype,
756816
model_name_or_path=pretrained_model_name_or_path,
817+
hf_quantizer=hf_quantizer,
818+
keep_in_fp32_modules=keep_in_fp32_modules,
757819
)
758820

759821
if cls._keys_to_ignore_on_load_unexpected is not None:
@@ -769,7 +831,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
769831
# Load weights and dispatch according to the device_map
770832
# by default the device_map is None and the weights are loaded on the CPU
771833
force_hook = True
772-
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
834+
device_map = _determine_device_map(
835+
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
836+
)
773837
if device_map is None and is_sharded:
774838
# we load the parameters on the cpu
775839
device_map = {"": "cpu"}
@@ -863,6 +927,47 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
863927

864928
return model
865929

930+
@wraps(torch.nn.Module.cuda)
931+
def cuda(self, *args, **kwargs):
932+
# Checks if the model has been loaded in 8-bit
933+
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
934+
raise ValueError(
935+
"Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the"
936+
" model has already been set to the correct devices and casted to the correct `dtype`."
937+
)
938+
else:
939+
return super().cuda(*args, **kwargs)
940+
941+
@wraps(torch.nn.Module.to)
942+
def to(self, *args, **kwargs):
943+
# Checks if the model has been loaded in 8-bit
944+
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
945+
raise ValueError(
946+
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
947+
" model has already been set to the correct devices and casted to the correct `dtype`."
948+
)
949+
return super().to(*args, **kwargs)
950+
951+
def half(self, *args):
952+
# Checks if the model is quantized
953+
if getattr(self, "is_quantized", False):
954+
raise ValueError(
955+
"`.half()` is not supported for quantized model. Please use the model as it is, since the"
956+
" model has already been casted to the correct `dtype`."
957+
)
958+
else:
959+
return super().half(*args)
960+
961+
def float(self, *args):
962+
# Checks if the model is quantized
963+
if getattr(self, "is_quantized", False):
964+
raise ValueError(
965+
"`.float()` is not supported for quantized model. Please use the model as it is, since the"
966+
" model has already been casted to the correct `dtype`."
967+
)
968+
else:
969+
return super().float(*args)
970+
866971
@classmethod
867972
def _load_pretrained_model(
868973
cls,

0 commit comments

Comments
 (0)