Skip to content

[refactor embeddings] gligen + ip-adapter #6244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn

from ..models.embeddings import ImageProjection, MLPProjection, Resampler
from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
USE_PEFT_BACKEND,
Expand Down Expand Up @@ -712,7 +712,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
cross_attention_dim = state_dict["proj.3.weight"].shape[0]

image_projection = MLPProjection(
image_projection = IPAdapterFullImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
)

Expand All @@ -730,7 +730,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
hidden_dims = state_dict["latents"].shape[2]
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64

image_projection = Resampler(
image_projection = IPAdapterPlusImageProjection(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
Expand Down Expand Up @@ -780,7 +780,7 @@ def _load_ip_adapter_weights(self, state_dict):
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]

# Set encoder_hid_proj after loading ip_adapter weights,
# because `Resampler` also has `attn_processors`.
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None

# set ip-adapter cross-attention processors & load state_dict
Expand Down
37 changes: 21 additions & 16 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def forward(self, image_embeds: torch.FloatTensor):
return image_embeds


class MLPProjection(nn.Module):
class IPAdapterFullImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
super().__init__()
from .attention import FeedForward
Expand Down Expand Up @@ -621,29 +621,34 @@ def shape(x):
return a[:, 0, :] # cls_token


class FourierEmbedder(nn.Module):
def __init__(self, num_freqs=64, temperature=100):
super().__init__()
def get_fourier_embeds_from_boundingbox(embed_dim, box):
"""
Args:
embed_dim: int
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
Returns:
[B x N x embed_dim] tensor of positional embeddings
"""

batch_size, num_boxes = box.shape[:2]

self.num_freqs = num_freqs
self.temperature = temperature
emb = 100 ** (torch.arange(embed_dim) / embed_dim)
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
emb = emb * box.unsqueeze(-1)

freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
freq_bands = freq_bands[None, None, None]
self.register_buffer("freq_bands", freq_bands, persistent=False)
emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)

def __call__(self, x):
x = self.freq_bands * x.unsqueeze(-1)
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
return emb


class PositionNet(nn.Module):
class GLIGENTextBoundingboxProjection(nn.Module):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten
Any rules around when we call it "Projection" vs "Embedding"?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I follow this rule of thumb.

When you're representing some feature (time, bbox coordinates, tokens, etc.) in the latent space for the first time, it's better to call those representations embeddings. The subsequent operations performed on those representations are projections (more aligned with the literature in linear algebra).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like GLIGENTextBoundingboxProjection - it's nicely specific

def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
super().__init__()
self.positive_len = positive_len
self.out_dim = out_dim

self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.fourier_embedder_dim = fourier_freqs
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy

if isinstance(out_dim, tuple):
Expand Down Expand Up @@ -692,7 +697,7 @@ def forward(
masks = masks.unsqueeze(-1)

# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C

# learnable null embedding
xyxy_null = self.null_position_feature.view(1, 1, -1)
Expand Down Expand Up @@ -787,7 +792,7 @@ def forward(self, caption):
return hidden_states


class Resampler(nn.Module):
class IPAdapterPlusImageProjection(nn.Module):
"""Resampler of IP-Adapter Plus.

Args:
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
)
from .embeddings import (
GaussianFourierProjection,
GLIGENTextBoundingboxProjection,
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
PositionNet,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
Expand Down Expand Up @@ -615,7 +615,7 @@ def __init__(
positive_len = cross_attention_dim[0]

feature_type = "text-only" if attention_type == "gated" else "text-image"
self.position_net = PositionNet(
self.position_net = GLIGENTextBoundingboxProjection(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __call__(self, x):
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)


class PositionNet(nn.Module):
class GLIGENTextBoundingboxProjection(nn.Module):
def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8):
super().__init__()
self.positive_len = positive_len
Expand Down Expand Up @@ -820,7 +820,7 @@ def __init__(
positive_len = cross_attention_dim[0]

feature_type = "text-only" if attention_type == "gated" else "text-image"
self.position_net = PositionNet(
self.position_net = GLIGENTextBoundingboxProjection(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def __call__(
)
gligen_phrases = gligen_phrases[:max_objs]
gligen_boxes = gligen_boxes[:max_objs]
# prepare batched input to the PositionNet (boxes, phrases, mask)
# prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask)
# Get tokens for phrases from pre-trained CLIPTokenizer
tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device)
# For the token, we use the same pre-trained text encoder
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
from diffusers.models.embeddings import ImageProjection, Resampler
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
Expand Down Expand Up @@ -133,7 +133,7 @@ def create_ip_adapter_plus_state_dict(model):

# "image_proj" (ImageProjection layer weights)
cross_attention_dim = model.config["cross_attention_dim"]
image_projection = Resampler(
image_projection = IPAdapterPlusImageProjection(
embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4
)

Expand Down