-
Notifications
You must be signed in to change notification settings - Fork 6k
kakaobrain unCLIP #1428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
williamberman
merged 43 commits into
huggingface:main
from
williamberman:kakaobrain_unclip
Dec 18, 2022
Merged
kakaobrain unCLIP #1428
Changes from all commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
1b7f73c
[wip] attention block updates
williamberman ef52835
[wip] unCLIP unet decoder and super res
williamberman 3acc487
[wip] unCLIP prior transformer
williamberman 6087a45
[wip] scheduler changes
williamberman 2cc5aea
[wip] text proj utility class
williamberman 06ae666
[wip] UnCLIPPipeline
williamberman c911fe6
[wip] kakaobrain unCLIP convert script
williamberman f4d35a0
Merge branch 'main' into kakaobrain_unclip
patrickvonplaten de7a56f
[unCLIP pipeline] fixes re: @patrickvonplaten
williamberman 710f648
UNCLIPScheduler re: @patrickvonplaten
williamberman 2b30fbc
mask -> attention_mask re: @patrickvonplaten
williamberman 5b20823
[DDPMScheduler] remove leftover change
williamberman d8957b2
[docs] PriorTransformer
williamberman 301c01e
[docs] UNet2DConditionModel and UNet2DModel
williamberman e64a9c5
[nit] UNCLIPScheduler -> UnCLIPScheduler
williamberman a866ded
[docs] SchedulingUnCLIP
williamberman a8207aa
[docs] UnCLIPTextProjModel
williamberman 344a1e8
refactor
patrickvonplaten 5428daf
finish licenses
patrickvonplaten 8c803b4
rename all to attention_mask and prep in models
patrickvonplaten 6fbddc8
more renaming
patrickvonplaten 03d2e79
don't expose unused configs
patrickvonplaten 7121611
final renaming fixes
patrickvonplaten a5e5d2a
remove x attn mask when not necessary
patrickvonplaten 669118d
configure kakao script to use new class embedding config
williamberman f79f8f2
fix copies
williamberman 226fca4
[tests] UnCLIPScheduler
williamberman b3c9b26
finish x attn
patrickvonplaten 6124048
Merge branch 'kakaobrain_unclip' of https://github.com/williamberman/…
patrickvonplaten 2853c1f
finish
patrickvonplaten 1c7fae6
remove more
patrickvonplaten 5d7f3d4
rename condition blocks
patrickvonplaten 0080668
clean more
patrickvonplaten ecaf203
Apply suggestions from code review
patrickvonplaten 285520b
up
patrickvonplaten 9d162d9
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patrickvonplaten cbc2ba0
fix
patrickvonplaten 8cb29fc
[tests] UnCLIPPipelineFastTests
williamberman 231d4e6
remove unused imports
williamberman 168b492
[tests] UnCLIPPipelineIntegrationTests
williamberman ec64638
correct
patrickvonplaten 90c383d
Merge branch 'kakaobrain_unclip' of https://github.com/williamberman/…
patrickvonplaten 51f3af3
make style
patrickvonplaten File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
1,159 changes: 1,159 additions & 0 deletions
1,159
scripts/convert_kakao_brain_unclip_to_diffusers.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
from ..configuration_utils import ConfigMixin, register_to_config | ||
from ..modeling_utils import ModelMixin | ||
from ..utils import BaseOutput | ||
from .attention import BasicTransformerBlock | ||
from .embeddings import TimestepEmbedding, Timesteps | ||
|
||
|
||
@dataclass | ||
class PriorTransformerOutput(BaseOutput): | ||
""" | ||
Args: | ||
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): | ||
The predicted CLIP image embedding conditioned on the CLIP text embedding input. | ||
""" | ||
|
||
predicted_image_embedding: torch.FloatTensor | ||
|
||
|
||
class PriorTransformer(ModelMixin, ConfigMixin): | ||
""" | ||
The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the | ||
transformer predicts the image embeddings through a denoising diffusion process. | ||
|
||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library | ||
implements for all the models (such as downloading or saving, etc.) | ||
|
||
For more details, see the original paper: https://arxiv.org/abs/2204.06125 | ||
|
||
Parameters: | ||
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. | ||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. | ||
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. | ||
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP | ||
image embeddings and text embeddings are both the same dimension. | ||
num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the | ||
length of the prompt after it has been tokenized. | ||
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the | ||
projected hidden_states. The actual length of the used hidden_states is `num_embeddings + | ||
additional_embeddings`. | ||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | ||
|
||
""" | ||
|
||
@register_to_config | ||
def __init__( | ||
self, | ||
num_attention_heads: int = 32, | ||
attention_head_dim: int = 64, | ||
num_layers: int = 20, | ||
embedding_dim: int = 768, | ||
num_embeddings=77, | ||
additional_embeddings=4, | ||
dropout: float = 0.0, | ||
): | ||
super().__init__() | ||
self.num_attention_heads = num_attention_heads | ||
self.attention_head_dim = attention_head_dim | ||
inner_dim = num_attention_heads * attention_head_dim | ||
self.additional_embeddings = additional_embeddings | ||
|
||
self.time_proj = Timesteps(inner_dim, True, 0) | ||
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim) | ||
|
||
self.proj_in = nn.Linear(embedding_dim, inner_dim) | ||
|
||
self.embedding_proj = nn.Linear(embedding_dim, inner_dim) | ||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) | ||
|
||
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) | ||
|
||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.transformer_blocks = nn.ModuleList( | ||
[ | ||
BasicTransformerBlock( | ||
inner_dim, | ||
num_attention_heads, | ||
attention_head_dim, | ||
dropout=dropout, | ||
activation_fn="gelu", | ||
attention_bias=True, | ||
) | ||
for d in range(num_layers) | ||
] | ||
) | ||
|
||
self.norm_out = nn.LayerNorm(inner_dim) | ||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim) | ||
|
||
causal_attention_mask = torch.full( | ||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf") | ||
) | ||
causal_attention_mask.triu_(1) | ||
causal_attention_mask = causal_attention_mask[None, ...] | ||
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) | ||
|
||
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim)) | ||
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim)) | ||
|
||
def forward( | ||
self, | ||
hidden_states, | ||
timestep: Union[torch.Tensor, float, int], | ||
proj_embedding: torch.FloatTensor, | ||
encoder_hidden_states: torch.FloatTensor, | ||
attention_mask: Optional[torch.BoolTensor] = None, | ||
return_dict: bool = True, | ||
): | ||
""" | ||
Args: | ||
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): | ||
x_t, the currently predicted image embeddings. | ||
timestep (`torch.long`): | ||
Current denoising step. | ||
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): | ||
Projected embedding vector the denoising process is conditioned on. | ||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): | ||
Hidden states of the text embeddings the denoising process is conditioned on. | ||
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): | ||
Text mask for the text embeddings. | ||
return_dict (`bool`, *optional*, defaults to `True`): | ||
Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain | ||
tuple. | ||
|
||
Returns: | ||
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: | ||
[`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||
returning a tuple, the first element is the sample tensor. | ||
""" | ||
batch_size = hidden_states.shape[0] | ||
|
||
timesteps = timestep | ||
if not torch.is_tensor(timesteps): | ||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) | ||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: | ||
timesteps = timesteps[None].to(hidden_states.device) | ||
|
||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | ||
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) | ||
|
||
timesteps_projected = self.time_proj(timesteps) | ||
|
||
# timesteps does not contain any weights and will always return f32 tensors | ||
# but time_embedding might be fp16, so we need to cast here. | ||
timesteps_projected = timesteps_projected.to(dtype=self.dtype) | ||
time_embeddings = self.time_embedding(timesteps_projected) | ||
|
||
proj_embeddings = self.embedding_proj(proj_embedding) | ||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) | ||
hidden_states = self.proj_in(hidden_states) | ||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) | ||
positional_embeddings = self.positional_embedding.to(hidden_states.dtype) | ||
|
||
hidden_states = torch.cat( | ||
[ | ||
encoder_hidden_states, | ||
proj_embeddings[:, None, :], | ||
time_embeddings[:, None, :], | ||
hidden_states[:, None, :], | ||
prd_embedding, | ||
], | ||
dim=1, | ||
) | ||
|
||
hidden_states = hidden_states + positional_embeddings | ||
|
||
if attention_mask is not None: | ||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 | ||
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) | ||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) | ||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) | ||
|
||
for block in self.transformer_blocks: | ||
hidden_states = block(hidden_states, attention_mask=attention_mask) | ||
|
||
hidden_states = self.norm_out(hidden_states) | ||
hidden_states = hidden_states[:, -1] | ||
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) | ||
|
||
if not return_dict: | ||
return (predicted_image_embedding,) | ||
|
||
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) | ||
|
||
def post_process_latents(self, prior_latents): | ||
prior_latents = (prior_latents * self.clip_std) + self.clip_mean | ||
return prior_latents |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.