Skip to content

Commit d75ea3c

Browse files
authored
device_map in load_model_dict_into_meta (#10851)
* `device_map` in `load_model_dict_into_meta` * _LOW_CPU_MEM_USAGE_DEFAULT * fix is_peft_version is_bitsandbytes_version
1 parent b27d4ed commit d75ea3c

File tree

4 files changed

+23
-18
lines changed

4 files changed

+23
-18
lines changed

src/diffusers/loaders/transformer_flux.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
ImageProjection,
1818
MultiIPAdapterImageProjection,
1919
)
20-
from ..models.modeling_utils import load_model_dict_into_meta
20+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
2121
from ..utils import (
2222
is_accelerate_available,
2323
is_torch_version,
@@ -36,7 +36,7 @@ class FluxTransformer2DLoadersMixin:
3636
Load layers into a [`FluxTransformer2DModel`].
3737
"""
3838

39-
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
39+
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
4040
if low_cpu_mem_usage:
4141
if is_accelerate_available():
4242
from accelerate import init_empty_weights
@@ -82,11 +82,12 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
8282
if not low_cpu_mem_usage:
8383
image_projection.load_state_dict(updated_state_dict, strict=True)
8484
else:
85-
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
85+
device_map = {"": self.device}
86+
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
8687

8788
return image_projection
8889

89-
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
90+
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
9091
from ..models.attention_processor import (
9192
FluxIPAdapterJointAttnProcessor2_0,
9293
)
@@ -151,15 +152,15 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
151152
if not low_cpu_mem_usage:
152153
attn_procs[name].load_state_dict(value_dict)
153154
else:
154-
device = self.device
155+
device_map = {"": self.device}
155156
dtype = self.dtype
156-
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
157+
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
157158

158159
key_id += 1
159160

160161
return attn_procs
161162

162-
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
163+
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
163164
if not isinstance(state_dicts, list):
164165
state_dicts = [state_dicts]
165166

src/diffusers/loaders/transformer_sd3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ def _convert_ip_adapter_attn_to_diffusers(
7575
if not low_cpu_mem_usage:
7676
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
7777
else:
78+
device_map = {"": self.device}
7879
load_model_dict_into_meta(
79-
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
80+
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
8081
)
8182

8283
return attn_procs
@@ -144,7 +145,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(
144145
if not low_cpu_mem_usage:
145146
image_proj.load_state_dict(updated_state_dict, strict=True)
146147
else:
147-
load_model_dict_into_meta(image_proj, updated_state_dict, device=self.device, dtype=self.dtype)
148+
device_map = {"": self.device}
149+
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
148150

149151
return image_proj
150152

src/diffusers/loaders/unet.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
IPAdapterPlusImageProjection,
3131
MultiIPAdapterImageProjection,
3232
)
33-
from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
33+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
3434
from ..utils import (
3535
USE_PEFT_BACKEND,
3636
_get_model_file,
@@ -143,7 +143,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
143143
adapter_name = kwargs.pop("adapter_name", None)
144144
_pipeline = kwargs.pop("_pipeline", None)
145145
network_alphas = kwargs.pop("network_alphas", None)
146-
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
146+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
147147
allow_pickle = False
148148

149149
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
@@ -540,7 +540,7 @@ def _get_custom_diffusion_state_dict(self):
540540

541541
return state_dict
542542

543-
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
543+
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
544544
if low_cpu_mem_usage:
545545
if is_accelerate_available():
546546
from accelerate import init_empty_weights
@@ -753,11 +753,12 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
753753
if not low_cpu_mem_usage:
754754
image_projection.load_state_dict(updated_state_dict, strict=True)
755755
else:
756-
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
756+
device_map = {"": self.device}
757+
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
757758

758759
return image_projection
759760

760-
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
761+
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
761762
from ..models.attention_processor import (
762763
IPAdapterAttnProcessor,
763764
IPAdapterAttnProcessor2_0,
@@ -846,13 +847,14 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
846847
else:
847848
device = next(iter(value_dict.values())).device
848849
dtype = next(iter(value_dict.values())).dtype
849-
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
850+
device_map = {"": device}
851+
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
850852

851853
key_id += 2
852854

853855
return attn_procs
854856

855-
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
857+
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
856858
if not isinstance(state_dicts, list):
857859
state_dicts = [state_dicts]
858860

src/diffusers/utils/import_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,7 @@ def is_peft_version(operation: str, version: str):
815815
version (`str`):
816816
A version string
817817
"""
818-
if not _peft_version:
818+
if not _peft_available:
819819
return False
820820
return compare_versions(parse(_peft_version), operation, version)
821821

@@ -829,7 +829,7 @@ def is_bitsandbytes_version(operation: str, version: str):
829829
version (`str`):
830830
A version string
831831
"""
832-
if not _bitsandbytes_version:
832+
if not _bitsandbytes_available:
833833
return False
834834
return compare_versions(parse(_bitsandbytes_version), operation, version)
835835

0 commit comments

Comments
 (0)