11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ #
15
+ # Note:
16
+ # This pipeline relies on a "hack" discovered by the community that allows
17
+ # the generation of videos given an input image with AnimateDiff. It works
18
+ # by creating a copy of the image `num_frames` times and progressively adding
19
+ # more noise to the image based on the strength and latent interpolation method.
14
20
15
21
import inspect
16
- from dataclasses import dataclass
17
22
from types import FunctionType
18
23
from typing import Any , Callable , Dict , List , Optional , Union
19
24
25
30
from diffusers .loaders import IPAdapterMixin , LoraLoaderMixin , TextualInversionLoaderMixin
26
31
from diffusers .models import AutoencoderKL , ImageProjection , UNet2DConditionModel , UNetMotionModel
27
32
from diffusers .models .lora import adjust_lora_scale_text_encoder
28
- from diffusers .models .unet_motion_model import MotionAdapter
33
+ from diffusers .models .unets .unet_motion_model import MotionAdapter
34
+ from diffusers .pipelines .animatediff .pipeline_output import AnimateDiffPipelineOutput
29
35
from diffusers .pipelines .pipeline_utils import DiffusionPipeline
30
36
from diffusers .schedulers import (
31
37
DDIMScheduler ,
35
41
LMSDiscreteScheduler ,
36
42
PNDMScheduler ,
37
43
)
38
- from diffusers .utils import USE_PEFT_BACKEND , BaseOutput , logging , scale_lora_layers , unscale_lora_layers
44
+ from diffusers .utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
39
45
from diffusers .utils .torch_utils import randn_tensor
40
46
41
47
48
54
>>> from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
49
55
>>> from diffusers.utils import export_to_gif, load_image
50
56
57
+ >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
51
58
>>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
52
59
>>> pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
53
- >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear ", steps_offset=1, clip_sample=False, timespace_spacing ="linspace")
60
+ >>> pipe.scheduler = pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler ", clip_sample=False, timestep_spacing ="linspace", beta_schedule="linear", steps_offset=1 )
54
61
55
62
>>> image = load_image("snail.png")
56
63
>>> output = pipe(image=image, prompt="A snail moving on the ground", strength=0.8, latent_interpolation_method="slerp")
@@ -225,14 +232,9 @@ def retrieve_timesteps(
225
232
return timesteps , num_inference_steps
226
233
227
234
228
- @dataclass
229
- class AnimateDiffImgToVideoPipelineOutput (BaseOutput ):
230
- frames : Union [torch .Tensor , np .ndarray ]
231
-
232
-
233
235
class AnimateDiffImgToVideoPipeline (DiffusionPipeline , TextualInversionLoaderMixin , IPAdapterMixin , LoraLoaderMixin ):
234
236
r"""
235
- Pipeline for text -to-video generation.
237
+ Pipeline for image -to-video generation.
236
238
237
239
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
238
240
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -503,6 +505,41 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
503
505
504
506
return image_embeds , uncond_image_embeds
505
507
508
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
509
+ def prepare_ip_adapter_image_embeds (
510
+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt
511
+ ):
512
+ if ip_adapter_image_embeds is None :
513
+ if not isinstance (ip_adapter_image , list ):
514
+ ip_adapter_image = [ip_adapter_image ]
515
+
516
+ if len (ip_adapter_image ) != len (self .unet .encoder_hid_proj .image_projection_layers ):
517
+ raise ValueError (
518
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { len (self .unet .encoder_hid_proj .image_projection_layers )} IP Adapters."
519
+ )
520
+
521
+ image_embeds = []
522
+ for single_ip_adapter_image , image_proj_layer in zip (
523
+ ip_adapter_image , self .unet .encoder_hid_proj .image_projection_layers
524
+ ):
525
+ output_hidden_state = not isinstance (image_proj_layer , ImageProjection )
526
+ single_image_embeds , single_negative_image_embeds = self .encode_image (
527
+ single_ip_adapter_image , device , 1 , output_hidden_state
528
+ )
529
+ single_image_embeds = torch .stack ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
530
+ single_negative_image_embeds = torch .stack (
531
+ [single_negative_image_embeds ] * num_images_per_prompt , dim = 0
532
+ )
533
+
534
+ if self .do_classifier_free_guidance :
535
+ single_image_embeds = torch .cat ([single_negative_image_embeds , single_image_embeds ])
536
+ single_image_embeds = single_image_embeds .to (device )
537
+
538
+ image_embeds .append (single_image_embeds )
539
+ else :
540
+ image_embeds = ip_adapter_image_embeds
541
+ return image_embeds
542
+
506
543
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
507
544
def decode_latents (self , latents ):
508
545
latents = 1 / self .vae .config .scaling_factor * latents
@@ -765,6 +802,7 @@ def __call__(
765
802
prompt_embeds : Optional [torch .FloatTensor ] = None ,
766
803
negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
767
804
ip_adapter_image : Optional [PipelineImageInput ] = None ,
805
+ ip_adapter_image_embeds : Optional [PipelineImageInput ] = None ,
768
806
output_type : Optional [str ] = "pil" ,
769
807
return_dict : bool = True ,
770
808
callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
@@ -818,6 +856,9 @@ def __call__(
818
856
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
819
857
ip_adapter_image: (`PipelineImageInput`, *optional*):
820
858
Optional image input to work with IP Adapters.
859
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
860
+ Pre-generated image embeddings for IP-Adapter. If not
861
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
821
862
output_type (`str`, *optional*, defaults to `"pil"`):
822
863
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
823
864
`np.array`.
@@ -842,8 +883,8 @@ def __call__(
842
883
Examples:
843
884
844
885
Returns:
845
- [`AnimateDiffImgToVideoPipelineOutput `] or `tuple`:
846
- If `return_dict` is `True`, [`AnimateDiffImgToVideoPipelineOutput `] is
886
+ [`AnimateDiffPipelineOutput `] or `tuple`:
887
+ If `return_dict` is `True`, [`AnimateDiffPipelineOutput `] is
847
888
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
848
889
"""
849
890
# 0. Default height and width to unet
@@ -902,12 +943,9 @@ def __call__(
902
943
prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
903
944
904
945
if ip_adapter_image is not None :
905
- output_hidden_state = False if isinstance (self .unet .encoder_hid_proj , ImageProjection ) else True
906
- image_embeds , negative_image_embeds = self .encode_image (
907
- ip_adapter_image , device , num_videos_per_prompt , output_hidden_state
946
+ image_embeds = self .prepare_ip_adapter_image_embeds (
947
+ ip_adapter_image , ip_adapter_image_embeds , device , batch_size * num_videos_per_prompt
908
948
)
909
- if do_classifier_free_guidance :
910
- image_embeds = torch .cat ([negative_image_embeds , image_embeds ])
911
949
912
950
# 4. Preprocess image
913
951
image = self .image_processor .preprocess (image , height = height , width = width )
@@ -936,7 +974,11 @@ def __call__(
936
974
extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
937
975
938
976
# 8. Add image embeds for IP-Adapter
939
- added_cond_kwargs = {"image_embeds" : image_embeds } if ip_adapter_image is not None else None
977
+ added_cond_kwargs = (
978
+ {"image_embeds" : image_embeds }
979
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
980
+ else None
981
+ )
940
982
941
983
# 9. Denoising loop
942
984
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
@@ -970,7 +1012,7 @@ def __call__(
970
1012
callback (i , t , latents )
971
1013
972
1014
if output_type == "latent" :
973
- return AnimateDiffImgToVideoPipelineOutput (frames = latents )
1015
+ return AnimateDiffPipelineOutput (frames = latents )
974
1016
975
1017
# 10. Post-processing
976
1018
video_tensor = self .decode_latents (latents )
@@ -986,4 +1028,4 @@ def __call__(
986
1028
if not return_dict :
987
1029
return (video ,)
988
1030
989
- return AnimateDiffImgToVideoPipelineOutput (frames = video )
1031
+ return AnimateDiffPipelineOutput (frames = video )
0 commit comments