Skip to content

Commit 2dcf64b

Browse files
kakaobrain unCLIP (#1428)
* [wip] attention block updates * [wip] unCLIP unet decoder and super res * [wip] unCLIP prior transformer * [wip] scheduler changes * [wip] text proj utility class * [wip] UnCLIPPipeline * [wip] kakaobrain unCLIP convert script * [unCLIP pipeline] fixes re: @patrickvonplaten remove callbacks move denoising loops into call function * UNCLIPScheduler re: @patrickvonplaten Revert changes to DDPMScheduler. Make UNCLIPScheduler, a modified DDPM scheduler with changes to support karlo * mask -> attention_mask re: @patrickvonplaten * [DDPMScheduler] remove leftover change * [docs] PriorTransformer * [docs] UNet2DConditionModel and UNet2DModel * [nit] UNCLIPScheduler -> UnCLIPScheduler matches existing unclip naming better * [docs] SchedulingUnCLIP * [docs] UnCLIPTextProjModel * refactor * finish licenses * rename all to attention_mask and prep in models * more renaming * don't expose unused configs * final renaming fixes * remove x attn mask when not necessary * configure kakao script to use new class embedding config * fix copies * [tests] UnCLIPScheduler * finish x attn * finish * remove more * rename condition blocks * clean more * Apply suggestions from code review * up * fix * [tests] UnCLIPPipelineFastTests * remove unused imports * [tests] UnCLIPPipelineIntegrationTests * correct * make style Co-authored-by: Patrick von Platen <[email protected]>
1 parent 402b956 commit 2dcf64b

21 files changed

+3594
-118
lines changed

scripts/convert_kakao_brain_unclip_to_diffusers.py

Lines changed: 1159 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,15 @@
2525
from .utils.dummy_pt_objects import * # noqa F403
2626
else:
2727
from .modeling_utils import ModelMixin
28-
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
28+
from .models import (
29+
AutoencoderKL,
30+
PriorTransformer,
31+
Transformer2DModel,
32+
UNet1DModel,
33+
UNet2DConditionModel,
34+
UNet2DModel,
35+
VQModel,
36+
)
2937
from .optimization import (
3038
get_constant_schedule,
3139
get_constant_schedule_with_warmup,
@@ -63,6 +71,7 @@
6371
RePaintScheduler,
6472
SchedulerMixin,
6573
ScoreSdeVeScheduler,
74+
UnCLIPScheduler,
6675
VQDiffusionScheduler,
6776
)
6877
from .training_utils import EMAModel
@@ -96,6 +105,7 @@
96105
StableDiffusionPipeline,
97106
StableDiffusionPipelineSafe,
98107
StableDiffusionUpscalePipeline,
108+
UnCLIPPipeline,
99109
VersatileDiffusionDualGuidedPipeline,
100110
VersatileDiffusionImageVariationPipeline,
101111
VersatileDiffusionPipeline,

src/diffusers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
if is_torch_available():
1919
from .attention import Transformer2DModel
20+
from .prior_transformer import PriorTransformer
2021
from .unet_1d import UNet1DModel
2122
from .unet_2d import UNet2DModel
2223
from .unet_2d_condition import UNet2DConditionModel

src/diffusers/models/attention.py

Lines changed: 94 additions & 30 deletions
Large diffs are not rendered by default.
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from dataclasses import dataclass
2+
from typing import Optional, Union
3+
4+
import torch
5+
import torch.nn.functional as F
6+
from torch import nn
7+
8+
from ..configuration_utils import ConfigMixin, register_to_config
9+
from ..modeling_utils import ModelMixin
10+
from ..utils import BaseOutput
11+
from .attention import BasicTransformerBlock
12+
from .embeddings import TimestepEmbedding, Timesteps
13+
14+
15+
@dataclass
16+
class PriorTransformerOutput(BaseOutput):
17+
"""
18+
Args:
19+
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
20+
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
21+
"""
22+
23+
predicted_image_embedding: torch.FloatTensor
24+
25+
26+
class PriorTransformer(ModelMixin, ConfigMixin):
27+
"""
28+
The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the
29+
transformer predicts the image embeddings through a denoising diffusion process.
30+
31+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
32+
implements for all the models (such as downloading or saving, etc.)
33+
34+
For more details, see the original paper: https://arxiv.org/abs/2204.06125
35+
36+
Parameters:
37+
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
38+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
39+
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
40+
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP
41+
image embeddings and text embeddings are both the same dimension.
42+
num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the
43+
length of the prompt after it has been tokenized.
44+
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
45+
projected hidden_states. The actual length of the used hidden_states is `num_embeddings +
46+
additional_embeddings`.
47+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
48+
49+
"""
50+
51+
@register_to_config
52+
def __init__(
53+
self,
54+
num_attention_heads: int = 32,
55+
attention_head_dim: int = 64,
56+
num_layers: int = 20,
57+
embedding_dim: int = 768,
58+
num_embeddings=77,
59+
additional_embeddings=4,
60+
dropout: float = 0.0,
61+
):
62+
super().__init__()
63+
self.num_attention_heads = num_attention_heads
64+
self.attention_head_dim = attention_head_dim
65+
inner_dim = num_attention_heads * attention_head_dim
66+
self.additional_embeddings = additional_embeddings
67+
68+
self.time_proj = Timesteps(inner_dim, True, 0)
69+
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
70+
71+
self.proj_in = nn.Linear(embedding_dim, inner_dim)
72+
73+
self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
74+
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
75+
76+
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
77+
78+
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
79+
80+
self.transformer_blocks = nn.ModuleList(
81+
[
82+
BasicTransformerBlock(
83+
inner_dim,
84+
num_attention_heads,
85+
attention_head_dim,
86+
dropout=dropout,
87+
activation_fn="gelu",
88+
attention_bias=True,
89+
)
90+
for d in range(num_layers)
91+
]
92+
)
93+
94+
self.norm_out = nn.LayerNorm(inner_dim)
95+
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
96+
97+
causal_attention_mask = torch.full(
98+
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf")
99+
)
100+
causal_attention_mask.triu_(1)
101+
causal_attention_mask = causal_attention_mask[None, ...]
102+
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
103+
104+
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
105+
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
106+
107+
def forward(
108+
self,
109+
hidden_states,
110+
timestep: Union[torch.Tensor, float, int],
111+
proj_embedding: torch.FloatTensor,
112+
encoder_hidden_states: torch.FloatTensor,
113+
attention_mask: Optional[torch.BoolTensor] = None,
114+
return_dict: bool = True,
115+
):
116+
"""
117+
Args:
118+
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
119+
x_t, the currently predicted image embeddings.
120+
timestep (`torch.long`):
121+
Current denoising step.
122+
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
123+
Projected embedding vector the denoising process is conditioned on.
124+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
125+
Hidden states of the text embeddings the denoising process is conditioned on.
126+
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
127+
Text mask for the text embeddings.
128+
return_dict (`bool`, *optional*, defaults to `True`):
129+
Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain
130+
tuple.
131+
132+
Returns:
133+
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
134+
[`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When
135+
returning a tuple, the first element is the sample tensor.
136+
"""
137+
batch_size = hidden_states.shape[0]
138+
139+
timesteps = timestep
140+
if not torch.is_tensor(timesteps):
141+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
142+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
143+
timesteps = timesteps[None].to(hidden_states.device)
144+
145+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
146+
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
147+
148+
timesteps_projected = self.time_proj(timesteps)
149+
150+
# timesteps does not contain any weights and will always return f32 tensors
151+
# but time_embedding might be fp16, so we need to cast here.
152+
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
153+
time_embeddings = self.time_embedding(timesteps_projected)
154+
155+
proj_embeddings = self.embedding_proj(proj_embedding)
156+
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
157+
hidden_states = self.proj_in(hidden_states)
158+
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
159+
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
160+
161+
hidden_states = torch.cat(
162+
[
163+
encoder_hidden_states,
164+
proj_embeddings[:, None, :],
165+
time_embeddings[:, None, :],
166+
hidden_states[:, None, :],
167+
prd_embedding,
168+
],
169+
dim=1,
170+
)
171+
172+
hidden_states = hidden_states + positional_embeddings
173+
174+
if attention_mask is not None:
175+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
176+
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
177+
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
178+
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
179+
180+
for block in self.transformer_blocks:
181+
hidden_states = block(hidden_states, attention_mask=attention_mask)
182+
183+
hidden_states = self.norm_out(hidden_states)
184+
hidden_states = hidden_states[:, -1]
185+
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
186+
187+
if not return_dict:
188+
return (predicted_image_embedding,)
189+
190+
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
191+
192+
def post_process_latents(self, prior_latents):
193+
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
194+
return prior_latents

src/diffusers/models/resnet.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,14 @@ def __init__(
405405
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
406406

407407
if temb_channels is not None:
408-
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
408+
if self.time_embedding_norm == "default":
409+
time_emb_proj_out_channels = out_channels
410+
elif self.time_embedding_norm == "scale_shift":
411+
time_emb_proj_out_channels = out_channels * 2
412+
else:
413+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
414+
415+
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
409416
else:
410417
self.time_emb_proj = None
411418

@@ -465,9 +472,16 @@ def forward(self, input_tensor, temb):
465472

466473
if temb is not None:
467474
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
475+
476+
if temb is not None and self.time_embedding_norm == "default":
468477
hidden_states = hidden_states + temb
469478

470479
hidden_states = self.norm2(hidden_states)
480+
481+
if temb is not None and self.time_embedding_norm == "scale_shift":
482+
scale, shift = torch.chunk(temb, 2, dim=1)
483+
hidden_states = hidden_states * (1 + scale) + shift
484+
471485
hidden_states = self.nonlinearity(hidden_states)
472486

473487
hidden_states = self.dropout(hidden_states)

src/diffusers/models/unet_2d.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
5555
down_block_types (`Tuple[str]`, *optional*, defaults to :
5656
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
5757
types.
58+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
59+
The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
5860
up_block_types (`Tuple[str]`, *optional*, defaults to :
5961
obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
6062
block_out_channels (`Tuple[int]`, *optional*, defaults to :
@@ -66,6 +68,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
6668
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
6769
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
6870
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
71+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
72+
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
6973
"""
7074

7175
@register_to_config
@@ -88,6 +92,8 @@ def __init__(
8892
attention_head_dim: int = 8,
8993
norm_num_groups: int = 32,
9094
norm_eps: float = 1e-5,
95+
resnet_time_scale_shift: str = "default",
96+
add_attention: bool = True,
9197
):
9298
super().__init__()
9399

@@ -130,6 +136,7 @@ def __init__(
130136
resnet_groups=norm_num_groups,
131137
attn_num_head_channels=attention_head_dim,
132138
downsample_padding=downsample_padding,
139+
resnet_time_scale_shift=resnet_time_scale_shift,
133140
)
134141
self.down_blocks.append(down_block)
135142

@@ -140,9 +147,10 @@ def __init__(
140147
resnet_eps=norm_eps,
141148
resnet_act_fn=act_fn,
142149
output_scale_factor=mid_block_scale_factor,
143-
resnet_time_scale_shift="default",
150+
resnet_time_scale_shift=resnet_time_scale_shift,
144151
attn_num_head_channels=attention_head_dim,
145152
resnet_groups=norm_num_groups,
153+
add_attention=add_attention,
146154
)
147155

148156
# up
@@ -167,6 +175,7 @@ def __init__(
167175
resnet_act_fn=act_fn,
168176
resnet_groups=norm_num_groups,
169177
attn_num_head_channels=attention_head_dim,
178+
resnet_time_scale_shift=resnet_time_scale_shift,
170179
)
171180
self.up_blocks.append(up_block)
172181
prev_output_channel = output_channel

0 commit comments

Comments
 (0)