Skip to content

Commit 5ca27aa

Browse files
committed
Type checks subclasses and fixed type warnings
1 parent b1f26c5 commit 5ca27aa

17 files changed

+48
-38
lines changed

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def __init__(
224224
vae: AutoencoderKL,
225225
text_encoder: CLIPTextModel,
226226
tokenizer: CLIPTokenizer,
227-
unet: UNet2DConditionModel,
227+
unet: Union[UNet2DConditionModel, UNetMotionModel],
228228
motion_adapter: MotionAdapter,
229229
scheduler: Union[
230230
DDIMScheduler,

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(
246246
vae: AutoencoderKL,
247247
text_encoder: CLIPTextModel,
248248
tokenizer: CLIPTokenizer,
249-
unet: UNet2DConditionModel,
249+
unet: Union[UNet2DConditionModel, UNetMotionModel],
250250
motion_adapter: MotionAdapter,
251251
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
252252
scheduler: Union[

src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ def __init__(
232232
Tuple[HunyuanDiT2DControlNetModel],
233233
HunyuanDiT2DMultiControlNetModel,
234234
],
235-
text_encoder_2=T5EncoderModel,
236-
tokenizer_2=MT5Tokenizer,
235+
text_encoder_2: Optional[T5EncoderModel] = None,
236+
tokenizer_2: Optional[MT5Tokenizer] = None,
237237
requires_safety_checker: bool = True,
238238
):
239239
super().__init__()

src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import torch
1919

20+
from ...models import UNet1DModel
21+
from ...schedulers import SchedulerMixin
2022
from ...utils import is_torch_xla_available, logging
2123
from ...utils.torch_utils import randn_tensor
2224
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
@@ -49,7 +51,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
4951

5052
model_cpu_offload_seq = "unet"
5153

52-
def __init__(self, unet, scheduler):
54+
def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
5355
super().__init__()
5456
self.register_modules(unet=unet, scheduler=scheduler)
5557

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818

19+
from ...models import UNet2DModel
1920
from ...schedulers import DDIMScheduler
2021
from ...utils import is_torch_xla_available
2122
from ...utils.torch_utils import randn_tensor
@@ -47,7 +48,7 @@ class DDIMPipeline(DiffusionPipeline):
4748

4849
model_cpu_offload_seq = "unet"
4950

50-
def __init__(self, unet, scheduler):
51+
def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler):
5152
super().__init__()
5253

5354
# make sure scheduler can always be converted to DDIM

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import torch
1919

20+
from ...models import UNet2DModel
21+
from ...schedulers import DDPMScheduler
2022
from ...utils import is_torch_xla_available
2123
from ...utils.torch_utils import randn_tensor
2224
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -47,7 +49,7 @@ class DDPMPipeline(DiffusionPipeline):
4749

4850
model_cpu_offload_seq = "unet"
4951

50-
def __init__(self, unet, scheduler):
52+
def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler):
5153
super().__init__()
5254
self.register_modules(unet=unet, scheduler=scheduler)
5355

src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline):
9191
scheduler: RePaintScheduler
9292
model_cpu_offload_seq = "unet"
9393

94-
def __init__(self, unet, scheduler):
94+
def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler):
9595
super().__init__()
9696
self.register_modules(unet=unet, scheduler=scheduler)
9797

src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def __init__(
207207
safety_checker: StableDiffusionSafetyChecker,
208208
feature_extractor: CLIPImageProcessor,
209209
requires_safety_checker: bool = True,
210-
text_encoder_2=T5EncoderModel,
211-
tokenizer_2=MT5Tokenizer,
210+
text_encoder_2: Optional[T5EncoderModel] = None,
211+
tokenizer_2: Optional[MT5Tokenizer] = None,
212212
):
213213
super().__init__()
214214

src/diffusers/pipelines/lumina/pipeline_lumina.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import List, Optional, Tuple, Union
2121

2222
import torch
23-
from transformers import AutoModel, AutoTokenizer
23+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
2424

