Skip to content

Commit 9c7e205

Browse files
guiyrthlky
andauthored
Comprehensive type checking for from_pretrained kwargs (#10758)
* More robust from_pretrained init_kwargs type checking * Corrected for Python 3.10 * Type checks subclasses and fixed type warnings * More type corrections and skip tokenizer type checking * make style && make quality * Updated docs and types for Lumina pipelines * Fixed check for empty signature * changed location of helper functions * make style --------- Co-authored-by: hlky <[email protected]>
1 parent 64dec70 commit 9c7e205

26 files changed

+208
-114
lines changed

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def __init__(
224224
vae: AutoencoderKL,
225225
text_encoder: CLIPTextModel,
226226
tokenizer: CLIPTokenizer,
227-
unet: UNet2DConditionModel,
227+
unet: Union[UNet2DConditionModel, UNetMotionModel],
228228
motion_adapter: MotionAdapter,
229229
scheduler: Union[
230230
DDIMScheduler,

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(
246246
vae: AutoencoderKL,
247247
text_encoder: CLIPTextModel,
248248
tokenizer: CLIPTokenizer,
249-
unet: UNet2DConditionModel,
249+
unet: Union[UNet2DConditionModel, UNetMotionModel],
250250
motion_adapter: MotionAdapter,
251251
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
252252
scheduler: Union[

src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ def __init__(
232232
Tuple[HunyuanDiT2DControlNetModel],
233233
HunyuanDiT2DMultiControlNetModel,
234234
],
235-
text_encoder_2=T5EncoderModel,
236-
tokenizer_2=MT5Tokenizer,
235+
text_encoder_2: Optional[T5EncoderModel] = None,
236+
tokenizer_2: Optional[MT5Tokenizer] = None,
237237
requires_safety_checker: bool = True,
238238
):
239239
super().__init__()

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
import torch
1919
from transformers import (
20-
BaseImageProcessor,
2120
CLIPTextModelWithProjection,
2221
CLIPTokenizer,
23-
PreTrainedModel,
22+
SiglipImageProcessor,
23+
SiglipVisionModel,
2424
T5EncoderModel,
2525
T5TokenizerFast,
2626
)
@@ -178,9 +178,9 @@ class StableDiffusion3ControlNetPipeline(
178178
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
179179
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
180180
additional conditioning.
181-
image_encoder (`PreTrainedModel`, *optional*):
181+
image_encoder (`SiglipVisionModel`, *optional*):
182182
Pre-trained Vision Model for IP Adapter.
183-
feature_extractor (`BaseImageProcessor`, *optional*):
183+
feature_extractor (`SiglipImageProcessor`, *optional*):
184184
Image processor for IP Adapter.
185185
"""
186186

@@ -202,8 +202,8 @@ def __init__(
202202
controlnet: Union[
203203
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
204204
],
205-
image_encoder: PreTrainedModel = None,
206-
feature_extractor: BaseImageProcessor = None,
205+
image_encoder: Optional[SiglipVisionModel] = None,
206+
feature_extractor: Optional[SiglipImageProcessor] = None,
207207
):
208208
super().__init__()
209209
if isinstance(controlnet, (list, tuple)):

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
import torch
1919
from transformers import (
20-
BaseImageProcessor,
2120
CLIPTextModelWithProjection,
2221
CLIPTokenizer,
23-
PreTrainedModel,
22+
SiglipImageProcessor,
23+
SiglipModel,
2424
T5EncoderModel,
2525
T5TokenizerFast,
2626
)
@@ -223,8 +223,8 @@ def __init__(
223223
controlnet: Union[
224224
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
225225
],
226-
image_encoder: PreTrainedModel = None,
227-
feature_extractor: BaseImageProcessor = None,
226+
image_encoder: SiglipModel = None,
227+
feature_extractor: Optional[SiglipImageProcessor] = None,
228228
):
229229
super().__init__()
230230

src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import torch
1919

20+
from ...models import UNet1DModel
21+
from ...schedulers import SchedulerMixin
2022
from ...utils import is_torch_xla_available, logging
2123
from ...utils.torch_utils import randn_tensor
2224
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
@@ -49,7 +51,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
4951

5052
model_cpu_offload_seq = "unet"
5153

52-
def __init__(self, unet, scheduler):
54+
def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
5355
super().__init__()
5456
self.register_modules(unet=unet, scheduler=scheduler)
5557

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818

19+
from ...models import UNet2DModel
1920
from ...schedulers import DDIMScheduler
2021
from ...utils import is_torch_xla_available
2122
from ...utils.torch_utils import randn_tensor
@@ -47,7 +48,7 @@ class DDIMPipeline(DiffusionPipeline):
4748

4849
model_cpu_offload_seq = "unet"
4950

50-
def __init__(self, unet, scheduler):
51+
def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler):
5152
super().__init__()
5253

5354
# make sure scheduler can always be converted to DDIM

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import torch
1919

20+
from ...models import UNet2DModel
21+
from ...schedulers import DDPMScheduler
2022
from ...utils import is_torch_xla_available
2123
from ...utils.torch_utils import randn_tensor
2224
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -47,7 +49,7 @@ class DDPMPipeline(DiffusionPipeline):
4749

4850
model_cpu_offload_seq = "unet"
4951

50-
def __init__(self, unet, scheduler):
52+
def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler):
5153
super().__init__()
5254
self.register_modules(unet=unet, scheduler=scheduler)
5355

src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline):
9191
scheduler: RePaintScheduler
9292
model_cpu_offload_seq = "unet"
9393

94-
def __init__(self, unet, scheduler):
94+
def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler):
9595
super().__init__()
9696
self.register_modules(unet=unet, scheduler=scheduler)
9797

src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def __init__(
207207
safety_checker: StableDiffusionSafetyChecker,
208208
feature_extractor: CLIPImageProcessor,
209209
requires_safety_checker: bool = True,
210-
text_encoder_2=T5EncoderModel,
211-
tokenizer_2=MT5Tokenizer,
210+
text_encoder_2: Optional[T5EncoderModel] = None,
211+
tokenizer_2: Optional[MT5Tokenizer] = None,
212212
):
213213
super().__init__()
214214

src/diffusers/pipelines/lumina/pipeline_lumina.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Callable, Dict, List, Optional, Tuple, Union
2121

2222
import torch
23-
from transformers import AutoModel, AutoTokenizer
23+
from transformers import GemmaPreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
2424

2525
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2626
from ...image_processor import VaeImageProcessor
@@ -144,13 +144,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
144144
Args:
145145
vae ([`AutoencoderKL`]):
146146
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
147-
text_encoder ([`AutoModel`]):
148-
Frozen text-encoder. Lumina-T2I uses
149-
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
150-
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
151-
tokenizer (`AutoModel`):
152-
Tokenizer of class
153-
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
147+
text_encoder ([`GemmaPreTrainedModel`]):
148+
Frozen Gemma text-encoder.
149+
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
150+
Gemma tokenizer.
154151
transformer ([`Transformer2DModel`]):
155152
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
156153
scheduler ([`SchedulerMixin`]):
@@ -185,8 +182,8 @@ def __init__(
185182
transformer: LuminaNextDiT2DModel,
186183
scheduler: FlowMatchEulerDiscreteScheduler,
187184
vae: AutoencoderKL,
188-
text_encoder: AutoModel,
189-
tokenizer: AutoTokenizer,
185+
text_encoder: GemmaPreTrainedModel,
186+
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
190187
):
191188
super().__init__()
192189

src/diffusers/pipelines/lumina2/pipeline_lumina2.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import numpy as np
1919
import torch
20-
from transformers import AutoModel, AutoTokenizer
20+
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
2121

2222
from ...image_processor import VaeImageProcessor
2323
from ...loaders import Lumina2LoraLoaderMixin
@@ -143,13 +143,10 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
143143
Args:
144144
vae ([`AutoencoderKL`]):
145145
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
146-
text_encoder ([`AutoModel`]):
147-
Frozen text-encoder. Lumina-T2I uses
148-
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
149-
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
150-
tokenizer (`AutoModel`):
151-
Tokenizer of class
152-
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
146+
text_encoder ([`Gemma2PreTrainedModel`]):
147+
Frozen Gemma2 text-encoder.
148+
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
149+
Gemma tokenizer.
153150
transformer ([`Transformer2DModel`]):
154151
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
155152
scheduler ([`SchedulerMixin`]):
@@ -165,8 +162,8 @@ def __init__(
165162
transformer: Lumina2Transformer2DModel,
166163
scheduler: FlowMatchEulerDiscreteScheduler,
167164
vae: AutoencoderKL,
168-
text_encoder: AutoModel,
169-
tokenizer: AutoTokenizer,
165+
text_encoder: Gemma2PreTrainedModel,
166+
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
170167
):
171168
super().__init__()
172169

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Callable, Dict, List, Optional, Tuple, Union
2121

2222
import torch
23-
from transformers import AutoModelForCausalLM, AutoTokenizer
23+
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
2424

2525
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2626
from ...image_processor import PixArtImageProcessor
@@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
160160

161161
def __init__(
162162
self,
163-
tokenizer: AutoTokenizer,
164-
text_encoder: AutoModelForCausalLM,
163+
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
164+
text_encoder: Gemma2PreTrainedModel,
165165
vae: AutoencoderDC,
166166
transformer: SanaTransformer2DModel,
167167
scheduler: FlowMatchEulerDiscreteScheduler,

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import re
1818
import warnings
1919
from pathlib import Path
20-
from typing import Any, Callable, Dict, List, Optional, Union
20+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin
2121

2222
import requests
2323
import torch
@@ -1059,3 +1059,76 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
10591059
break
10601060
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
10611061
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
1062+
1063+
1064+
def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
1065+
"""
1066+
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
1067+
the correct type as well.
1068+
"""
1069+
if not isinstance(class_or_tuple, tuple):
1070+
class_or_tuple = (class_or_tuple,)
1071+
1072+
# Unpack unions
1073+
unpacked_class_or_tuple = []
1074+
for t in class_or_tuple:
1075+
if get_origin(t) is Union:
1076+
unpacked_class_or_tuple.extend(get_args(t))
1077+
else:
1078+
unpacked_class_or_tuple.append(t)
1079+
class_or_tuple = tuple(unpacked_class_or_tuple)
1080+
1081+
if Any in class_or_tuple:
1082+
return True
1083+
1084+
obj_type = type(obj)
1085+
# Classes with obj's type
1086+
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
1087+
1088+
# Singular types (e.g. int, ControlNet, ...)
1089+
# Untyped collections (e.g. List, but not List[int])
1090+
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
1091+
if () in elem_class_or_tuple:
1092+
return True
1093+
# Typed lists or sets
1094+
elif obj_type in (list, set):
1095+
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
1096+
# Typed tuples
1097+
elif obj_type is tuple:
1098+
return any(
1099+
# Tuples with any length and single type (e.g. Tuple[int, ...])
1100+
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
1101+
or
1102+
# Tuples with fixed length and any types (e.g. Tuple[int, str])
1103+
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
1104+
for t in elem_class_or_tuple
1105+
)
1106+
# Typed dicts
1107+
elif obj_type is dict:
1108+
return any(
1109+
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
1110+
for kt, vt in elem_class_or_tuple
1111+
)
1112+
1113+
else:
1114+
return False
1115+
1116+
1117+
def _get_detailed_type(obj: Any) -> Type:
1118+
"""
1119+
Gets a detailed type for an object, including nested types for collections.
1120+
"""
1121+
obj_type = type(obj)
1122+
1123+
if obj_type in (list, set):
1124+
obj_origin_type = List if obj_type is list else Set
1125+
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
1126+
return obj_origin_type[elems_type]
1127+
elif obj_type is tuple:
1128+
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
1129+
elif obj_type is dict:
1130+
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
1131+
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
1132+
return Dict[keys_type, values_type]
1133+
else:
1134+
return obj_type

0 commit comments

Comments
 (0)