|
13 | 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 |
| -import enum |
17 | 16 | import fnmatch
|
18 | 17 | import importlib
|
19 | 18 | import inspect
|
|
22 | 21 | import sys
|
23 | 22 | from dataclasses import dataclass
|
24 | 23 | from pathlib import Path
|
25 |
| -from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin |
| 24 | +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin |
26 | 25 |
|
27 | 26 | import numpy as np
|
28 | 27 | import PIL.Image
|
@@ -864,26 +863,6 @@ def load_module(name, value):
|
864 | 863 |
|
865 | 864 | init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
866 | 865 |
|
867 |
| - for key in init_dict.keys(): |
868 |
| - if key not in passed_class_obj: |
869 |
| - continue |
870 |
| - if "scheduler" in key: |
871 |
| - continue |
872 |
| - |
873 |
| - class_obj = passed_class_obj[key] |
874 |
| - _expected_class_types = [] |
875 |
| - for expected_type in expected_types[key]: |
876 |
| - if isinstance(expected_type, enum.EnumMeta): |
877 |
| - _expected_class_types.extend(expected_type.__members__.keys()) |
878 |
| - else: |
879 |
| - _expected_class_types.append(expected_type.__name__) |
880 |
| - |
881 |
| - _is_valid_type = class_obj.__class__.__name__ in _expected_class_types |
882 |
| - if not _is_valid_type: |
883 |
| - logger.warning( |
884 |
| - f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}." |
885 |
| - ) |
886 |
| - |
887 | 866 | # Special case: safety_checker must be loaded separately when using `from_flax`
|
888 | 867 | if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
|
889 | 868 | raise NotImplementedError(
|
@@ -1003,10 +982,82 @@ def load_module(name, value):
|
1003 | 982 | f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
1004 | 983 | )
|
1005 | 984 |
|
1006 |
| - # 10. Instantiate the pipeline |
| 985 | + # 10. Type checking init arguments |
| 986 | + def is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool: |
| 987 | + if not isinstance(class_or_tuple, tuple): |
| 988 | + class_or_tuple = (class_or_tuple,) |
| 989 | + |
| 990 | + # Unpack unions |
| 991 | + unpacked_class_or_tuple = [] |
| 992 | + for t in class_or_tuple: |
| 993 | + if get_origin(t) is Union: |
| 994 | + unpacked_class_or_tuple.extend(get_args(t)) |
| 995 | + else: |
| 996 | + unpacked_class_or_tuple.append(t) |
| 997 | + class_or_tuple = tuple(unpacked_class_or_tuple) |
| 998 | + |
| 999 | + if Any in class_or_tuple: |
| 1000 | + return True |
| 1001 | + |
| 1002 | + obj_type = type(obj) |
| 1003 | + # Classes with obj's type |
| 1004 | + class_or_tuple = {t for t in class_or_tuple if (get_origin(t) or t) is obj_type} |
| 1005 | + |
| 1006 | + # Singular types (e.g. int, ControlNet, ...) |
| 1007 | + # Untyped collections (e.g. List, but not List[int]) |
| 1008 | + elem_class_or_tuple = {get_args(t) for t in class_or_tuple} |
| 1009 | + if () in elem_class_or_tuple: |
| 1010 | + return True |
| 1011 | + # Typed lists or sets |
| 1012 | + elif obj_type in (list, set): |
| 1013 | + return any(all(is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple) |
| 1014 | + # Typed tuples |
| 1015 | + elif obj_type is tuple: |
| 1016 | + return any( |
| 1017 | + # Tuples with any length and single type (e.g. Tuple[int, ...]) |
| 1018 | + (len(t) == 2 and t[-1] is Ellipsis and all(is_valid_type(x, t[0]) for x in obj)) |
| 1019 | + or |
| 1020 | + # Tuples with fixed length and any types (e.g. Tuple[int, str]) |
| 1021 | + (len(obj) == len(t) and all(is_valid_type(x, tt) for x, tt in zip(obj, t))) |
| 1022 | + for t in elem_class_or_tuple |
| 1023 | + ) |
| 1024 | + # Typed dicts |
| 1025 | + elif obj_type is dict: |
| 1026 | + return any( |
| 1027 | + all(is_valid_type(k, kt) and is_valid_type(v, vt) for k, v in obj.items()) |
| 1028 | + for kt, vt in elem_class_or_tuple |
| 1029 | + ) |
| 1030 | + |
| 1031 | + else: |
| 1032 | + return False |
| 1033 | + |
| 1034 | + def get_detailed_type(obj: Any) -> Type: |
| 1035 | + obj_type = type(obj) |
| 1036 | + |
| 1037 | + if obj_type in (list, set): |
| 1038 | + obj_origin_type = List if obj_type is list else Set |
| 1039 | + elems_type = Union[*{get_detailed_type(x) for x in obj}] |
| 1040 | + return obj_origin_type[elems_type] |
| 1041 | + elif obj_type is tuple: |
| 1042 | + return Tuple[tuple(get_detailed_type(x) for x in obj)] |
| 1043 | + elif obj_type is dict: |
| 1044 | + keys_type = Union[*{get_detailed_type(k) for k in obj.keys()}] |
| 1045 | + values_type = Union[*{get_detailed_type(k) for k in obj.values()}] |
| 1046 | + return Dict[keys_type, values_type] |
| 1047 | + else: |
| 1048 | + return obj_type |
| 1049 | + |
| 1050 | + for key, class_obj in init_kwargs.items(): |
| 1051 | + if "scheduler" in key: |
| 1052 | + continue |
| 1053 | + |
| 1054 | + if class_obj is not None and not is_valid_type(class_obj, expected_types[key]): |
| 1055 | + logger.warning(f"Expected types for {key}: {expected_types[key]}, got {get_detailed_type(class_obj)}.") |
| 1056 | + |
| 1057 | + # 11. Instantiate the pipeline |
1007 | 1058 | model = pipeline_class(**init_kwargs)
|
1008 | 1059 |
|
1009 |
| - # 11. Save where the model was instantiated from |
| 1060 | + # 12. Save where the model was instantiated from |
1010 | 1061 | model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
1011 | 1062 | if device_map is not None:
|
1012 | 1063 | setattr(model, "hf_device_map", final_device_map)
|
|
0 commit comments