Skip to content

kakaobrain unCLIP #1428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Dec 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
1b7f73c
[wip] attention block updates
williamberman Dec 14, 2022
ef52835
[wip] unCLIP unet decoder and super res
williamberman Dec 14, 2022
3acc487
[wip] unCLIP prior transformer
williamberman Dec 16, 2022
6087a45
[wip] scheduler changes
williamberman Dec 12, 2022
2cc5aea
[wip] text proj utility class
williamberman Dec 16, 2022
06ae666
[wip] UnCLIPPipeline
williamberman Nov 25, 2022
c911fe6
[wip] kakaobrain unCLIP convert script
williamberman Nov 25, 2022
f4d35a0
Merge branch 'main' into kakaobrain_unclip
patrickvonplaten Dec 17, 2022
de7a56f
[unCLIP pipeline] fixes re: @patrickvonplaten
williamberman Dec 17, 2022
710f648
UNCLIPScheduler re: @patrickvonplaten
williamberman Dec 17, 2022
2b30fbc
mask -> attention_mask re: @patrickvonplaten
williamberman Dec 17, 2022
5b20823
[DDPMScheduler] remove leftover change
williamberman Dec 17, 2022
d8957b2
[docs] PriorTransformer
williamberman Dec 18, 2022
301c01e
[docs] UNet2DConditionModel and UNet2DModel
williamberman Dec 18, 2022
e64a9c5
[nit] UNCLIPScheduler -> UnCLIPScheduler
williamberman Dec 18, 2022
a866ded
[docs] SchedulingUnCLIP
williamberman Dec 18, 2022
a8207aa
[docs] UnCLIPTextProjModel
williamberman Dec 18, 2022
344a1e8
refactor
patrickvonplaten Dec 18, 2022
5428daf
finish licenses
patrickvonplaten Dec 18, 2022
8c803b4
rename all to attention_mask and prep in models
patrickvonplaten Dec 18, 2022
6fbddc8
more renaming
patrickvonplaten Dec 18, 2022
03d2e79
don't expose unused configs
patrickvonplaten Dec 18, 2022
7121611
final renaming fixes
patrickvonplaten Dec 18, 2022
a5e5d2a
remove x attn mask when not necessary
patrickvonplaten Dec 18, 2022
669118d
configure kakao script to use new class embedding config
williamberman Dec 18, 2022
f79f8f2
fix copies
williamberman Dec 18, 2022
226fca4
[tests] UnCLIPScheduler
williamberman Dec 18, 2022
b3c9b26
finish x attn
patrickvonplaten Dec 18, 2022
6124048
Merge branch 'kakaobrain_unclip' of https://github.com/williamberman/…
patrickvonplaten Dec 18, 2022
2853c1f
finish
patrickvonplaten Dec 18, 2022
1c7fae6
remove more
patrickvonplaten Dec 18, 2022
5d7f3d4
rename condition blocks
patrickvonplaten Dec 18, 2022
0080668
clean more
patrickvonplaten Dec 18, 2022
ecaf203
Apply suggestions from code review
patrickvonplaten Dec 18, 2022
285520b
up
patrickvonplaten Dec 18, 2022
9d162d9
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patrickvonplaten Dec 18, 2022
cbc2ba0
fix
patrickvonplaten Dec 18, 2022
8cb29fc
[tests] UnCLIPPipelineFastTests
williamberman Dec 18, 2022
231d4e6
remove unused imports
williamberman Dec 18, 2022
168b492
[tests] UnCLIPPipelineIntegrationTests
williamberman Dec 18, 2022
ec64638
correct
patrickvonplaten Dec 18, 2022
90c383d
Merge branch 'kakaobrain_unclip' of https://github.com/williamberman/…
patrickvonplaten Dec 18, 2022
51f3af3
make style
patrickvonplaten Dec 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,159 changes: 1,159 additions & 0 deletions scripts/convert_kakao_brain_unclip_to_diffusers.py

Large diffs are not rendered by default.

12 changes: 11 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@
from .utils.dummy_pt_objects import * # noqa F403
else:
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .models import (
AutoencoderKL,
PriorTransformer,
Transformer2DModel,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
VQModel,
)
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
Expand Down Expand Up @@ -63,6 +71,7 @@
RePaintScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
UnCLIPScheduler,
VQDiffusionScheduler,
)
from .training_utils import EMAModel
Expand Down Expand Up @@ -96,6 +105,7 @@
StableDiffusionPipeline,
StableDiffusionPipelineSafe,
StableDiffusionUpscalePipeline,
UnCLIPPipeline,
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

