14
14
15
15
import inspect
16
16
from dataclasses import dataclass
17
- from typing import Any , Callable , Dict , List , Optional , Union
17
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
18
18
19
19
import numpy as np
20
20
import torch
66
66
... custom_pipeline="pipeline_animatediff_controlnet",
67
67
... ).to(device="cuda", dtype=torch.float16)
68
68
>>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
69
- ... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
69
+ ... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
70
70
... )
71
71
>>> pipe.enable_vae_slicing()
72
72
83
83
... height=768,
84
84
... conditioning_frames=conditioning_frames,
85
85
... num_inference_steps=12,
86
- ... ).frames[0]
86
+ ... )
87
87
88
88
>>> from diffusers.utils import export_to_gif
89
89
>>> export_to_gif(result.frames[0], "result.gif")
@@ -151,7 +151,7 @@ def __init__(
151
151
tokenizer : CLIPTokenizer ,
152
152
unet : UNet2DConditionModel ,
153
153
motion_adapter : MotionAdapter ,
154
- controlnet : Union [ControlNetModel , MultiControlNetModel ],
154
+ controlnet : Union [ControlNetModel , List [ ControlNetModel ], Tuple [ ControlNetModel ], MultiControlNetModel ],
155
155
scheduler : Union [
156
156
DDIMScheduler ,
157
157
PNDMScheduler ,
@@ -166,6 +166,9 @@ def __init__(
166
166
super ().__init__ ()
167
167
unet = UNetMotionModel .from_unet2d (unet , motion_adapter )
168
168
169
+ if isinstance (controlnet , (list , tuple )):
170
+ controlnet = MultiControlNetModel (controlnet )
171
+
169
172
self .register_modules (
170
173
vae = vae ,
171
174
text_encoder = text_encoder ,
@@ -488,6 +491,7 @@ def check_inputs(
488
491
prompt ,
489
492
height ,
490
493
width ,
494
+ num_frames ,
491
495
callback_steps ,
492
496
negative_prompt = None ,
493
497
prompt_embeds = None ,
@@ -557,31 +561,21 @@ def check_inputs(
557
561
or is_compiled
558
562
and isinstance (self .controlnet ._orig_mod , ControlNetModel )
559
563
):
560
- if isinstance (image , list ):
561
- for image_ in image :
562
- self .check_image (image_ , prompt , prompt_embeds )
563
- else :
564
- self .check_image (image , prompt , prompt_embeds )
564
+ if not isinstance (image , list ):
565
+ raise TypeError (f"For single controlnet, `image` must be of type `list` but got { type (image )} " )
566
+ if len (image ) != num_frames :
567
+ raise ValueError (f"Excepted image to have length { num_frames } but got { len (image )= } " )
565
568
elif (
566
569
isinstance (self .controlnet , MultiControlNetModel )
567
570
or is_compiled
568
571
and isinstance (self .controlnet ._orig_mod , MultiControlNetModel )
569
572
):
570
- if not isinstance (image , list ):
571
- raise TypeError ("For multiple controlnets: `image` must be type `list`" )
572
-
573
- # When `image` is a nested list:
574
- # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
575
- elif any (isinstance (i , list ) for i in image ):
576
- raise ValueError ("A single batch of multiple conditionings are supported at the moment." )
577
- elif len (image ) != len (self .controlnet .nets ):
578
- raise ValueError (
579
- f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got { len (image )} images and { len (self .controlnet .nets )} ControlNets."
580
- )
581
-
582
- for control_ in image :
583
- for image_ in control_ :
584
- self .check_image (image_ , prompt , prompt_embeds )
573
+ if not isinstance (image , list ) or not isinstance (image [0 ], list ):
574
+ raise TypeError (f"For multiple controlnets: `image` must be type list of lists but got { type (image )= } " )
575
+ if len (image [0 ]) != num_frames :
576
+ raise ValueError (f"Expected length of image sublist as { num_frames } but got { len (image [0 ])= } " )
577
+ if any (len (img ) != len (image [0 ]) for img in image ):
578
+ raise ValueError ("All conditioning frame batches for multicontrolnet must be same size" )
585
579
else :
586
580
assert False
587
581
@@ -913,6 +907,7 @@ def __call__(
913
907
prompt = prompt ,
914
908
height = height ,
915
909
width = width ,
910
+ num_frames = num_frames ,
916
911
callback_steps = callback_steps ,
917
912
negative_prompt = negative_prompt ,
918
913
callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
@@ -1000,9 +995,7 @@ def __call__(
1000
995
do_classifier_free_guidance = self .do_classifier_free_guidance ,
1001
996
guess_mode = guess_mode ,
1002
997
)
1003
-
1004
998
cond_prepared_frames .append (prepared_frame )
1005
-
1006
999
conditioning_frames = cond_prepared_frames
1007
1000
else :
1008
1001
assert False
0 commit comments