Skip to content

Commit 09b7bfc

Browse files
authored
[Core] move transformer scripts to transformers modules (#6747)
* move transformer scripts to transformers modules * move transformer model test * move prior transformer test to directory * fix doc path * correct doc path * add: __init__.py
1 parent 5d8b198 commit 09b7bfc

28 files changed

+1925
-1754
lines changed

docs/source/en/api/models/prior_transformer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ The abstract from the paper is:
2424

2525
## PriorTransformerOutput
2626

27-
[[autodoc]] models.prior_transformer.PriorTransformerOutput
27+
[[autodoc]] models.transformers.prior_transformer.PriorTransformerOutput

docs/source/en/api/models/transformer2d.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@ It is assumed one of the input classes is the masked latent pixel. The predicted
3838

3939
## Transformer2DModelOutput
4040

41-
[[autodoc]] models.transformer_2d.Transformer2DModelOutput
41+
[[autodoc]] models.transformers.transformer_2d.Transformer2DModelOutput

docs/source/en/api/models/transformer_temporal.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ A Transformer model for video-like data.
1616

1717
## TransformerTemporalModel
1818

19-
[[autodoc]] models.transformer_temporal.TransformerTemporalModel
19+
[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModel
2020

2121
## TransformerTemporalModelOutput
2222

23-
[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput
23+
[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModelOutput

scripts/convert_kakao_brain_unclip_to_diffusers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
77

88
from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
9-
from diffusers.models.prior_transformer import PriorTransformer
9+
from diffusers.models.transformers.prior_transformer import PriorTransformer
1010
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
1111
from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
1212

scripts/convert_kandinsky_to_diffusers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from accelerate import load_checkpoint_and_dispatch
77

88
from diffusers import UNet2DConditionModel
9-
from diffusers.models.prior_transformer import PriorTransformer
9+
from diffusers.models.transformers.prior_transformer import PriorTransformer
1010
from diffusers.models.vq_model import VQModel
1111

1212

scripts/convert_shap_e_to_diffusers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from accelerate import load_checkpoint_and_dispatch
66

7-
from diffusers.models.prior_transformer import PriorTransformer
7+
from diffusers.models.transformers.prior_transformer import PriorTransformer
88
from diffusers.pipelines.shap_e import ShapERenderer
99

1010

src/diffusers/models/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@
3535
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
3636
_import_structure["embeddings"] = ["ImageProjection"]
3737
_import_structure["modeling_utils"] = ["ModelMixin"]
38-
_import_structure["prior_transformer"] = ["PriorTransformer"]
39-
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
40-
_import_structure["transformer_2d"] = ["Transformer2DModel"]
41-
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
38+
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
39+
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
40+
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
41+
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
4242
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
4343
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
4444
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -66,13 +66,15 @@
6666
ConsistencyDecoderVAE,
6767
)
6868
from .controlnet import ControlNetModel
69-
from .dual_transformer_2d import DualTransformer2DModel
7069
from .embeddings import ImageProjection
7170
from .modeling_utils import ModelMixin
72-
from .prior_transformer import PriorTransformer
73-
from .t5_film_transformer import T5FilmDecoder
74-
from .transformer_2d import Transformer2DModel
75-
from .transformer_temporal import TransformerTemporalModel
71+
from .transformers import (
72+
DualTransformer2DModel,
73+
PriorTransformer,
74+
T5FilmDecoder,
75+
Transformer2DModel,
76+
TransformerTemporalModel,
77+
)
7678
from .unets import (
7779
Kandinsky3UNet,
7880
MotionAdapter,

src/diffusers/models/dual_transformer_2d.py

Lines changed: 5 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -11,145 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional
14+
from ..utils import deprecate
15+
from .transformers.dual_transformer_2d import DualTransformer2DModel
1516

16-
from torch import nn
1717

18-
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
19-
20-
21-
class DualTransformer2DModel(nn.Module):
22-
"""
23-
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24-
25-
Parameters:
26-
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27-
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28-
in_channels (`int`, *optional*):
29-
Pass if the input is continuous. The number of channels in the input and output.
30-
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31-
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32-
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33-
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34-
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35-
`ImagePositionalEmbeddings`.
36-
num_vector_embeds (`int`, *optional*):
37-
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38-
Includes the class for the masked latent pixel.
39-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40-
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41-
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42-
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43-
up to but not more than steps than `num_embeds_ada_norm`.
44-
attention_bias (`bool`, *optional*):
45-
Configure if the TransformerBlocks' attention should contain a bias parameter.
46-
"""
47-
48-
def __init__(
49-
self,
50-
num_attention_heads: int = 16,
51-
attention_head_dim: int = 88,
52-
in_channels: Optional[int] = None,
53-
num_layers: int = 1,
54-
dropout: float = 0.0,
55-
norm_num_groups: int = 32,
56-
cross_attention_dim: Optional[int] = None,
57-
attention_bias: bool = False,
58-
sample_size: Optional[int] = None,
59-
num_vector_embeds: Optional[int] = None,
60-
activation_fn: str = "geglu",
61-
num_embeds_ada_norm: Optional[int] = None,
62-
):
63-
super().__init__()
64-
self.transformers = nn.ModuleList(
65-
[
66-
Transformer2DModel(
67-
num_attention_heads=num_attention_heads,
68-
attention_head_dim=attention_head_dim,
69-
in_channels=in_channels,
70-
num_layers=num_layers,
71-
dropout=dropout,
72-
norm_num_groups=norm_num_groups,
73-
cross_attention_dim=cross_attention_dim,
74-
attention_bias=attention_bias,
75-
sample_size=sample_size,
76-
num_vector_embeds=num_vector_embeds,
77-
activation_fn=activation_fn,
78-
num_embeds_ada_norm=num_embeds_ada_norm,
79-
)
80-
for _ in range(2)
81-
]
82-
)
83-
84-
# Variables that can be set by a pipeline:
85-
86-
# The ratio of transformer1 to transformer2's output states to be combined during inference
87-
self.mix_ratio = 0.5
88-
89-
# The shape of `encoder_hidden_states` is expected to be
90-
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91-
self.condition_lengths = [77, 257]
92-
93-
# Which transformer to use to encode which condition.
94-
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95-
self.transformer_index_for_condition = [1, 0]
96-
97-
def forward(
98-
self,
99-
hidden_states,
100-
encoder_hidden_states,
101-
timestep=None,
102-
attention_mask=None,
103-
cross_attention_kwargs=None,
104-
return_dict: bool = True,
105-
):
106-
"""
107-
Args:
108-
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109-
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110-
hidden_states.
111-
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112-
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113-
self-attention.
114-
timestep ( `torch.long`, *optional*):
115-
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116-
attention_mask (`torch.FloatTensor`, *optional*):
117-
Optional attention mask to be applied in Attention.
118-
cross_attention_kwargs (`dict`, *optional*):
119-
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
120-
`self.processor` in
121-
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
122-
return_dict (`bool`, *optional*, defaults to `True`):
123-
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
124-
125-
Returns:
126-
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
127-
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
128-
returning a tuple, the first element is the sample tensor.
129-
"""
130-
input_states = hidden_states
131-
132-
encoded_states = []
133-
tokens_start = 0
134-
# attention_mask is not used yet
135-
for i in range(2):
136-
# for each of the two transformers, pass the corresponding condition tokens
137-
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
138-
transformer_index = self.transformer_index_for_condition[i]
139-
encoded_state = self.transformers[transformer_index](
140-
input_states,
141-
encoder_hidden_states=condition_state,
142-
timestep=timestep,
143-
cross_attention_kwargs=cross_attention_kwargs,
144-
return_dict=False,
145-
)[0]
146-
encoded_states.append(encoded_state - input_states)
147-
tokens_start += self.condition_lengths[i]
148-
149-
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
150-
output_states = output_states + input_states
151-
152-
if not return_dict:
153-
return (output_states,)
154-
155-
return Transformer2DModelOutput(sample=output_states)
18+
class DualTransformer2DModel(DualTransformer2DModel):
19+
deprecation_message = "Importing `DualTransformer2DModel` from `diffusers.models.dual_transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel`, instead."
20+
deprecate("DualTransformer2DModel", "0.29", deprecation_message)

0 commit comments

Comments
 (0)