@@ -355,6 +355,7 @@ def load_sub_model(
355
355
provider : Any ,
356
356
sess_options : Any ,
357
357
device_map : Optional [Union [Dict [str , torch .device ], str ]],
358
+ max_memory : Optional [Dict [Union [int , str ], Union [int , str ]]],
358
359
offload_folder : Optional [Union [str , os .PathLike ]],
359
360
offload_state_dict : bool ,
360
361
model_variants : Dict [str , str ],
@@ -419,6 +420,7 @@ def load_sub_model(
419
420
# This makes sure that the weights won't be initialized which significantly speeds up loading.
420
421
if is_diffusers_model or is_transformers_model :
421
422
loading_kwargs ["device_map" ] = device_map
423
+ loading_kwargs ["max_memory" ] = max_memory
422
424
loading_kwargs ["offload_folder" ] = offload_folder
423
425
loading_kwargs ["offload_state_dict" ] = offload_state_dict
424
426
loading_kwargs ["variant" ] = model_variants .pop (name , None )
@@ -813,6 +815,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
813
815
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
814
816
more information about each option see [designing a device
815
817
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
818
+ max_memory (`Dict`, *optional*):
819
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
820
+ GPU and the available CPU RAM if unset.
816
821
offload_folder (`str` or `os.PathLike`, *optional*):
817
822
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
818
823
offload_state_dict (`bool`, *optional*):
@@ -884,6 +889,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
884
889
provider = kwargs .pop ("provider" , None )
885
890
sess_options = kwargs .pop ("sess_options" , None )
886
891
device_map = kwargs .pop ("device_map" , None )
892
+ max_memory = kwargs .pop ("max_memory" , None )
887
893
offload_folder = kwargs .pop ("offload_folder" , None )
888
894
offload_state_dict = kwargs .pop ("offload_state_dict" , False )
889
895
low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT )
@@ -1059,6 +1065,7 @@ def load_module(name, value):
1059
1065
provider = provider ,
1060
1066
sess_options = sess_options ,
1061
1067
device_map = device_map ,
1068
+ max_memory = max_memory ,
1062
1069
offload_folder = offload_folder ,
1063
1070
offload_state_dict = offload_state_dict ,
1064
1071
model_variants = model_variants ,
0 commit comments