Skip to content

Commit 3a66113

Browse files
authored
[Community] Bug fix + Latest IP-Adapter impl. for AnimateDiff img2vid/controlnet (#7086)
* fix img2vid; update to latest ip-adapter impl * update README * update animatediff controlnet to latest impl
1 parent 7f16187 commit 3a66113

File tree

3 files changed

+115
-27
lines changed

3 files changed

+115
-27
lines changed

examples/community/README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3561,14 +3561,17 @@ pipe.disable_style_aligned()
35613561

35623562
This pipeline adds experimental support for the image-to-video task using AnimateDiff. Refer to [this](https://github.com/huggingface/diffusers/pull/6328) PR for more examples and results.
35633563

3564+
This pipeline relies on a "hack" discovered by the community that allows the generation of videos given an input image with AnimateDiff. It works by creating a copy of the image `num_frames` times and progressively adding more noise to the image based on the strength and latent interpolation method.
3565+
35643566
```py
35653567
import torch
35663568
from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
35673569
from diffusers.utils import export_to_gif, load_image
35683570

3571+
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
35693572
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
3570-
pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
3571-
pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
3573+
pipe = DiffusionPipeline.from_pretrained(model_id, motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
3574+
pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", beta_schedule="linear", steps_offset=1)
35723575

35733576
image = load_image("snail.png")
35743577
output = pipe(

examples/community/pipeline_animatediff_controlnet.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
2626
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27-
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel
27+
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
2828
from diffusers.models.lora import adjust_lora_scale_text_encoder
2929
from diffusers.models.unets.unet_motion_model import MotionAdapter
3030
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
@@ -382,6 +382,41 @@ def encode_image(self, image, device, num_images_per_prompt):
382382
uncond_image_embeds = torch.zeros_like(image_embeds)
383383
return image_embeds, uncond_image_embeds
384384

385+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
386+
def prepare_ip_adapter_image_embeds(
387+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
388+
):
389+
if ip_adapter_image_embeds is None:
390+
if not isinstance(ip_adapter_image, list):
391+
ip_adapter_image = [ip_adapter_image]
392+
393+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
394+
raise ValueError(
395+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
396+
)
397+
398+
image_embeds = []
399+
for single_ip_adapter_image, image_proj_layer in zip(
400+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
401+
):
402+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
403+
single_image_embeds, single_negative_image_embeds = self.encode_image(
404+
single_ip_adapter_image, device, 1, output_hidden_state
405+
)
406+
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
407+
single_negative_image_embeds = torch.stack(
408+
[single_negative_image_embeds] * num_images_per_prompt, dim=0
409+
)
410+
411+
if self.do_classifier_free_guidance:
412+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
413+
single_image_embeds = single_image_embeds.to(device)
414+
415+
image_embeds.append(single_image_embeds)
416+
else:
417+
image_embeds = ip_adapter_image_embeds
418+
return image_embeds
419+
385420
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
386421
def decode_latents(self, latents):
387422
latents = 1 / self.vae.config.scaling_factor * latents
@@ -767,6 +802,7 @@ def __call__(
767802
prompt_embeds: Optional[torch.FloatTensor] = None,
768803
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
769804
ip_adapter_image: Optional[PipelineImageInput] = None,
805+
ip_adapter_image_embeds: Optional[PipelineImageInput] = None,
770806
conditioning_frames: Optional[List[PipelineImageInput]] = None,
771807
output_type: Optional[str] = "pil",
772808
return_dict: bool = True,
@@ -821,6 +857,9 @@ def __call__(
821857
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
822858
ip_adapter_image (`PipelineImageInput`, *optional*):
823859
Optional image input to work with IP Adapters.
860+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
861+
Pre-generated image embeddings for IP-Adapter. If not
862+
provided, embeddings are computed from the `ip_adapter_image` input argument.
824863
conditioning_frames (`List[PipelineImageInput]`, *optional*):
825864
The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets
826865
are specified, images must be passed as a list such that each element of the list can be correctly
@@ -965,9 +1004,9 @@ def __call__(
9651004
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
9661005

9671006
if ip_adapter_image is not None:
968-
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt)
969-
if self.do_classifier_free_guidance:
970-
image_embeds = torch.cat([negative_image_embeds, image_embeds])
1007+
image_embeds = self.prepare_ip_adapter_image_embeds(
1008+
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
1009+
)
9711010

9721011
if isinstance(controlnet, ControlNetModel):
9731012
conditioning_frames = self.prepare_image(
@@ -1023,7 +1062,11 @@ def __call__(
10231062
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
10241063

10251064
# 7. Add image embeds for IP-Adapter
1026-
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1065+
added_cond_kwargs = (
1066+
{"image_embeds": image_embeds}
1067+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1068+
else None
1069+
)
10271070

10281071
# 7.1 Create tensor stating which controlnets to keep
10291072
controlnet_keep = []

examples/community/pipeline_animatediff_img2video.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
#
15+
# Note:
16+
# This pipeline relies on a "hack" discovered by the community that allows
17+
# the generation of videos given an input image with AnimateDiff. It works
18+
# by creating a copy of the image `num_frames` times and progressively adding
19+
# more noise to the image based on the strength and latent interpolation method.
1420

1521
import inspect
16-
from dataclasses import dataclass
1722
from types import FunctionType
1823
from typing import Any, Callable, Dict, List, Optional, Union
1924

@@ -25,7 +30,8 @@
2530
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2631
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
2732
from diffusers.models.lora import adjust_lora_scale_text_encoder
28-
from diffusers.models.unet_motion_model import MotionAdapter
33+
from diffusers.models.unets.unet_motion_model import MotionAdapter
34+
from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
2935
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
3036
from diffusers.schedulers import (
3137
DDIMScheduler,
@@ -35,7 +41,7 @@
3541
LMSDiscreteScheduler,
3642
PNDMScheduler,
3743
)
38-
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
44+
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
3945
from diffusers.utils.torch_utils import randn_tensor
4046

4147

@@ -48,9 +54,10 @@
4854
>>> from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
4955
>>> from diffusers.utils import export_to_gif, load_image
5056
57+
>>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
5158
>>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
5259
>>> pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
53-
>>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
60+
>>> pipe.scheduler = pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", beta_schedule="linear", steps_offset=1)
5461
5562
>>> image = load_image("snail.png")
5663
>>> output = pipe(image=image, prompt="A snail moving on the ground", strength=0.8, latent_interpolation_method="slerp")
@@ -225,14 +232,9 @@ def retrieve_timesteps(
225232
return timesteps, num_inference_steps
226233

227234

228-
@dataclass
229-
class AnimateDiffImgToVideoPipelineOutput(BaseOutput):
230-
frames: Union[torch.Tensor, np.ndarray]
231-
232-
233235
class AnimateDiffImgToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
234236
r"""
235-
Pipeline for text-to-video generation.
237+
Pipeline for image-to-video generation.
236238
237239
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
238240
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -503,6 +505,41 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
503505

504506
return image_embeds, uncond_image_embeds
505507

508+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
509+
def prepare_ip_adapter_image_embeds(
510+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
511+
):
512+
if ip_adapter_image_embeds is None:
513+
if not isinstance(ip_adapter_image, list):
514+
ip_adapter_image = [ip_adapter_image]
515+
516+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
517+
raise ValueError(
518+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
519+
)
520+
521+
image_embeds = []
522+
for single_ip_adapter_image, image_proj_layer in zip(
523+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
524+
):
525+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
526+
single_image_embeds, single_negative_image_embeds = self.encode_image(
527+
single_ip_adapter_image, device, 1, output_hidden_state
528+
)
529+
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
530+
single_negative_image_embeds = torch.stack(
531+
[single_negative_image_embeds] * num_images_per_prompt, dim=0
532+
)
533+
534+
if self.do_classifier_free_guidance:
535+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
536+
single_image_embeds = single_image_embeds.to(device)
537+
538+
image_embeds.append(single_image_embeds)
539+
else:
540+
image_embeds = ip_adapter_image_embeds
541+
return image_embeds
542+
506543
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
507544
def decode_latents(self, latents):
508545
latents = 1 / self.vae.config.scaling_factor * latents
@@ -765,6 +802,7 @@ def __call__(
765802
prompt_embeds: Optional[torch.FloatTensor] = None,
766803
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
767804
ip_adapter_image: Optional[PipelineImageInput] = None,
805+
ip_adapter_image_embeds: Optional[PipelineImageInput] = None,
768806
output_type: Optional[str] = "pil",
769807
return_dict: bool = True,
770808
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -818,6 +856,9 @@ def __call__(
818856
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
819857
ip_adapter_image: (`PipelineImageInput`, *optional*):
820858
Optional image input to work with IP Adapters.
859+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
860+
Pre-generated image embeddings for IP-Adapter. If not
861+
provided, embeddings are computed from the `ip_adapter_image` input argument.
821862
output_type (`str`, *optional*, defaults to `"pil"`):
822863
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
823864
`np.array`.
@@ -842,8 +883,8 @@ def __call__(
842883
Examples:
843884
844885
Returns:
845-
[`AnimateDiffImgToVideoPipelineOutput`] or `tuple`:
846-
If `return_dict` is `True`, [`AnimateDiffImgToVideoPipelineOutput`] is
886+
[`AnimateDiffPipelineOutput`] or `tuple`:
887+
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
847888
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
848889
"""
849890
# 0. Default height and width to unet
@@ -902,12 +943,9 @@ def __call__(
902943
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
903944

904945
if ip_adapter_image is not None:
905-
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
906-
image_embeds, negative_image_embeds = self.encode_image(
907-
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
946+
image_embeds = self.prepare_ip_adapter_image_embeds(
947+
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
908948
)
909-
if do_classifier_free_guidance:
910-
image_embeds = torch.cat([negative_image_embeds, image_embeds])
911949

912950
# 4. Preprocess image
913951
image = self.image_processor.preprocess(image, height=height, width=width)
@@ -936,7 +974,11 @@ def __call__(
936974
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
937975

938976
# 8. Add image embeds for IP-Adapter
939-
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
977+
added_cond_kwargs = (
978+
{"image_embeds": image_embeds}
979+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
980+
else None
981+
)
940982

941983
# 9. Denoising loop
942984
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@@ -970,7 +1012,7 @@ def __call__(
9701012
callback(i, t, latents)
9711013

9721014
if output_type == "latent":
973-
return AnimateDiffImgToVideoPipelineOutput(frames=latents)
1015+
return AnimateDiffPipelineOutput(frames=latents)
9741016

9751017
# 10. Post-processing
9761018
video_tensor = self.decode_latents(latents)
@@ -986,4 +1028,4 @@ def __call__(
9861028
if not return_dict:
9871029
return (video,)
9881030

989-
return AnimateDiffImgToVideoPipelineOutput(frames=video)
1031+
return AnimateDiffPipelineOutput(frames=video)

0 commit comments

Comments
 (0)