Skip to content

Commit 119d734

Browse files
authored
[AnimateDiff+Controlnet] Fix multicontrolnet support (#6551)
* fix multicontrolnet support * update README with multicontrolnet example
1 parent cb4b3f0 commit 119d734

File tree

2 files changed

+94
-28
lines changed

2 files changed

+94
-28
lines changed

examples/community/README.md

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2989,7 +2989,7 @@ pipe = DiffusionPipeline.from_pretrained(
29892989
custom_pipeline="pipeline_animatediff_controlnet",
29902990
).to(device="cuda", dtype=torch.float16)
29912991
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
2992-
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
2992+
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
29932993
)
29942994
pipe.enable_vae_slicing()
29952995
@@ -3005,7 +3005,7 @@ result = pipe(
30053005
width=512,
30063006
height=768,
30073007
conditioning_frames=conditioning_frames,
3008-
num_inference_steps=12,
3008+
num_inference_steps=20,
30093009
).frames[0]
30103010
30113011
from diffusers.utils import export_to_gif
@@ -3029,6 +3029,79 @@ export_to_gif(result.frames[0], "result.gif")
30293029
</tr>
30303030
</table>
30313031

3032+
You can also use multiple controlnets at once!
3033+
3034+
```python
3035+
import torch
3036+
from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter
3037+
from diffusers.pipelines import DiffusionPipeline
3038+
from diffusers.schedulers import DPMSolverMultistepScheduler
3039+
from PIL import Image
3040+
3041+
motion_id = "guoyww/animatediff-motion-adapter-v1-5-2"
3042+
adapter = MotionAdapter.from_pretrained(motion_id)
3043+
controlnet1 = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16)
3044+
controlnet2 = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
3045+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
3046+
3047+
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
3048+
pipe = DiffusionPipeline.from_pretrained(
3049+
model_id,
3050+
motion_adapter=adapter,
3051+
controlnet=[controlnet1, controlnet2],
3052+
vae=vae,
3053+
custom_pipeline="pipeline_animatediff_controlnet",
3054+
).to(device="cuda", dtype=torch.float16)
3055+
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
3056+
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
3057+
)
3058+
pipe.enable_vae_slicing()
3059+
3060+
def load_video(file_path: str):
3061+
images = []
3062+
3063+
if file_path.startswith(('http://', 'https://')):
3064+
# If the file_path is a URL
3065+
response = requests.get(file_path)
3066+
response.raise_for_status()
3067+
content = BytesIO(response.content)
3068+
vid = imageio.get_reader(content)
3069+
else:
3070+
# Assuming it's a local file path
3071+
vid = imageio.get_reader(file_path)
3072+
3073+
for frame in vid:
3074+
pil_image = Image.fromarray(frame)
3075+
images.append(pil_image)
3076+
3077+
return images
3078+
3079+
video = load_video("dance.gif")
3080+
3081+
# You need to install it using `pip install controlnet_aux`
3082+
from controlnet_aux.processor import Processor
3083+
3084+
p1 = Processor("openpose_full")
3085+
cn1 = [p1(frame) for frame in video]
3086+
3087+
p2 = Processor("canny")
3088+
cn2 = [p2(frame) for frame in video]
3089+
3090+
prompt = "astronaut in space, dancing"
3091+
negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
3092+
result = pipe(
3093+
prompt=prompt,
3094+
negative_prompt=negative_prompt,
3095+
width=512,
3096+
height=768,
3097+
conditioning_frames=[cn1, cn2],
3098+
num_inference_steps=20,
3099+
)
3100+
3101+
from diffusers.utils import export_to_gif
3102+
export_to_gif(result.frames[0], "result.gif")
3103+
```
3104+
30323105
### DemoFusion
30333106

30343107
This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973).

examples/community/pipeline_animatediff_controlnet.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import inspect
1616
from dataclasses import dataclass
17-
from typing import Any, Callable, Dict, List, Optional, Union
17+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1818

1919
import numpy as np
2020
import torch
@@ -66,7 +66,7 @@
6666
... custom_pipeline="pipeline_animatediff_controlnet",
6767
... ).to(device="cuda", dtype=torch.float16)
6868
>>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
69-
... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
69+
... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
7070
... )
7171
>>> pipe.enable_vae_slicing()
7272
@@ -83,7 +83,7 @@
8383
... height=768,
8484
... conditioning_frames=conditioning_frames,
8585
... num_inference_steps=12,
86-
... ).frames[0]
86+
... )
8787
8888
>>> from diffusers.utils import export_to_gif
8989
>>> export_to_gif(result.frames[0], "result.gif")
@@ -151,7 +151,7 @@ def __init__(
151151
tokenizer: CLIPTokenizer,
152152
unet: UNet2DConditionModel,
153153
motion_adapter: MotionAdapter,
154-
controlnet: Union[ControlNetModel, MultiControlNetModel],
154+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
155155
scheduler: Union[
156156
DDIMScheduler,
157157
PNDMScheduler,
@@ -166,6 +166,9 @@ def __init__(
166166
super().__init__()
167167
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
168168

169+
if isinstance(controlnet, (list, tuple)):
170+
controlnet = MultiControlNetModel(controlnet)
171+
169172
self.register_modules(
170173
vae=vae,
171174
text_encoder=text_encoder,
@@ -488,6 +491,7 @@ def check_inputs(
488491
prompt,
489492
height,
490493
width,
494+
num_frames,
491495
callback_steps,
492496
negative_prompt=None,
493497
prompt_embeds=None,
@@ -557,31 +561,21 @@ def check_inputs(
557561
or is_compiled
558562
and isinstance(self.controlnet._orig_mod, ControlNetModel)
559563
):
560-
if isinstance(image, list):
561-
for image_ in image:
562-
self.check_image(image_, prompt, prompt_embeds)
563-
else:
564-
self.check_image(image, prompt, prompt_embeds)
564+
if not isinstance(image, list):
565+
raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(image)}")
566+
if len(image) != num_frames:
567+
raise ValueError(f"Excepted image to have length {num_frames} but got {len(image)=}")
565568
elif (
566569
isinstance(self.controlnet, MultiControlNetModel)
567570
or is_compiled
568571
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
569572
):
570-
if not isinstance(image, list):
571-
raise TypeError("For multiple controlnets: `image` must be type `list`")
572-
573-
# When `image` is a nested list:
574-
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
575-
elif any(isinstance(i, list) for i in image):
576-
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
577-
elif len(image) != len(self.controlnet.nets):
578-
raise ValueError(
579-
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
580-
)
581-
582-
for control_ in image:
583-
for image_ in control_:
584-
self.check_image(image_, prompt, prompt_embeds)
573+
if not isinstance(image, list) or not isinstance(image[0], list):
574+
raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(image)=}")
575+
if len(image[0]) != num_frames:
576+
raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(image[0])=}")
577+
if any(len(img) != len(image[0]) for img in image):
578+
raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
585579
else:
586580
assert False
587581

@@ -913,6 +907,7 @@ def __call__(
913907
prompt=prompt,
914908
height=height,
915909
width=width,
910+
num_frames=num_frames,
916911
callback_steps=callback_steps,
917912
negative_prompt=negative_prompt,
918913
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
@@ -1000,9 +995,7 @@ def __call__(
1000995
do_classifier_free_guidance=self.do_classifier_free_guidance,
1001996
guess_mode=guess_mode,
1002997
)
1003-
1004998
cond_prepared_frames.append(prepared_frame)
1005-
1006999
conditioning_frames = cond_prepared_frames
10071000
else:
10081001
assert False

0 commit comments

Comments
 (0)