Skip to content

Commit 818f760

Browse files
a-r-r-o-wDN6
andauthored
[Pipeline] AnimateDiff SDXL (#6721)
* update conversion script to handle motion adapter sdxl checkpoint * add animatediff xl * handle addition_embed_type * fix output * update * add imports * make fix-copies * add decode latents * update docstrings * add animatediff sdxl to docs * remove unnecessary lines * update example * add test * revert conv_in conv_out kernel param * remove unused param addition_embed_type_num_heads * latest IPAdapter impl * make fix-copies * fix return * add IPAdapterTesterMixin to tests * fix return * revert based on suggestion * add freeinit * fix test_to_dtype test * use StableDiffusionMixin instead of different helper methods * fix progress bar iterations * apply suggestions from review * hardcode flip_sin_to_cos and freq_shift * make fix-copies * fix ip adapter implementation * fix last failing test * make style * Update docs/source/en/api/pipelines/animatediff.md Co-authored-by: Dhruv Nair <[email protected]> * remove todo * fix doc-builder errors --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent f29b934 commit 818f760

File tree

10 files changed

+1740
-9
lines changed

10 files changed

+1740
-9
lines changed

docs/source/en/api/pipelines/animatediff.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,53 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you
101101

102102
</Tip>
103103

104+
### AnimateDiffSDXLPipeline
105+
106+
AnimateDiff can also be used with SDXL models. This is currently an experimental feature as only a beta release of the motion adapter checkpoint is available.
107+
108+
```python
109+
import torch
110+
from diffusers.models import MotionAdapter
111+
from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler
112+
from diffusers.utils import export_to_gif
113+
114+
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16)
115+
116+
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
117+
scheduler = DDIMScheduler.from_pretrained(
118+
model_id,
119+
subfolder="scheduler",
120+
clip_sample=False,
121+
timestep_spacing="linspace",
122+
beta_schedule="linear",
123+
steps_offset=1,
124+
)
125+
pipe = AnimateDiffSDXLPipeline.from_pretrained(
126+
model_id,
127+
motion_adapter=adapter,
128+
scheduler=scheduler,
129+
torch_dtype=torch.float16,
130+
variant="fp16",
131+
).to("cuda")
132+
133+
# enable memory savings
134+
pipe.enable_vae_slicing()
135+
pipe.enable_vae_tiling()
136+
137+
output = pipe(
138+
prompt="a panda surfing in the ocean, realistic, high quality",
139+
negative_prompt="low quality, worst quality",
140+
num_inference_steps=20,
141+
guidance_scale=8,
142+
width=1024,
143+
height=1024,
144+
num_frames=16,
145+
)
146+
147+
frames = output.frames[0]
148+
export_to_gif(frames, "animation.gif")
149+
```
150+
104151
### AnimateDiffVideoToVideoPipeline
105152

106153
AnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities.
@@ -522,6 +569,12 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
522569
- all
523570
- __call__
524571

572+
## AnimateDiffSDXLPipeline
573+
574+
[[autodoc]] AnimateDiffSDXLPipeline
575+
- all
576+
- __call__
577+
525578
## AnimateDiffVideoToVideoPipeline
526579

527580
[[autodoc]] AnimateDiffVideoToVideoPipeline

scripts/convert_animatediff_motion_module_to_diffusers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def get_args():
3131
parser.add_argument("--output_path", type=str, required=True)
3232
parser.add_argument("--use_motion_mid_block", action="store_true")
3333
parser.add_argument("--motion_max_seq_length", type=int, default=32)
34+
parser.add_argument("--block_out_channels", nargs="+", default=[320, 640, 1280, 1280], type=int)
3435
parser.add_argument("--save_fp16", action="store_true")
3536

3637
return parser.parse_args()
@@ -49,11 +50,13 @@ def get_args():
4950

5051
conv_state_dict = convert_motion_module(state_dict)
5152
adapter = MotionAdapter(
52-
use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length
53+
block_out_channels=args.block_out_channels,
54+
use_motion_mid_block=args.use_motion_mid_block,
55+
motion_max_seq_length=args.motion_max_seq_length,
5356
)
5457
# skip loading position embeddings
5558
adapter.load_state_dict(conv_state_dict, strict=False)
5659
adapter.save_pretrained(args.output_path)
5760

5861
if args.save_fp16:
59-
adapter.to(torch.float16).save_pretrained(args.output_path, variant="fp16")
62+
adapter.to(dtype=torch.float16).save_pretrained(args.output_path, variant="fp16")

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@
216216
"AmusedInpaintPipeline",
217217
"AmusedPipeline",
218218
"AnimateDiffPipeline",
219+
"AnimateDiffSDXLPipeline",
219220
"AnimateDiffVideoToVideoPipeline",
220221
"AudioLDM2Pipeline",
221222
"AudioLDM2ProjectionModel",
@@ -595,6 +596,7 @@
595596
AmusedInpaintPipeline,
596597
AmusedPipeline,
597598
AnimateDiffPipeline,
599+
AnimateDiffSDXLPipeline,
598600
AnimateDiffVideoToVideoPipeline,
599601
AudioLDM2Pipeline,
600602
AudioLDM2ProjectionModel,

src/diffusers/models/unets/unet_3d_blocks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def get_down_block(
121121
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
122122
return CrossAttnDownBlockMotion(
123123
num_layers=num_layers,
124+
transformer_layers_per_block=transformer_layers_per_block,
124125
in_channels=in_channels,
125126
out_channels=out_channels,
126127
temb_channels=temb_channels,
@@ -255,6 +256,7 @@ def get_up_block(
255256
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
256257
return CrossAttnUpBlockMotion(
257258
num_layers=num_layers,
259+
transformer_layers_per_block=transformer_layers_per_block,
258260
in_channels=in_channels,
259261
out_channels=out_channels,
260262
prev_output_channel=prev_output_channel,

src/diffusers/models/unets/unet_motion_model.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,18 @@ def __init__(
211211
norm_num_groups: int = 32,
212212
norm_eps: float = 1e-5,
213213
cross_attention_dim: int = 1280,
214+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
215+
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
214216
use_linear_projection: bool = False,
215217
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
216218
motion_max_seq_length: int = 32,
217219
motion_num_attention_heads: int = 8,
218220
use_motion_mid_block: int = True,
219221
encoder_hid_dim: Optional[int] = None,
220222
encoder_hid_dim_type: Optional[str] = None,
223+
addition_embed_type: Optional[str] = None,
224+
addition_time_embed_dim: Optional[int] = None,
225+
projection_class_embeddings_input_dim: Optional[int] = None,
221226
time_cond_proj_dim: Optional[int] = None,
222227
):
223228
super().__init__()
@@ -240,6 +245,21 @@ def __init__(
240245
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
241246
)
242247

248+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
249+
raise ValueError(
250+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
251+
)
252+
253+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
254+
raise ValueError(
255+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
256+
)
257+
258+
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
259+
for layer_number_per_block in transformer_layers_per_block:
260+
if isinstance(layer_number_per_block, list):
261+
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
262+
243263
# input
244264
conv_in_kernel = 3
245265
conv_out_kernel = 3
@@ -260,13 +280,26 @@ def __init__(
260280
if encoder_hid_dim_type is None:
261281
self.encoder_hid_proj = None
262282

283+
if addition_embed_type == "text_time":
284+
self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0)
285+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
286+
263287
# class embedding
264288
self.down_blocks = nn.ModuleList([])
265289
self.up_blocks = nn.ModuleList([])
266290

267291
if isinstance(num_attention_heads, int):
268292
num_attention_heads = (num_attention_heads,) * len(down_block_types)
269293

294+
if isinstance(cross_attention_dim, int):
295+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
296+
297+
if isinstance(layers_per_block, int):
298+
layers_per_block = [layers_per_block] * len(down_block_types)
299+
300+
if isinstance(transformer_layers_per_block, int):
301+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
302+
270303
# down
271304
output_channel = block_out_channels[0]
272305
for i, down_block_type in enumerate(down_block_types):
@@ -276,21 +309,22 @@ def __init__(
276309

277310
down_block = get_down_block(
278311
down_block_type,
279-
num_layers=layers_per_block,
312+
num_layers=layers_per_block[i],
280313
in_channels=input_channel,
281314
out_channels=output_channel,
282315
temb_channels=time_embed_dim,
283316
add_downsample=not is_final_block,
284317
resnet_eps=norm_eps,
285318
resnet_act_fn=act_fn,
286319
resnet_groups=norm_num_groups,
287-
cross_attention_dim=cross_attention_dim,
320+
cross_attention_dim=cross_attention_dim[i],
288321
num_attention_heads=num_attention_heads[i],
289322
downsample_padding=downsample_padding,
290323
use_linear_projection=use_linear_projection,
291324
dual_cross_attention=False,
292325
temporal_num_attention_heads=motion_num_attention_heads,
293326
temporal_max_seq_length=motion_max_seq_length,
327+
transformer_layers_per_block=transformer_layers_per_block[i],
294328
)
295329
self.down_blocks.append(down_block)
296330

@@ -302,13 +336,14 @@ def __init__(
302336
resnet_eps=norm_eps,
303337
resnet_act_fn=act_fn,
304338
output_scale_factor=mid_block_scale_factor,
305-
cross_attention_dim=cross_attention_dim,
339+
cross_attention_dim=cross_attention_dim[-1],
306340
num_attention_heads=num_attention_heads[-1],
307341
resnet_groups=norm_num_groups,
308342
dual_cross_attention=False,
309343
use_linear_projection=use_linear_projection,
310344
temporal_num_attention_heads=motion_num_attention_heads,
311345
temporal_max_seq_length=motion_max_seq_length,
346+
transformer_layers_per_block=transformer_layers_per_block[-1],
312347
)
313348

314349
else:
@@ -318,11 +353,12 @@ def __init__(
318353
resnet_eps=norm_eps,
319354
resnet_act_fn=act_fn,
320355
output_scale_factor=mid_block_scale_factor,
321-
cross_attention_dim=cross_attention_dim,
356+
cross_attention_dim=cross_attention_dim[-1],
322357
num_attention_heads=num_attention_heads[-1],
323358
resnet_groups=norm_num_groups,
324359
dual_cross_attention=False,
325360
use_linear_projection=use_linear_projection,
361+
transformer_layers_per_block=transformer_layers_per_block[-1],
326362
)
327363

328364
# count how many layers upsample the images
@@ -331,6 +367,9 @@ def __init__(
331367
# up
332368
reversed_block_out_channels = list(reversed(block_out_channels))
333369
reversed_num_attention_heads = list(reversed(num_attention_heads))
370+
reversed_layers_per_block = list(reversed(layers_per_block))
371+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
372+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
334373

335374
output_channel = reversed_block_out_channels[0]
336375
for i, up_block_type in enumerate(up_block_types):
@@ -349,7 +388,7 @@ def __init__(
349388

350389
up_block = get_up_block(
351390
up_block_type,
352-
num_layers=layers_per_block + 1,
391+
num_layers=reversed_layers_per_block[i] + 1,
353392
in_channels=input_channel,
354393
out_channels=output_channel,
355394
prev_output_channel=prev_output_channel,
@@ -358,13 +397,14 @@ def __init__(
358397
resnet_eps=norm_eps,
359398
resnet_act_fn=act_fn,
360399
resnet_groups=norm_num_groups,
361-
cross_attention_dim=cross_attention_dim,
400+
cross_attention_dim=reversed_cross_attention_dim[i],
362401
num_attention_heads=reversed_num_attention_heads[i],
363402
dual_cross_attention=False,
364403
resolution_idx=i,
365404
use_linear_projection=use_linear_projection,
366405
temporal_num_attention_heads=motion_num_attention_heads,
367406
temporal_max_seq_length=motion_max_seq_length,
407+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
368408
)
369409
self.up_blocks.append(up_block)
370410
prev_output_channel = output_channel
@@ -835,6 +875,28 @@ def forward(
835875
t_emb = t_emb.to(dtype=self.dtype)
836876

837877
emb = self.time_embedding(t_emb, timestep_cond)
878+
aug_emb = None
879+
880+
if self.config.addition_embed_type == "text_time":
881+
if "text_embeds" not in added_cond_kwargs:
882+
raise ValueError(
883+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
884+
)
885+
886+
text_embeds = added_cond_kwargs.get("text_embeds")
887+
if "time_ids" not in added_cond_kwargs:
888+
raise ValueError(
889+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
890+
)
891+
time_ids = added_cond_kwargs.get("time_ids")
892+
time_embeds = self.add_time_proj(time_ids.flatten())
893+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
894+
895+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
896+
add_embeds = add_embeds.to(emb.dtype)
897+
aug_emb = self.add_embedding(add_embeds)
898+
899+
emb = emb if aug_emb is None else emb + aug_emb
838900
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
839901
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
840902

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
115115
_import_structure["animatediff"] = [
116116
"AnimateDiffPipeline",
117+
"AnimateDiffSDXLPipeline",
117118
"AnimateDiffVideoToVideoPipeline",
118119
]
119120
_import_structure["audioldm"] = ["AudioLDMPipeline"]
@@ -367,7 +368,7 @@
367368
from ..utils.dummy_torch_and_transformers_objects import *
368369
else:
369370
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
370-
from .animatediff import AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline
371+
from .animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline, AnimateDiffVideoToVideoPipeline
371372
from .audioldm import AudioLDMPipeline
372373
from .audioldm2 import (
373374
AudioLDM2Pipeline,

src/diffusers/pipelines/animatediff/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2323
else:
2424
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
25+
_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
2526
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
2627

2728
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -33,6 +34,7 @@
3334

3435
else:
3536
from .pipeline_animatediff import AnimateDiffPipeline
37+
from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
3638
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
3739
from .pipeline_output import AnimateDiffPipelineOutput
3840

0 commit comments

Comments
 (0)