if is_torch_available():
from .attention import Transformer2DModel
from .prior_transformer import PriorTransformer
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
Expand Down
124 changes: 94 additions & 30 deletions src/diffusers/models/attention.py

Large diffs are not rendered by default.

194 changes: 194 additions & 0 deletions src/diffusers/models/prior_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from dataclasses import dataclass
from typing import Optional, Union

import torch
import torch.nn.functional as F
from torch import nn

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .attention import BasicTransformerBlock
from .embeddings import TimestepEmbedding, Timesteps


@dataclass
class PriorTransformerOutput(BaseOutput):
"""
Args:
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
"""

predicted_image_embedding: torch.FloatTensor


class PriorTransformer(ModelMixin, ConfigMixin):
"""
The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the
transformer predicts the image embeddings through a denoising diffusion process.

This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)

For more details, see the original paper: https://arxiv.org/abs/2204.06125

Parameters:
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP
image embeddings and text embeddings are both the same dimension.
num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the
length of the prompt after it has been tokenized.
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
projected hidden_states. The actual length of the used hidden_states is `num_embeddings +
additional_embeddings`.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.

"""

@register_to_config
def __init__(
self,
num_attention_heads: int = 32,
attention_head_dim: int = 64,
num_layers: int = 20,
embedding_dim: int = 768,
num_embeddings=77,
additional_embeddings=4,
dropout: float = 0.0,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.additional_embeddings = additional_embeddings

self.time_proj = Timesteps(inner_dim, True, 0)
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)

self.proj_in = nn.Linear(embedding_dim, inner_dim)

self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)

self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))

self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))

self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
activation_fn="gelu",
attention_bias=True,
)
for d in range(num_layers)
]
)

self.norm_out = nn.LayerNorm(inner_dim)
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)

causal_attention_mask = torch.full(
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf")
)
causal_attention_mask.triu_(1)
causal_attention_mask = causal_attention_mask[None, ...]
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)

self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))

def forward(
self,
hidden_states,
timestep: Union[torch.Tensor, float, int],
proj_embedding: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.BoolTensor] = None,
return_dict: bool = True,
):
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
x_t, the currently predicted image embeddings.
timestep (`torch.long`):
Current denoising step.
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
Projected embedding vector the denoising process is conditioned on.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
Hidden states of the text embeddings the denoising process is conditioned on.
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
Text mask for the text embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain
tuple.

Returns:
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
[`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
batch_size = hidden_states.shape[0]

timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(hidden_states.device)

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)

timesteps_projected = self.time_proj(timesteps)

# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might be fp16, so we need to cast here.
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
time_embeddings = self.time_embedding(timesteps_projected)

proj_embeddings = self.embedding_proj(proj_embedding)
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
hidden_states = self.proj_in(hidden_states)
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)

hidden_states = torch.cat(
[
encoder_hidden_states,
proj_embeddings[:, None, :],
time_embeddings[:, None, :],
hidden_states[:, None, :],
prd_embedding,
],
dim=1,
)

hidden_states = hidden_states + positional_embeddings

if attention_mask is not None:
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)

for block in self.transformer_blocks:
hidden_states = block(hidden_states, attention_mask=attention_mask)

hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states[:, -1]
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)

if not return_dict:
return (predicted_image_embedding,)

return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)

def post_process_latents(self, prior_latents):
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
return prior_latents
16 changes: 15 additions & 1 deletion src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,14 @@ def __init__(
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

if temb_channels is not None:
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
if self.time_embedding_norm == "default":
time_emb_proj_out_channels = out_channels
elif self.time_embedding_norm == "scale_shift":
time_emb_proj_out_channels = out_channels * 2
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")

self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
else:
self.time_emb_proj = None

Expand Down Expand Up @@ -465,9 +472,16 @@ def forward(self, input_tensor, temb):

if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]

if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb

hidden_states = self.norm2(hidden_states)

if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift

hidden_states = self.nonlinearity(hidden_states)

hidden_states = self.dropout(hidden_states)
Expand Down
11 changes: 10 additions & 1 deletion src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
Expand All @@ -66,6 +68,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
"""

@register_to_config
Expand All @@ -88,6 +92,8 @@ def __init__(
attention_head_dim: int = 8,
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
add_attention: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -130,6 +136,7 @@ def __init__(
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)

Expand All @@ -140,9 +147,10 @@ def __init__(
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
resnet_time_scale_shift=resnet_time_scale_shift,
attn_num_head_channels=attention_head_dim,
resnet_groups=norm_num_groups,
add_attention=add_attention,
)

# up
Expand All @@ -167,6 +175,7 @@ def __init__(
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down
Loading