Skip to content

Comprehensive type checking for from_pretrained kwargs #10758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __init__(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
unet: Union[UNet2DConditionModel, UNetMotionModel],
motion_adapter: MotionAdapter,
scheduler: Union[
DDIMScheduler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def __init__(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
unet: Union[UNet2DConditionModel, UNetMotionModel],
motion_adapter: MotionAdapter,
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: Union[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ def __init__(
Tuple[HunyuanDiT2DControlNetModel],
HunyuanDiT2DMultiControlNetModel,
],
text_encoder_2=T5EncoderModel,
tokenizer_2=MT5Tokenizer,
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
requires_safety_checker: bool = True,
):
super().__init__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
PreTrainedModel,
SiglipImageProcessor,
SiglipVisionModel,
T5EncoderModel,
T5TokenizerFast,
)
Expand Down Expand Up @@ -178,9 +178,9 @@ class StableDiffusion3ControlNetPipeline(
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
additional conditioning.
image_encoder (`PreTrainedModel`, *optional*):
image_encoder (`SiglipVisionModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`BaseImageProcessor`, *optional*):
feature_extractor (`SiglipImageProcessor`, *optional*):
Image processor for IP Adapter.
"""

Expand All @@ -202,8 +202,8 @@ def __init__(
controlnet: Union[
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
],
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
image_encoder: Optional[SiglipVisionModel] = None,
feature_extractor: Optional[SiglipImageProcessor] = None,
):
super().__init__()
if isinstance(controlnet, (list, tuple)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
PreTrainedModel,
SiglipImageProcessor,
SiglipModel,
T5EncoderModel,
T5TokenizerFast,
)
Expand Down Expand Up @@ -223,8 +223,8 @@ def __init__(
controlnet: Union[
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
],
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
image_encoder: SiglipModel = None,
feature_extractor: Optional[SiglipImageProcessor] = None,
):
super().__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import torch

from ...models import UNet1DModel
from ...schedulers import SchedulerMixin
from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
Expand Down Expand Up @@ -49,7 +51,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):

model_cpu_offload_seq = "unet"

def __init__(self, unet, scheduler):
def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch

from ...models import UNet2DModel
from ...schedulers import DDIMScheduler
from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
Expand Down Expand Up @@ -47,7 +48,7 @@ class DDIMPipeline(DiffusionPipeline):

model_cpu_offload_seq = "unet"

def __init__(self, unet, scheduler):
def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler):
super().__init__()

# make sure scheduler can always be converted to DDIM
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import torch

from ...models import UNet2DModel
from ...schedulers import DDPMScheduler
from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
Expand Down Expand Up @@ -47,7 +49,7 @@ class DDPMPipeline(DiffusionPipeline):

model_cpu_offload_seq = "unet"

def __init__(self, unet, scheduler):
def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline):
scheduler: RePaintScheduler
model_cpu_offload_seq = "unet"

def __init__(self, unet, scheduler):
def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def __init__(
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
text_encoder_2=T5EncoderModel,
tokenizer_2=MT5Tokenizer,
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
):
super().__init__()

Expand Down
17 changes: 7 additions & 10 deletions src/diffusers/pipelines/lumina/pipeline_lumina.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import List, Optional, Tuple, Union

import torch
from transformers import AutoModel, AutoTokenizer
from transformers import GemmaPreTrainedModel, GemmaTokenizer, GemmaTokenizerFast

from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL
Expand Down Expand Up @@ -143,13 +143,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`AutoModel`]):
Frozen text-encoder. Lumina-T2I uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
tokenizer (`AutoModel`):
Tokenizer of class
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
text_encoder ([`GemmaPreTrainedModel`]):
Frozen Gemma text-encoder.
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
Gemma tokenizer.
transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
Expand Down Expand Up @@ -180,8 +177,8 @@ def __init__(
transformer: LuminaNextDiT2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: AutoModel,
tokenizer: AutoTokenizer,
text_encoder: GemmaPreTrainedModel,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
):
super().__init__()

Expand Down
17 changes: 7 additions & 10 deletions src/diffusers/pipelines/lumina2/pipeline_lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast

from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL
Expand Down Expand Up @@ -150,13 +150,10 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`AutoModel`]):
Frozen text-encoder. Lumina-T2I uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
tokenizer (`AutoModel`):
Tokenizer of class
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
text_encoder ([`Gemma2PreTrainedModel`]):
Frozen Gemma2 text-encoder.
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
Gemma tokenizer.
transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
Expand All @@ -172,8 +169,8 @@ def __init__(
transformer: Lumina2Transformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: AutoModel,
tokenizer: AutoTokenizer,
text_encoder: Gemma2PreTrainedModel,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
):
super().__init__()

Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipelines/pag/pipeline_pag_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PixArtImageProcessor
Expand Down Expand Up @@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):

def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: AutoModelForCausalLM,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC,
transformer: SanaTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
Expand Down
107 changes: 83 additions & 24 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import fnmatch
import importlib
import inspect
Expand All @@ -22,7 +21,7 @@
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin

import numpy as np
import PIL.Image
Expand Down Expand Up @@ -864,26 +863,6 @@ def load_module(name, value):

init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}

for key in init_dict.keys():
if key not in passed_class_obj:
continue
if "scheduler" in key:
continue

class_obj = passed_class_obj[key]
_expected_class_types = []
for expected_type in expected_types[key]:
if isinstance(expected_type, enum.EnumMeta):
_expected_class_types.extend(expected_type.__members__.keys())
else:
_expected_class_types.append(expected_type.__name__)

_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
if not _is_valid_type:
logger.warning(
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
)

# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
raise NotImplementedError(
Expand Down Expand Up @@ -1003,10 +982,90 @@ def load_module(name, value):
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)

# 10. Instantiate the pipeline
# 10. Type checking init arguments
def is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
if not isinstance(class_or_tuple, tuple):
class_or_tuple = (class_or_tuple,)

# Unpack unions
unpacked_class_or_tuple = []
for t in class_or_tuple:
if get_origin(t) is Union:
unpacked_class_or_tuple.extend(get_args(t))
else:
unpacked_class_or_tuple.append(t)
class_or_tuple = tuple(unpacked_class_or_tuple)

if Any in class_or_tuple:
return True

obj_type = type(obj)
# Classes with obj's type
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}

# Singular types (e.g. int, ControlNet, ...)
# Untyped collections (e.g. List, but not List[int])
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
if () in elem_class_or_tuple:
return True
# Typed lists or sets
elif obj_type in (list, set):
return any(all(is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
# Typed tuples
elif obj_type is tuple:
return any(
# Tuples with any length and single type (e.g. Tuple[int, ...])
(len(t) == 2 and t[-1] is Ellipsis and all(is_valid_type(x, t[0]) for x in obj))
or
# Tuples with fixed length and any types (e.g. Tuple[int, str])
(len(obj) == len(t) and all(is_valid_type(x, tt) for x, tt in zip(obj, t)))
for t in elem_class_or_tuple
)
# Typed dicts
elif obj_type is dict:
return any(
all(is_valid_type(k, kt) and is_valid_type(v, vt) for k, v in obj.items())
for kt, vt in elem_class_or_tuple
)

else:
return False

def get_detailed_type(obj: Any) -> Type:
obj_type = type(obj)

if obj_type in (list, set):
obj_origin_type = List if obj_type is list else Set
elems_type = Union[tuple({get_detailed_type(x) for x in obj})]
return obj_origin_type[elems_type]
elif obj_type is tuple:
return Tuple[tuple(get_detailed_type(x) for x in obj)]
elif obj_type is dict:
keys_type = Union[tuple({get_detailed_type(k) for k in obj.keys()})]
values_type = Union[tuple({get_detailed_type(k) for k in obj.values()})]
return Dict[keys_type, values_type]
else:
return obj_type

for kw, arg in init_kwargs.items():
# Too complex to validate with type annotation alone
if "scheduler" in kw:
continue
# Many tokenizer annotations don't include its "Fast" variant, so skip this
# e.g T5Tokenizer but not T5TokenizerFast
elif "tokenizer" in kw:
continue
elif (
arg is not None
and not expected_types[kw] == (inspect.Signature.empty,) # no type annotations
and not is_valid_type(arg, expected_types[kw])
):
logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {get_detailed_type(arg)}.")

# 11. Instantiate the pipeline
model = pipeline_class(**init_kwargs)

# 11. Save where the model was instantiated from
# 12. Save where the model was instantiated from
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
if device_map is not None:
setattr(model, "hf_device_map", final_device_map)
Expand Down
Loading