20
20
import os
21
21
import re
22
22
from collections import OrderedDict
23
- from functools import partial
23
+ from functools import partial , wraps
24
24
from pathlib import Path
25
25
from typing import Any , Callable , List , Optional , Tuple , Union
26
26
31
31
from torch import Tensor , nn
32
32
33
33
from .. import __version__
34
+ from ..quantizers import DiffusersAutoQuantizer
35
+ from ..quantizers .quantization_config import QuantizationMethod
34
36
from ..utils import (
35
37
CONFIG_NAME ,
36
38
FLAX_WEIGHTS_NAME ,
@@ -128,6 +130,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
128
130
_supports_gradient_checkpointing = False
129
131
_keys_to_ignore_on_load_unexpected = None
130
132
_no_split_modules = None
133
+ _keep_in_fp32_modules = []
131
134
132
135
def __init__ (self ):
133
136
super ().__init__ ()
@@ -407,6 +410,18 @@ def save_pretrained(
407
410
create_pr = create_pr ,
408
411
)
409
412
413
+ def dequantize (self ):
414
+ """
415
+ Potentially dequantize the model in case it has been quantized by a quantization method that support
416
+ dequantization.
417
+ """
418
+ hf_quantizer = getattr (self , "hf_quantizer" , None )
419
+
420
+ if hf_quantizer is None :
421
+ raise ValueError ("You need to first quantize your model in order to dequantize it" )
422
+
423
+ return hf_quantizer .dequantize (self )
424
+
410
425
@classmethod
411
426
@validate_hf_hub_args
412
427
def from_pretrained (cls , pretrained_model_name_or_path : Optional [Union [str , os .PathLike ]], ** kwargs ):
@@ -625,8 +640,42 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
625
640
** kwargs ,
626
641
)
627
642
628
- # determine quantization config.
629
- ##############################
643
+ # determine initial quantization config.
644
+ ###############################
645
+ pre_quantized = getattr (config , "quantization_config" , None ) is not None
646
+ if pre_quantized or quantization_config is not None :
647
+ if pre_quantized :
648
+ config .quantization_config = DiffusersAutoQuantizer .merge_quantization_configs (
649
+ config .quantization_config , quantization_config
650
+ )
651
+ else :
652
+ config .quantization_config = quantization_config
653
+ hf_quantizer = DiffusersAutoQuantizer .from_config (config .quantization_config , pre_quantized = pre_quantized )
654
+ else :
655
+ hf_quantizer = None
656
+
657
+ if hf_quantizer is not None :
658
+ hf_quantizer .validate_environment (torch_dtype = torch_dtype , from_flax = from_flax , device_map = device_map )
659
+ torch_dtype = hf_quantizer .update_torch_dtype (torch_dtype )
660
+ device_map = hf_quantizer .update_device_map (device_map )
661
+
662
+ # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
663
+ user_agent ["quant" ] = hf_quantizer .quantization_config .quant_method .value
664
+
665
+ # Force-set to `True` for more mem efficiency
666
+ if low_cpu_mem_usage is None :
667
+ low_cpu_mem_usage = True
668
+ logger .warning ("`low_cpu_mem_usage` was None, now set to True since model is quantized." )
669
+
670
+ # Check if `_keep_in_fp32_modules` is not None
671
+ use_keep_in_fp32_modules = (cls ._keep_in_fp32_modules is not None ) and (
672
+ (torch_dtype == torch .float16 ) or hasattr (hf_quantizer , "use_keep_in_fp32_modules" )
673
+ )
674
+ if use_keep_in_fp32_modules :
675
+ keep_in_fp32_modules = cls ._keep_in_fp32_modules
676
+ else :
677
+ keep_in_fp32_modules = []
678
+ ###############################
630
679
631
680
# Determine if we're loading from a directory of sharded checkpoints.
632
681
is_sharded = False
@@ -733,6 +782,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
733
782
with accelerate .init_empty_weights ():
734
783
model = cls .from_config (config , ** unused_kwargs )
735
784
785
+ if hf_quantizer is not None :
786
+ hf_quantizer .preprocess_model (
787
+ model = model , device_map = device_map , keep_in_fp32_modules = keep_in_fp32_modules
788
+ )
789
+
790
+ # We store the original dtype for quantized models as we cannot easily retrieve it
791
+ # once the weights have been quantized
792
+ # Note that once you have loaded a quantized model, you can't change its dtype so this will
793
+ # remain a single source of truth
794
+ config ._pre_quantization_dtype = torch_dtype
795
+
736
796
# if device_map is None, load the state dict and move the params from meta device to the cpu
737
797
if device_map is None and not is_sharded :
738
798
param_device = "cpu"
@@ -754,6 +814,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
754
814
device = param_device ,
755
815
dtype = torch_dtype ,
756
816
model_name_or_path = pretrained_model_name_or_path ,
817
+ hf_quantizer = hf_quantizer ,
818
+ keep_in_fp32_modules = keep_in_fp32_modules ,
757
819
)
758
820
759
821
if cls ._keys_to_ignore_on_load_unexpected is not None :
@@ -769,7 +831,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
769
831
# Load weights and dispatch according to the device_map
770
832
# by default the device_map is None and the weights are loaded on the CPU
771
833
force_hook = True
772
- device_map = _determine_device_map (model , device_map , max_memory , torch_dtype )
834
+ device_map = _determine_device_map (
835
+ model , device_map , max_memory , torch_dtype , keep_in_fp32_modules , hf_quantizer
836
+ )
773
837
if device_map is None and is_sharded :
774
838
# we load the parameters on the cpu
775
839
device_map = {"" : "cpu" }
@@ -863,6 +927,47 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
863
927
864
928
return model
865
929
930
+ @wraps (torch .nn .Module .cuda )
931
+ def cuda (self , * args , ** kwargs ):
932
+ # Checks if the model has been loaded in 8-bit
933
+ if getattr (self , "quantization_method" , None ) == QuantizationMethod .BITS_AND_BYTES :
934
+ raise ValueError (
935
+ "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the"
936
+ " model has already been set to the correct devices and casted to the correct `dtype`."
937
+ )
938
+ else :
939
+ return super ().cuda (* args , ** kwargs )
940
+
941
+ @wraps (torch .nn .Module .to )
942
+ def to (self , * args , ** kwargs ):
943
+ # Checks if the model has been loaded in 8-bit
944
+ if getattr (self , "quantization_method" , None ) == QuantizationMethod .BITS_AND_BYTES :
945
+ raise ValueError (
946
+ "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
947
+ " model has already been set to the correct devices and casted to the correct `dtype`."
948
+ )
949
+ return super ().to (* args , ** kwargs )
950
+
951
+ def half (self , * args ):
952
+ # Checks if the model is quantized
953
+ if getattr (self , "is_quantized" , False ):
954
+ raise ValueError (
955
+ "`.half()` is not supported for quantized model. Please use the model as it is, since the"
956
+ " model has already been casted to the correct `dtype`."
957
+ )
958
+ else :
959
+ return super ().half (* args )
960
+
961
+ def float (self , * args ):
962
+ # Checks if the model is quantized
963
+ if getattr (self , "is_quantized" , False ):
964
+ raise ValueError (
965
+ "`.float()` is not supported for quantized model. Please use the model as it is, since the"
966
+ " model has already been casted to the correct `dtype`."
967
+ )
968
+ else :
969
+ return super ().float (* args )
970
+
866
971
@classmethod
867
972
def _load_pretrained_model (
868
973
cls ,
0 commit comments