2525
from ...image_processor import VaeImageProcessor
2626
from ...models import AutoencoderKL
@@ -143,13 +143,13 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
143143
Args:
144144
vae ([`AutoencoderKL`]):
145145
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
146-
text_encoder ([`AutoModel`]):
146+
text_encoder ([`PreTrainedModel`]):
147147
Frozen text-encoder. Lumina-T2I uses
148148
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
149149
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
150-
tokenizer (`AutoModel`):
150+
tokenizer (`AutoTokenizer`):
151151
Tokenizer of class
152-
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
152+
[AutoTokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
153153
transformer ([`Transformer2DModel`]):
154154
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
155155
scheduler ([`SchedulerMixin`]):
@@ -180,8 +180,8 @@ def __init__(
180180
transformer: LuminaNextDiT2DModel,
181181
scheduler: FlowMatchEulerDiscreteScheduler,
182182
vae: AutoencoderKL,
183-
text_encoder: AutoModel,
184-
tokenizer: AutoTokenizer,
183+
text_encoder: PreTrainedModel,
184+
tokenizer: PreTrainedTokenizerBase,
185185
):
186186
super().__init__()
187187

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Callable, Dict, List, Optional, Tuple, Union
2121

2222
import torch
23-
from transformers import AutoModelForCausalLM, AutoTokenizer
23+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
2424

2525
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2626
from ...image_processor import PixArtImageProcessor
@@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
160160

161161
def __init__(
162162
self,
163-
tokenizer: AutoTokenizer,
164-
text_encoder: AutoModelForCausalLM,
163+
tokenizer: PreTrainedTokenizerBase,
164+
text_encoder: PreTrainedModel,
165165
vae: AutoencoderDC,
166166
transformer: SanaTransformer2DModel,
167167
scheduler: FlowMatchEulerDiscreteScheduler,

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,7 @@ def is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bo
10011001

10021002
obj_type = type(obj)
10031003
# Classes with obj's type
1004-
class_or_tuple = {t for t in class_or_tuple if (get_origin(t) or t) is obj_type}
1004+
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
10051005

10061006
# Singular types (e.g. int, ControlNet, ...)
10071007
# Untyped collections (e.g. List, but not List[int])

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2121

2222
import torch
23-
from transformers import AutoModelForCausalLM, AutoTokenizer
23+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
2424

2525
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2626
from ...image_processor import PixArtImageProcessor
@@ -200,8 +200,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
200200

201201
def __init__(
202202
self,
203-
tokenizer: AutoTokenizer,
204-
text_encoder: AutoModelForCausalLM,
203+
tokenizer: PreTrainedTokenizerBase,
204+
text_encoder: PreTrainedModel,
205205
vae: AutoencoderDC,
206206
transformer: SanaTransformer2DModel,
207207
scheduler: DPMSolverMultistepScheduler,

src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Callable, Dict, List, Optional, Union
1616

1717
import torch
18-
from transformers import CLIPTextModel, CLIPTokenizer
18+
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
1919

2020
from ...models import StableCascadeUNet
2121
from ...schedulers import DDPMWuerstchenScheduler
@@ -65,7 +65,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
6565
Args:
6666
tokenizer (`CLIPTokenizer`):
6767
The CLIP tokenizer.
68-
text_encoder (`CLIPTextModel`):
68+
text_encoder (`CLIPTextModelWithProjection`):
6969
The CLIP text encoder.
7070
decoder ([`StableCascadeUNet`]):
7171
The Stable Cascade decoder unet.
@@ -93,7 +93,7 @@ def __init__(
9393
self,
9494
decoder: StableCascadeUNet,
9595
tokenizer: CLIPTokenizer,
96-
text_encoder: CLIPTextModel,
96+
text_encoder: CLIPTextModelWithProjection,
9797
scheduler: DDPMWuerstchenScheduler,
9898
vqgan: PaellaVQModel,
9999
latent_dim_scale: float = 10.67,

src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import PIL
1717
import torch
18-
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
18+
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
1919

2020
from ...models import StableCascadeUNet
2121
from ...schedulers import DDPMWuerstchenScheduler
@@ -52,22 +52,26 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
5252
Args:
5353
tokenizer (`CLIPTokenizer`):
5454
The decoder tokenizer to be used for text inputs.
55-
text_encoder (`CLIPTextModel`):
55+
text_encoder (`CLIPTextModelWithProjection`):
5656
The decoder text encoder to be used for text inputs.
5757
decoder (`StableCascadeUNet`):
5858
The decoder model to be used for decoder image generation pipeline.
5959
scheduler (`DDPMWuerstchenScheduler`):
6060
The scheduler to be used for decoder image generation pipeline.
6161
vqgan (`PaellaVQModel`):
6262
The VQGAN model to be used for decoder image generation pipeline.
63-
feature_extractor ([`~transformers.CLIPImageProcessor`]):
64-
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
65-
image_encoder ([`CLIPVisionModelWithProjection`]):
66-
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
6763
prior_prior (`StableCascadeUNet`):
6864
The prior model to be used for prior pipeline.
65+
prior_text_encoder (`CLIPTextModelWithProjection`):
66+
The prior text encoder to be used for text inputs.
67+
prior_tokenizer (`CLIPTokenizer`):
68+
The prior tokenizer to be used for text inputs.
6969
prior_scheduler (`DDPMWuerstchenScheduler`):
7070
The scheduler to be used for prior pipeline.
71+
prior_feature_extractor ([`~transformers.CLIPImageProcessor`]):
72+
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
73+
prior_image_encoder ([`CLIPVisionModelWithProjection`]):
74+
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
7175
"""
7276

7377
_load_connected_pipes = True
@@ -76,12 +80,12 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
7680
def __init__(
7781
self,
7882
tokenizer: CLIPTokenizer,
79-
text_encoder: CLIPTextModel,
83+
text_encoder: CLIPTextModelWithProjection,
8084
decoder: StableCascadeUNet,
8185
scheduler: DDPMWuerstchenScheduler,
8286
vqgan: PaellaVQModel,
8387
prior_prior: StableCascadeUNet,
84-
prior_text_encoder: CLIPTextModel,
88+
prior_text_encoder: CLIPTextModelWithProjection,
8589
prior_tokenizer: CLIPTokenizer,
8690
prior_scheduler: DDPMWuerstchenScheduler,
8791
prior_feature_extractor: Optional[CLIPImageProcessor] = None,

src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
image_noising_scheduler: KarrasDiffusionSchedulers,
142142
# regular denoising components
143143
tokenizer: CLIPTokenizer,
144-
text_encoder: CLIPTextModelWithProjection,
144+
text_encoder: CLIPTextModel,
145145
unet: UNet2DConditionModel,
146146
scheduler: KarrasDiffusionSchedulers,
147147
# vae

tests/fixtures/custom_pipeline/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch
2020

21-
from diffusers import DiffusionPipeline, ImagePipelineOutput
21+
from diffusers import DiffusionPipeline, ImagePipelineOutput, SchedulerMixin, UNet2DModel
2222

2323

2424
class CustomLocalPipeline(DiffusionPipeline):
@@ -33,7 +33,7 @@ class CustomLocalPipeline(DiffusionPipeline):
3333
[`DDPMScheduler`], or [`DDIMScheduler`].
3434
"""
3535

36-
def __init__(self, unet, scheduler):
36+
def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
3737
super().__init__()
3838
self.register_modules(unet=unet, scheduler=scheduler)
3939

tests/fixtures/custom_pipeline/what_ever.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020

21+
from diffusers import SchedulerMixin, UNet2DModel
2122
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2223

2324

@@ -33,7 +34,7 @@ class CustomLocalPipeline(DiffusionPipeline):
3334
[`DDPMScheduler`], or [`DDIMScheduler`].
3435
"""
3536

36-
def __init__(self, unet, scheduler):
37+
def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
3738
super().__init__()
3839
self.register_modules(unet=unet, scheduler=scheduler)
3940

0 commit comments

Comments
 (0)