Skip to content

Commit 0d1d267

Browse files
a-r-r-o-whyang0511yiyixuxustevhliu
authored
[core] Allegro T2V (#9736)
* update * refactor transformer part 1 * refactor part 2 * refactor part 3 * make style * refactor part 4; modeling tests * make style * refactor part 5 * refactor part 6 * gradient checkpointing * pipeline tests (broken atm) * update * add coauthor Co-Authored-By: Huan Yang <[email protected]> * refactor part 7 * add docs * make style * add coauthor Co-Authored-By: YiYi Xu <[email protected]> * make fix-copies * undo unrelated change * revert changes to embeddings, normalization, transformer * refactor part 8 * make style * refactor part 9 * make style * fix * apply suggestions from review * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> * update example * remove attention mask for self-attention * update * copied from * update * update --------- Co-authored-by: Huan Yang <[email protected]> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Steven Liu <[email protected]>
1 parent c5376c5 commit 0d1d267

23 files changed

+3300
-5
lines changed

docs/source/en/_toctree.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@
252252
title: SparseControlNetModel
253253
title: ControlNets
254254
- sections:
255+
- local: api/models/allegro_transformer3d
256+
title: AllegroTransformer3DModel
255257
- local: api/models/aura_flow_transformer2d
256258
title: AuraFlowTransformer2DModel
257259
- local: api/models/cogvideox_transformer3d
@@ -300,6 +302,8 @@
300302
- sections:
301303
- local: api/models/autoencoderkl
302304
title: AutoencoderKL
305+
- local: api/models/autoencoderkl_allegro
306+
title: AutoencoderKLAllegro
303307
- local: api/models/autoencoderkl_cogvideox
304308
title: AutoencoderKLCogVideoX
305309
- local: api/models/asymmetricautoencoderkl
@@ -318,6 +322,8 @@
318322
sections:
319323
- local: api/pipelines/overview
320324
title: Overview
325+
- local: api/pipelines/allegro
326+
title: Allegro
321327
- local: api/pipelines/amused
322328
title: aMUSEd
323329
- local: api/pipelines/animatediff
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# AllegroTransformer3DModel
13+
14+
A Diffusion Transformer model for 3D data from [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import AllegroTransformer3DModel
20+
21+
vae = AllegroTransformer3DModel.from_pretrained("rhymes-ai/Allegro", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
22+
```
23+
24+
## AllegroTransformer3DModel
25+
26+
[[autodoc]] AllegroTransformer3DModel
27+
28+
## Transformer2DModelOutput
29+
30+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# AutoencoderKLAllegro
13+
14+
The 3D variational autoencoder (VAE) model with KL loss used in [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import AutoencoderKLAllegro
20+
21+
vae = AutoencoderKLCogVideoX.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda")
22+
```
23+
24+
## AutoencoderKLAllegro
25+
26+
[[autodoc]] AutoencoderKLAllegro
27+
- decode
28+
- encode
29+
- all
30+
31+
## AutoencoderKLOutput
32+
33+
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
34+
35+
## DecoderOutput
36+
37+
[[autodoc]] models.autoencoders.vae.DecoderOutput
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# Allegro
13+
14+
[Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) from RhymesAI, by Yuan Zhou, Qiuyue Wang, Yuxuan Cai, Huan Yang.
15+
16+
The abstract from the paper is:
17+
18+
*Significant advancements have been made in the field of video generation, with the open-source community contributing a wealth of research papers and tools for training high-quality models. However, despite these efforts, the available information and resources remain insufficient for achieving commercial-level performance. In this report, we open the black box and introduce Allegro, an advanced video generation model that excels in both quality and temporal consistency. We also highlight the current limitations in the field and present a comprehensive methodology for training high-performance, commercial-level video generation models, addressing key aspects such as data, model architecture, training pipeline, and evaluation. Our user study shows that Allegro surpasses existing open-source models and most commercial models, ranking just behind Hailuo and Kling. Code: https://github.com/rhymes-ai/Allegro , Model: https://huggingface.co/rhymes-ai/Allegro , Gallery: https://rhymes.ai/allegro_gallery .*
19+
20+
<Tip>
21+
22+
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
23+
24+
</Tip>
25+
26+
## AllegroPipeline
27+
28+
[[autodoc]] AllegroPipeline
29+
- all
30+
- __call__
31+
32+
## AllegroPipelineOutput
33+
34+
[[autodoc]] pipelines.allegro.pipeline_output.AllegroPipelineOutput

src/diffusers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@
7777
else:
7878
_import_structure["models"].extend(
7979
[
80+
"AllegroTransformer3DModel",
8081
"AsymmetricAutoencoderKL",
8182
"AuraFlowTransformer2DModel",
8283
"AutoencoderKL",
84+
"AutoencoderKLAllegro",
8385
"AutoencoderKLCogVideoX",
8486
"AutoencoderKLTemporalDecoder",
8587
"AutoencoderOobleck",
@@ -237,6 +239,7 @@
237239
else:
238240
_import_structure["pipelines"].extend(
239241
[
242+
"AllegroPipeline",
240243
"AltDiffusionImg2ImgPipeline",
241244
"AltDiffusionPipeline",
242245
"AmusedImg2ImgPipeline",
@@ -556,9 +559,11 @@
556559
from .utils.dummy_pt_objects import * # noqa F403
557560
else:
558561
from .models import (
562+
AllegroTransformer3DModel,
559563
AsymmetricAutoencoderKL,
560564
AuraFlowTransformer2DModel,
561565
AutoencoderKL,
566+
AutoencoderKLAllegro,
562567
AutoencoderKLCogVideoX,
563568
AutoencoderKLTemporalDecoder,
564569
AutoencoderOobleck,
@@ -697,6 +702,7 @@
697702
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
698703
else:
699704
from .pipelines import (
705+
AllegroPipeline,
700706
AltDiffusionImg2ImgPipeline,
701707
AltDiffusionPipeline,
702708
AmusedImg2ImgPipeline,

src/diffusers/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
2929
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
3030
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
31+
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
3132
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
3233
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
3334
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
@@ -54,6 +55,7 @@
5455
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
5556
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
5657
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
58+
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
5759
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
5860
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
5961
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
@@ -81,6 +83,7 @@
8183
from .autoencoders import (
8284
AsymmetricAutoencoderKL,
8385
AutoencoderKL,
86+
AutoencoderKLAllegro,
8487
AutoencoderKLCogVideoX,
8588
AutoencoderKLTemporalDecoder,
8689
AutoencoderOobleck,
@@ -97,6 +100,7 @@
97100
from .embeddings import ImageProjection
98101
from .modeling_utils import ModelMixin
99102
from .transformers import (
103+
AllegroTransformer3DModel,
100104
AuraFlowTransformer2DModel,
101105
CogVideoXTransformer3DModel,
102106
CogView3PlusTransformer2DModel,

src/diffusers/models/attention_processor.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,6 +1521,100 @@ def __call__(
15211521
return hidden_states, encoder_hidden_states
15221522

15231523

1524+
class AllegroAttnProcessor2_0:
1525+
r"""
1526+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
1527+
used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector.
1528+
"""
1529+
1530+
def __init__(self):
1531+
if not hasattr(F, "scaled_dot_product_attention"):
1532+
raise ImportError(
1533+
"AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1534+
)
1535+
1536+
def __call__(
1537+
self,
1538+
attn: Attention,
1539+
hidden_states: torch.Tensor,
1540+
encoder_hidden_states: Optional[torch.Tensor] = None,
1541+
attention_mask: Optional[torch.Tensor] = None,
1542+
temb: Optional[torch.Tensor] = None,
1543+
image_rotary_emb: Optional[torch.Tensor] = None,
1544+
) -> torch.Tensor:
1545+
residual = hidden_states
1546+
1547+
if attn.spatial_norm is not None:
1548+
hidden_states = attn.spatial_norm(hidden_states, temb)
1549+
1550+
input_ndim = hidden_states.ndim
1551+
1552+
if input_ndim == 4:
1553+
batch_size, channel, height, width = hidden_states.shape
1554+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1555+
1556+
batch_size, sequence_length, _ = (
1557+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1558+
)
1559+
1560+
if attention_mask is not None:
1561+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1562+
# scaled_dot_product_attention expects attention_mask shape to be
1563+
# (batch, heads, source_length, target_length)
1564+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1565+
1566+
if attn.group_norm is not None:
1567+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1568+
1569+
query = attn.to_q(hidden_states)
1570+
1571+
if encoder_hidden_states is None:
1572+
encoder_hidden_states = hidden_states
1573+
elif attn.norm_cross:
1574+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1575+
1576+
key = attn.to_k(encoder_hidden_states)
1577+
value = attn.to_v(encoder_hidden_states)
1578+
1579+
inner_dim = key.shape[-1]
1580+
head_dim = inner_dim // attn.heads
1581+
1582+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1583+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1584+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1585+
1586+
# Apply RoPE if needed
1587+
if image_rotary_emb is not None and not attn.is_cross_attention:
1588+
from .embeddings import apply_rotary_emb_allegro
1589+
1590+
query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
1591+
key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
1592+
1593+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
1594+
# TODO: add support for attn.scale when we move to Torch 2.1
1595+
hidden_states = F.scaled_dot_product_attention(
1596+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1597+
)
1598+
1599+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1600+
hidden_states = hidden_states.to(query.dtype)
1601+
1602+
# linear proj
1603+
hidden_states = attn.to_out[0](hidden_states)
1604+
# dropout
1605+
hidden_states = attn.to_out[1](hidden_states)
1606+
1607+
if input_ndim == 4:
1608+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1609+
1610+
if attn.residual_connection:
1611+
hidden_states = hidden_states + residual
1612+
1613+
hidden_states = hidden_states / attn.rescale_output_factor
1614+
1615+
return hidden_states
1616+
1617+
15241618
class AuraFlowAttnProcessor2_0:
15251619
"""Attention processor used typically in processing Aura Flow."""
15261620

src/diffusers/models/autoencoders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
22
from .autoencoder_kl import AutoencoderKL
3+
from .autoencoder_kl_allegro import AutoencoderKLAllegro
34
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
45
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
56
from .autoencoder_oobleck import AutoencoderOobleck

0 commit comments

Comments
 (0)