Skip to content

Commit 6f74ef5

Browse files
authored
Fix torch_dtype in Kolors text encoder with transformers v4.49 (#10816)
* Fix `torch_dtype` in Kolors text encoder with `transformers` v4.49 * Default torch_dtype and warning
1 parent 9c7e205 commit 6f74ef5

File tree

8 files changed

+43
-9
lines changed

8 files changed

+43
-9
lines changed

examples/community/checkpoint_merger.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,13 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
9292
token = kwargs.pop("token", None)
9393
variant = kwargs.pop("variant", None)
9494
revision = kwargs.pop("revision", None)
95-
torch_dtype = kwargs.pop("torch_dtype", None)
95+
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
9696
device_map = kwargs.pop("device_map", None)
9797

98+
if not isinstance(torch_dtype, torch.dtype):
99+
torch_dtype = torch.float32
100+
print(f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`.")
101+
98102
alpha = kwargs.pop("alpha", 0.5)
99103
interp = kwargs.pop("interp", None)
100104

src/diffusers/loaders/single_file.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,11 +360,17 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
360360
cache_dir = kwargs.pop("cache_dir", None)
361361
local_files_only = kwargs.pop("local_files_only", False)
362362
revision = kwargs.pop("revision", None)
363-
torch_dtype = kwargs.pop("torch_dtype", None)
363+
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
364364
disable_mmap = kwargs.pop("disable_mmap", False)
365365

366366
is_legacy_loading = False
367367

368+
if not isinstance(torch_dtype, torch.dtype):
369+
torch_dtype = torch.float32
370+
logger.warning(
371+
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
372+
)
373+
368374
# We shouldn't allow configuring individual models components through a Pipeline creation method
369375
# These model kwargs should be deprecated
370376
scaling_factor = kwargs.get("scaling_factor", None)

src/diffusers/loaders/single_file_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,17 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
240240
subfolder = kwargs.pop("subfolder", None)
241241
revision = kwargs.pop("revision", None)
242242
config_revision = kwargs.pop("config_revision", None)
243-
torch_dtype = kwargs.pop("torch_dtype", None)
243+
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
244244
quantization_config = kwargs.pop("quantization_config", None)
245245
device = kwargs.pop("device", None)
246246
disable_mmap = kwargs.pop("disable_mmap", False)
247247

248+
if not isinstance(torch_dtype, torch.dtype):
249+
torch_dtype = torch.float32
250+
logger.warning(
251+
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
252+
)
253+
248254
if isinstance(pretrained_model_link_or_path_or_dict, dict):
249255
checkpoint = pretrained_model_link_or_path_or_dict
250256
else:

src/diffusers/models/modeling_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -866,7 +866,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
866866
local_files_only = kwargs.pop("local_files_only", None)
867867
token = kwargs.pop("token", None)
868868
revision = kwargs.pop("revision", None)
869-
torch_dtype = kwargs.pop("torch_dtype", None)
869+
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
870870
subfolder = kwargs.pop("subfolder", None)
871871
device_map = kwargs.pop("device_map", None)
872872
max_memory = kwargs.pop("max_memory", None)
@@ -879,6 +879,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
879879
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
880880
disable_mmap = kwargs.pop("disable_mmap", False)
881881

882+
if not isinstance(torch_dtype, torch.dtype):
883+
torch_dtype = torch.float32
884+
logger.warning(
885+
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
886+
)
887+
882888
allow_pickle = False
883889
if use_safetensors is None:
884890
use_safetensors = True

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
685685
token = kwargs.pop("token", None)
686686
revision = kwargs.pop("revision", None)
687687
from_flax = kwargs.pop("from_flax", False)
688-
torch_dtype = kwargs.pop("torch_dtype", None)
688+
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
689689
custom_pipeline = kwargs.pop("custom_pipeline", None)
690690
custom_revision = kwargs.pop("custom_revision", None)
691691
provider = kwargs.pop("provider", None)
@@ -702,6 +702,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
702702
use_onnx = kwargs.pop("use_onnx", None)
703703
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
704704

705+
if not isinstance(torch_dtype, torch.dtype):
706+
torch_dtype = torch.float32
707+
logger.warning(
708+
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
709+
)
710+
705711
if low_cpu_mem_usage and not is_accelerate_available():
706712
low_cpu_mem_usage = False
707713
logger.warning(
@@ -1826,7 +1832,7 @@ def from_pipe(cls, pipeline, **kwargs):
18261832
"""
18271833

18281834
original_config = dict(pipeline.config)
1829-
torch_dtype = kwargs.pop("torch_dtype", None)
1835+
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
18301836

18311837
# derive the pipeline class to instantiate
18321838
custom_pipeline = kwargs.pop("custom_pipeline", None)

tests/pipelines/kolors/test_kolors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def get_dummy_components(self, time_cond_proj_dim=None):
8989
sample_size=128,
9090
)
9191
torch.manual_seed(0)
92-
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
92+
text_encoder = ChatGLMModel.from_pretrained(
93+
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
94+
)
9395
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
9496

9597
components = {

tests/pipelines/kolors/test_kolors_img2img.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def get_dummy_components(self, time_cond_proj_dim=None):
9393
sample_size=128,
9494
)
9595
torch.manual_seed(0)
96-
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
96+
text_encoder = ChatGLMModel.from_pretrained(
97+
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
98+
)
9799
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
98100

99101
components = {

tests/pipelines/pag/test_pag_kolors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def get_dummy_components(self, time_cond_proj_dim=None):
9898
sample_size=128,
9999
)
100100
torch.manual_seed(0)
101-
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
101+
text_encoder = ChatGLMModel.from_pretrained(
102+
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
103+
)
102104
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
103105

104106
components = {

0 commit comments

Comments
 (0)