Skip to content

Comprehensive type checking for from_pretrained kwargs #10758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Feb 22, 2025
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 75 additions & 24 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import fnmatch
import importlib
import inspect
Expand All @@ -22,7 +21,7 @@
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin

import numpy as np
import PIL.Image
Expand Down Expand Up @@ -864,26 +863,6 @@ def load_module(name, value):

init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}

for key in init_dict.keys():
if key not in passed_class_obj:
continue
if "scheduler" in key:
continue

class_obj = passed_class_obj[key]
_expected_class_types = []
for expected_type in expected_types[key]:
if isinstance(expected_type, enum.EnumMeta):
_expected_class_types.extend(expected_type.__members__.keys())
else:
_expected_class_types.append(expected_type.__name__)

_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
if not _is_valid_type:
logger.warning(
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
)

# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
raise NotImplementedError(
Expand Down Expand Up @@ -1003,10 +982,82 @@ def load_module(name, value):
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)

# 10. Instantiate the pipeline
# 10. Type checking init arguments
def is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
if not isinstance(class_or_tuple, tuple):
class_or_tuple = (class_or_tuple,)

# Unpack unions
unpacked_class_or_tuple = []
for t in class_or_tuple:
if get_origin(t) is Union:
unpacked_class_or_tuple.extend(get_args(t))
else:
unpacked_class_or_tuple.append(t)
class_or_tuple = tuple(unpacked_class_or_tuple)

if Any in class_or_tuple:
return True

obj_type = type(obj)
# Classes with obj's type
class_or_tuple = {t for t in class_or_tuple if (get_origin(t) or t) is obj_type}

# Singular types (e.g. int, ControlNet, ...)
# Untyped collections (e.g. List, but not List[int])
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
if () in elem_class_or_tuple:
return True
# Typed lists or sets
elif obj_type in (list, set):
return any(all(is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
# Typed tuples
elif obj_type is tuple:
return any(
# Tuples with any length and single type (e.g. Tuple[int, ...])
(len(t) == 2 and t[-1] is Ellipsis and all(is_valid_type(x, t[0]) for x in obj))
or
# Tuples with fixed length and any types (e.g. Tuple[int, str])
(len(obj) == len(t) and all(is_valid_type(x, tt) for x, tt in zip(obj, t)))
for t in elem_class_or_tuple
)
# Typed dicts
elif obj_type is dict:
return any(
all(is_valid_type(k, kt) and is_valid_type(v, vt) for k, v in obj.items())
for kt, vt in elem_class_or_tuple
)

else:
return False

def get_detailed_type(obj: Any) -> Type:
obj_type = type(obj)

if obj_type in (list, set):
obj_origin_type = List if obj_type is list else Set
elems_type = Union[tuple({get_detailed_type(x) for x in obj})]
return obj_origin_type[elems_type]
elif obj_type is tuple:
return Tuple[tuple(get_detailed_type(x) for x in obj)]
elif obj_type is dict:
keys_type = Union[tuple({get_detailed_type(k) for k in obj.keys()})]
values_type = Union[tuple({get_detailed_type(k) for k in obj.values()})]
return Dict[keys_type, values_type]
else:
return obj_type

for key, class_obj in init_kwargs.items():
if "scheduler" in key:
continue

if class_obj is not None and not is_valid_type(class_obj, expected_types[key]):
logger.warning(f"Expected types for {key}: {expected_types[key]}, got {get_detailed_type(class_obj)}.")

# 11. Instantiate the pipeline
model = pipeline_class(**init_kwargs)

# 11. Save where the model was instantiated from
# 12. Save where the model was instantiated from
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
if device_map is not None:
setattr(model, "hf_device_map", final_device_map)
Expand Down
Loading