|
| 1 | +from dataclasses import dataclass |
| 2 | +from typing import Optional, Union |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn.functional as F |
| 6 | +from torch import nn |
| 7 | + |
| 8 | +from ..configuration_utils import ConfigMixin, register_to_config |
| 9 | +from ..modeling_utils import ModelMixin |
| 10 | +from ..utils import BaseOutput |
| 11 | +from .attention import BasicTransformerBlock |
| 12 | +from .embeddings import TimestepEmbedding, Timesteps |
| 13 | + |
| 14 | + |
| 15 | +@dataclass |
| 16 | +class PriorTransformerOutput(BaseOutput): |
| 17 | + """ |
| 18 | + Args: |
| 19 | + predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): |
| 20 | + The predicted CLIP image embedding conditioned on the CLIP text embedding input. |
| 21 | + """ |
| 22 | + |
| 23 | + predicted_image_embedding: torch.FloatTensor |
| 24 | + |
| 25 | + |
| 26 | +class PriorTransformer(ModelMixin, ConfigMixin): |
| 27 | + """ |
| 28 | + The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the |
| 29 | + transformer predicts the image embeddings through a denoising diffusion process. |
| 30 | +
|
| 31 | + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library |
| 32 | + implements for all the models (such as downloading or saving, etc.) |
| 33 | +
|
| 34 | + For more details, see the original paper: https://arxiv.org/abs/2204.06125 |
| 35 | +
|
| 36 | + Parameters: |
| 37 | + num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. |
| 38 | + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. |
| 39 | + num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. |
| 40 | + embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP |
| 41 | + image embeddings and text embeddings are both the same dimension. |
| 42 | + num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the |
| 43 | + length of the prompt after it has been tokenized. |
| 44 | + additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the |
| 45 | + projected hidden_states. The actual length of the used hidden_states is `num_embeddings + |
| 46 | + additional_embeddings`. |
| 47 | + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
| 48 | +
|
| 49 | + """ |
| 50 | + |
| 51 | + @register_to_config |
| 52 | + def __init__( |
| 53 | + self, |
| 54 | + num_attention_heads: int = 32, |
| 55 | + attention_head_dim: int = 64, |
| 56 | + num_layers: int = 20, |
| 57 | + embedding_dim: int = 768, |
| 58 | + num_embeddings=77, |
| 59 | + additional_embeddings=4, |
| 60 | + dropout: float = 0.0, |
| 61 | + ): |
| 62 | + super().__init__() |
| 63 | + self.num_attention_heads = num_attention_heads |
| 64 | + self.attention_head_dim = attention_head_dim |
| 65 | + inner_dim = num_attention_heads * attention_head_dim |
| 66 | + self.additional_embeddings = additional_embeddings |
| 67 | + |
| 68 | + self.time_proj = Timesteps(inner_dim, True, 0) |
| 69 | + self.time_embedding = TimestepEmbedding(inner_dim, inner_dim) |
| 70 | + |
| 71 | + self.proj_in = nn.Linear(embedding_dim, inner_dim) |
| 72 | + |
| 73 | + self.embedding_proj = nn.Linear(embedding_dim, inner_dim) |
| 74 | + self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) |
| 75 | + |
| 76 | + self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) |
| 77 | + |
| 78 | + self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) |
| 79 | + |
| 80 | + self.transformer_blocks = nn.ModuleList( |
| 81 | + [ |
| 82 | + BasicTransformerBlock( |
| 83 | + inner_dim, |
| 84 | + num_attention_heads, |
| 85 | + attention_head_dim, |
| 86 | + dropout=dropout, |
| 87 | + activation_fn="gelu", |
| 88 | + attention_bias=True, |
| 89 | + ) |
| 90 | + for d in range(num_layers) |
| 91 | + ] |
| 92 | + ) |
| 93 | + |
| 94 | + self.norm_out = nn.LayerNorm(inner_dim) |
| 95 | + self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim) |
| 96 | + |
| 97 | + causal_attention_mask = torch.full( |
| 98 | + [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf") |
| 99 | + ) |
| 100 | + causal_attention_mask.triu_(1) |
| 101 | + causal_attention_mask = causal_attention_mask[None, ...] |
| 102 | + self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) |
| 103 | + |
| 104 | + self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim)) |
| 105 | + self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim)) |
| 106 | + |
| 107 | + def forward( |
| 108 | + self, |
| 109 | + hidden_states, |
| 110 | + timestep: Union[torch.Tensor, float, int], |
| 111 | + proj_embedding: torch.FloatTensor, |
| 112 | + encoder_hidden_states: torch.FloatTensor, |
| 113 | + attention_mask: Optional[torch.BoolTensor] = None, |
| 114 | + return_dict: bool = True, |
| 115 | + ): |
| 116 | + """ |
| 117 | + Args: |
| 118 | + hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): |
| 119 | + x_t, the currently predicted image embeddings. |
| 120 | + timestep (`torch.long`): |
| 121 | + Current denoising step. |
| 122 | + proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): |
| 123 | + Projected embedding vector the denoising process is conditioned on. |
| 124 | + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): |
| 125 | + Hidden states of the text embeddings the denoising process is conditioned on. |
| 126 | + attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): |
| 127 | + Text mask for the text embeddings. |
| 128 | + return_dict (`bool`, *optional*, defaults to `True`): |
| 129 | + Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain |
| 130 | + tuple. |
| 131 | +
|
| 132 | + Returns: |
| 133 | + [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: |
| 134 | + [`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When |
| 135 | + returning a tuple, the first element is the sample tensor. |
| 136 | + """ |
| 137 | + batch_size = hidden_states.shape[0] |
| 138 | + |
| 139 | + timesteps = timestep |
| 140 | + if not torch.is_tensor(timesteps): |
| 141 | + timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) |
| 142 | + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: |
| 143 | + timesteps = timesteps[None].to(hidden_states.device) |
| 144 | + |
| 145 | + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML |
| 146 | + timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) |
| 147 | + |
| 148 | + timesteps_projected = self.time_proj(timesteps) |
| 149 | + |
| 150 | + # timesteps does not contain any weights and will always return f32 tensors |
| 151 | + # but time_embedding might be fp16, so we need to cast here. |
| 152 | + timesteps_projected = timesteps_projected.to(dtype=self.dtype) |
| 153 | + time_embeddings = self.time_embedding(timesteps_projected) |
| 154 | + |
| 155 | + proj_embeddings = self.embedding_proj(proj_embedding) |
| 156 | + encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) |
| 157 | + hidden_states = self.proj_in(hidden_states) |
| 158 | + prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) |
| 159 | + positional_embeddings = self.positional_embedding.to(hidden_states.dtype) |
| 160 | + |
| 161 | + hidden_states = torch.cat( |
| 162 | + [ |
| 163 | + encoder_hidden_states, |
| 164 | + proj_embeddings[:, None, :], |
| 165 | + time_embeddings[:, None, :], |
| 166 | + hidden_states[:, None, :], |
| 167 | + prd_embedding, |
| 168 | + ], |
| 169 | + dim=1, |
| 170 | + ) |
| 171 | + |
| 172 | + hidden_states = hidden_states + positional_embeddings |
| 173 | + |
| 174 | + if attention_mask is not None: |
| 175 | + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 |
| 176 | + attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) |
| 177 | + attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) |
| 178 | + attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) |
| 179 | + |
| 180 | + for block in self.transformer_blocks: |
| 181 | + hidden_states = block(hidden_states, attention_mask=attention_mask) |
| 182 | + |
| 183 | + hidden_states = self.norm_out(hidden_states) |
| 184 | + hidden_states = hidden_states[:, -1] |
| 185 | + predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) |
| 186 | + |
| 187 | + if not return_dict: |
| 188 | + return (predicted_image_embedding,) |
| 189 | + |
| 190 | + return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) |
| 191 | + |
| 192 | + def post_process_latents(self, prior_latents): |
| 193 | + prior_latents = (prior_latents * self.clip_std) + self.clip_mean |
| 194 | + return prior_latents |
0 commit comments