Skip to content

Commit 6653047

Browse files
committed
add max_memory argument
1 parent 26e322d commit 6653047

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
392392
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
393393
more information about each option see [designing a device
394394
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
395+
max_memory (`Dict`, *optional*):
396+
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
397+
GPU and the available CPU RAM if unset.
395398
offload_folder (`str` or `os.PathLike`, *optional*):
396399
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
397400
offload_state_dict (`bool`, *optional*):
@@ -439,6 +442,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
439442
torch_dtype = kwargs.pop("torch_dtype", None)
440443
subfolder = kwargs.pop("subfolder", None)
441444
device_map = kwargs.pop("device_map", None)
445+
max_memory = kwargs.pop("max_memory", None)
442446
offload_folder = kwargs.pop("offload_folder", None)
443447
offload_state_dict = kwargs.pop("offload_state_dict", False)
444448
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
@@ -512,6 +516,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
512516
revision=revision,
513517
subfolder=subfolder,
514518
device_map=device_map,
519+
max_memory=max_memory,
515520
offload_folder=offload_folder,
516521
offload_state_dict=offload_state_dict,
517522
user_agent=user_agent,
@@ -621,6 +626,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
621626
model,
622627
model_file,
623628
device_map,
629+
max_memory=max_memory,
624630
offload_folder=offload_folder,
625631
offload_state_dict=offload_state_dict,
626632
dtype=torch_dtype,

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def load_sub_model(
355355
provider: Any,
356356
sess_options: Any,
357357
device_map: Optional[Union[Dict[str, torch.device], str]],
358+
max_memory: Optional[Dict[Union[int, str], Union[int, str]]],
358359
offload_folder: Optional[Union[str, os.PathLike]],
359360
offload_state_dict: bool,
360361
model_variants: Dict[str, str],
@@ -419,6 +420,7 @@ def load_sub_model(
419420
# This makes sure that the weights won't be initialized which significantly speeds up loading.
420421
if is_diffusers_model or is_transformers_model:
421422
loading_kwargs["device_map"] = device_map
423+
loading_kwargs["max_memory"] = max_memory
422424
loading_kwargs["offload_folder"] = offload_folder
423425
loading_kwargs["offload_state_dict"] = offload_state_dict
424426
loading_kwargs["variant"] = model_variants.pop(name, None)
@@ -813,6 +815,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
813815
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
814816
more information about each option see [designing a device
815817
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
818+
max_memory (`Dict`, *optional*):
819+
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
820+
GPU and the available CPU RAM if unset.
816821
offload_folder (`str` or `os.PathLike`, *optional*):
817822
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
818823
offload_state_dict (`bool`, *optional*):
@@ -884,6 +889,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
884889
provider = kwargs.pop("provider", None)
885890
sess_options = kwargs.pop("sess_options", None)
886891
device_map = kwargs.pop("device_map", None)
892+
max_memory = kwargs.pop("max_memory", None)
887893
offload_folder = kwargs.pop("offload_folder", None)
888894
offload_state_dict = kwargs.pop("offload_state_dict", False)
889895
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
@@ -1059,6 +1065,7 @@ def load_module(name, value):
10591065
provider=provider,
10601066
sess_options=sess_options,
10611067
device_map=device_map,
1068+
max_memory=max_memory,
10621069
offload_folder=offload_folder,
10631070
offload_state_dict=offload_state_dict,
10641071
model_variants=model_variants,

0 commit comments

Comments
 (0)