Skip to content

Commit ce75466

Browse files
committed
More robust from_pretrained init_kwargs type checking
1 parent 9f5ad1d commit ce75466

File tree

1 file changed

+75
-24
lines changed

1 file changed

+75
-24
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 75 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
import enum
1716
import fnmatch
1817
import importlib
1918
import inspect
@@ -22,7 +21,7 @@
2221
import sys
2322
from dataclasses import dataclass
2423
from pathlib import Path
25-
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
24+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin
2625

2726
import numpy as np
2827
import PIL.Image
@@ -864,26 +863,6 @@ def load_module(name, value):
864863

865864
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
866865

867-
for key in init_dict.keys():
868-
if key not in passed_class_obj:
869-
continue
870-
if "scheduler" in key:
871-
continue
872-
873-
class_obj = passed_class_obj[key]
874-
_expected_class_types = []
875-
for expected_type in expected_types[key]:
876-
if isinstance(expected_type, enum.EnumMeta):
877-
_expected_class_types.extend(expected_type.__members__.keys())
878-
else:
879-
_expected_class_types.append(expected_type.__name__)
880-
881-
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
882-
if not _is_valid_type:
883-
logger.warning(
884-
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
885-
)
886-
887866
# Special case: safety_checker must be loaded separately when using `from_flax`
888867
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
889868
raise NotImplementedError(
@@ -1003,10 +982,82 @@ def load_module(name, value):
1003982
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
1004983
)
1005984

1006-
# 10. Instantiate the pipeline
985+
# 10. Type checking init arguments
986+
def is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
987+
if not isinstance(class_or_tuple, tuple):
988+
class_or_tuple = (class_or_tuple,)
989+
990+
# Unpack unions
991+
unpacked_class_or_tuple = []
992+
for t in class_or_tuple:
993+
if get_origin(t) is Union:
994+
unpacked_class_or_tuple.extend(get_args(t))
995+
else:
996+
unpacked_class_or_tuple.append(t)
997+
class_or_tuple = tuple(unpacked_class_or_tuple)
998+
999+
if Any in class_or_tuple:
1000+
return True
1001+
1002+
obj_type = type(obj)
1003+
# Classes with obj's type
1004+
class_or_tuple = {t for t in class_or_tuple if (get_origin(t) or t) is obj_type}
1005+
1006+
# Singular types (e.g. int, ControlNet, ...)
1007+
# Untyped collections (e.g. List, but not List[int])
1008+
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
1009+
if () in elem_class_or_tuple:
1010+
return True
1011+
# Typed lists or sets
1012+
elif obj_type in (list, set):
1013+
return any(all(is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
1014+
# Typed tuples
1015+
elif obj_type is tuple:
1016+
return any(
1017+
# Tuples with any length and single type (e.g. Tuple[int, ...])
1018+
(len(t) == 2 and t[-1] is Ellipsis and all(is_valid_type(x, t[0]) for x in obj))
1019+
or
1020+
# Tuples with fixed length and any types (e.g. Tuple[int, str])
1021+
(len(obj) == len(t) and all(is_valid_type(x, tt) for x, tt in zip(obj, t)))
1022+
for t in elem_class_or_tuple
1023+
)
1024+
# Typed dicts
1025+
elif obj_type is dict:
1026+
return any(
1027+
all(is_valid_type(k, kt) and is_valid_type(v, vt) for k, v in obj.items())
1028+
for kt, vt in elem_class_or_tuple
1029+
)
1030+
1031+
else:
1032+
return False
1033+
1034+
def get_detailed_type(obj: Any) -> Type:
1035+
obj_type = type(obj)
1036+
1037+
if obj_type in (list, set):
1038+
obj_origin_type = List if obj_type is list else Set
1039+
elems_type = Union[*{get_detailed_type(x) for x in obj}]
1040+
return obj_origin_type[elems_type]
1041+
elif obj_type is tuple:
1042+
return Tuple[tuple(get_detailed_type(x) for x in obj)]
1043+
elif obj_type is dict:
1044+
keys_type = Union[*{get_detailed_type(k) for k in obj.keys()}]
1045+
values_type = Union[*{get_detailed_type(k) for k in obj.values()}]
1046+
return Dict[keys_type, values_type]
1047+
else:
1048+
return obj_type
1049+
1050+
for key, class_obj in init_kwargs.items():
1051+
if "scheduler" in key:
1052+
continue
1053+
1054+
if class_obj is not None and not is_valid_type(class_obj, expected_types[key]):
1055+
logger.warning(f"Expected types for {key}: {expected_types[key]}, got {get_detailed_type(class_obj)}.")
1056+
1057+
# 11. Instantiate the pipeline
10071058
model = pipeline_class(**init_kwargs)
10081059

1009-
# 11. Save where the model was instantiated from
1060+
# 12. Save where the model was instantiated from
10101061
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
10111062
if device_map is not None:
10121063
setattr(model, "hf_device_map", final_device_map)

0 commit comments

Comments
 (0)