Skip to content

[Pipeline] AnimateDiff SDXL #6721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 50 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
56ba44b
update conversion script to handle motion adapter sdxl checkpoint
a-r-r-o-w Jan 26, 2024
7ae7bc8
add animatediff xl
a-r-r-o-w Jan 26, 2024
01f5978
handle addition_embed_type
a-r-r-o-w Jan 26, 2024
2562500
fix output
a-r-r-o-w Jan 26, 2024
736a224
update
a-r-r-o-w Jan 26, 2024
4a2b9de
add imports
a-r-r-o-w Jan 26, 2024
eb060e0
make fix-copies
a-r-r-o-w Jan 26, 2024
54cd75c
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Jan 26, 2024
c01d2c2
add decode latents
a-r-r-o-w Jan 26, 2024
3d45dc1
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Jan 26, 2024
60364ea
update docstrings
a-r-r-o-w Jan 27, 2024
0db8340
add animatediff sdxl to docs
a-r-r-o-w Jan 27, 2024
bf2cd49
remove unnecessary lines
a-r-r-o-w Jan 27, 2024
389adaa
update example
a-r-r-o-w Jan 27, 2024
f471e3c
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Jan 27, 2024
ba4f9f4
add test
a-r-r-o-w Jan 27, 2024
504e958
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Feb 8, 2024
a3fb232
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Feb 21, 2024
512d346
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Feb 25, 2024
21e6af1
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Feb 26, 2024
1fa606f
revert conv_in conv_out kernel param
a-r-r-o-w Feb 26, 2024
93fc848
remove unused param addition_embed_type_num_heads
a-r-r-o-w Feb 26, 2024
5bbd8ef
latest IPAdapter impl
a-r-r-o-w Feb 26, 2024
9f69127
make fix-copies
a-r-r-o-w Feb 26, 2024
5919f37
fix return
a-r-r-o-w Feb 26, 2024
a09355a
add IPAdapterTesterMixin to tests
a-r-r-o-w Feb 26, 2024
306902f
fix return
a-r-r-o-w Feb 26, 2024
5ba4383
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Mar 24, 2024
feb458a
revert based on suggestion
a-r-r-o-w Mar 24, 2024
7c3807d
add freeinit
a-r-r-o-w Mar 24, 2024
dd996ad
fix test_to_dtype test
a-r-r-o-w Mar 24, 2024
53f5815
use StableDiffusionMixin instead of different helper methods
a-r-r-o-w Mar 24, 2024
971852f
fix progress bar iterations
a-r-r-o-w Mar 24, 2024
dc0fd88
apply suggestions from review
a-r-r-o-w Mar 26, 2024
75bd4e8
hardcode flip_sin_to_cos and freq_shift
a-r-r-o-w Mar 26, 2024
2b618ea
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Mar 26, 2024
124d1c9
make fix-copies
a-r-r-o-w Mar 26, 2024
1081788
fix ip adapter implementation
a-r-r-o-w Mar 28, 2024
f887625
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Mar 28, 2024
4772b0d
fix last failing test
a-r-r-o-w Mar 28, 2024
7df8fab
make style
a-r-r-o-w Mar 28, 2024
b68037d
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Mar 29, 2024
90ebe25
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Apr 2, 2024
3c662ba
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Apr 14, 2024
c8b9d73
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Apr 20, 2024
33d5a18
Update docs/source/en/api/pipelines/animatediff.md
a-r-r-o-w Apr 30, 2024
43e873f
remove todo
a-r-r-o-w Apr 30, 2024
3a9cd0f
Merge branch 'main' into re-animatediff-sdxl
a-r-r-o-w Apr 30, 2024
cf3ebc9
fix doc-builder errors
a-r-r-o-w Apr 30, 2024
1a8d76c
Merge branch 'main' into re-animatediff-sdxl
DN6 May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions docs/source/en/api/pipelines/animatediff.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,53 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you

</Tip>

### AnimateDiffSDXLPipeline

AnimateDiff can also be used with SDXL models. This is currently an experimental feature as only a beta release of the motion adapter checkpoint is available.

```python
import torch
from diffusers.models import MotionAdapter
from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler
from diffusers.utils import export_to_gif

adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16)

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
scheduler = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
clip_sample=False,
timestep_spacing="linspace",
beta_schedule="linear",
steps_offset=1,
)
pipe = AnimateDiffSDXLPipeline.from_pretrained(
model_id,
motion_adapter=adapter,
scheduler=scheduler,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")

# enable memory savings
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()

output = pipe(
prompt="a panda surfing in the ocean, realistic, high quality",
negative_prompt="low quality, worst quality",
num_inference_steps=20,
guidance_scale=8,
width=1024,
height=1024,
num_frames=16,
)

frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```

### AnimateDiffVideoToVideoPipeline

AnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities.
Expand Down Expand Up @@ -522,6 +569,12 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
- all
- __call__

## AnimateDiffSDXLPipeline

[[autodoc]] AnimateDiffSDXLPipeline
- all
- __call__

## AnimateDiffVideoToVideoPipeline

[[autodoc]] AnimateDiffVideoToVideoPipeline
Expand Down
7 changes: 5 additions & 2 deletions scripts/convert_animatediff_motion_module_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--use_motion_mid_block", action="store_true")
parser.add_argument("--motion_max_seq_length", type=int, default=32)
parser.add_argument("--block_out_channels", nargs="+", default=[320, 640, 1280, 1280], type=int)
parser.add_argument("--save_fp16", action="store_true")

return parser.parse_args()
Expand All @@ -49,11 +50,13 @@ def get_args():

conv_state_dict = convert_motion_module(state_dict)
adapter = MotionAdapter(
use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length
block_out_channels=args.block_out_channels,
use_motion_mid_block=args.use_motion_mid_block,
motion_max_seq_length=args.motion_max_seq_length,
)
# skip loading position embeddings
adapter.load_state_dict(conv_state_dict, strict=False)
adapter.save_pretrained(args.output_path)

if args.save_fp16:
adapter.to(torch.float16).save_pretrained(args.output_path, variant="fp16")
adapter.to(dtype=torch.float16).save_pretrained(args.output_path, variant="fp16")
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
"AmusedInpaintPipeline",
"AmusedPipeline",
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffVideoToVideoPipeline",
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
Expand Down Expand Up @@ -595,6 +596,7 @@
AmusedInpaintPipeline,
AmusedPipeline,
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffVideoToVideoPipeline,
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/unets/unet_3d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def get_down_block(
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
return CrossAttnDownBlockMotion(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
Expand Down Expand Up @@ -255,6 +256,7 @@ def get_up_block(
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
return CrossAttnUpBlockMotion(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
Expand Down
74 changes: 68 additions & 6 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,18 @@ def __init__(
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
use_linear_projection: bool = False,
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
motion_max_seq_length: int = 32,
motion_num_attention_heads: int = 8,
use_motion_mid_block: int = True,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
projection_class_embeddings_input_dim: Optional[int] = None,
time_cond_proj_dim: Optional[int] = None,
):
super().__init__()
Expand All @@ -240,6 +245,21 @@ def __init__(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)

if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)

if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)

if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")

# input
conv_in_kernel = 3
conv_out_kernel = 3
Expand All @@ -260,13 +280,26 @@ def __init__(
if encoder_hid_dim_type is None:
self.encoder_hid_proj = None

if addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also load add_embedding in from_unet2d? Something like:

if hasattr(model, "add_embedding"):
    model.add_embedding.load_state_dict(unet.add_embedding.state_dict())

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we take a look at current unet_2d_condition.py modelling code, the team has refactored out these changes into separate helper functions. I think that because unet_motion_model.py is mostly a copy of that, we can adapt those changes here and therefore all the functionality one would need. We can take it up in a future PR in my opinion (also I'm afraid I will not have time to test things thoroughly if we do it here).


# class embedding
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])

if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)

if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)

if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)

if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)

# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
Expand All @@ -276,21 +309,22 @@ def __init__(

down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
num_layers=layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
dual_cross_attention=False,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[i],
)
self.down_blocks.append(down_block)

Expand All @@ -302,13 +336,14 @@ def __init__(
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[-1],
)

else:
Expand All @@ -318,11 +353,12 @@ def __init__(
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
transformer_layers_per_block=transformer_layers_per_block[-1],
)

# count how many layers upsample the images
Expand All @@ -331,6 +367,9 @@ def __init__(
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))

output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
Expand All @@ -349,7 +388,7 @@ def __init__(

up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
num_layers=reversed_layers_per_block[i] + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
Expand All @@ -358,13 +397,14 @@ def __init__(
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=False,
resolution_idx=i,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down Expand Up @@ -835,6 +875,28 @@ def forward(
t_emb = t_emb.to(dtype=self.dtype)

emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None

if self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)

text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)

emb = emb if aug_emb is None else emb + aug_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
_import_structure["animatediff"] = [
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffVideoToVideoPipeline",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
Expand Down Expand Up @@ -367,7 +368,7 @@
from ..utils.dummy_torch_and_transformers_objects import *
else:
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
from .animatediff import AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline
from .animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline, AnimateDiffVideoToVideoPipeline
from .audioldm import AudioLDMPipeline
from .audioldm2 import (
AudioLDM2Pipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/animatediff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand All @@ -33,6 +34,7 @@

else:
from .pipeline_animatediff import AnimateDiffPipeline
from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
from .pipeline_output import AnimateDiffPipelineOutput

Expand Down
Loading
Loading