Skip to content

Fix some documentation in ./src/diffusers/models/embeddings.py for demo #9579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 105 additions & 5 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,25 @@ def get_3d_sincos_pos_embed(
temporal_interpolation_scale: float = 1.0,
) -> np.ndarray:
r"""
Creates 3D sinusoidal positional embeddings.

Args:
embed_dim (`int`):
The embedding dimension of inputs. It must be divisible by 16.
spatial_size (`int` or `Tuple[int, int]`):
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
spatial dimensions (height and width).
temporal_size (`int`):
The temporal dimension of postional embeddings (number of frames).
spatial_interpolation_scale (`float`, defaults to 1.0):
Scale factor for spatial grid interpolation.
temporal_interpolation_scale (`float`, defaults to 1.0):
Scale factor for temporal grid interpolation.

Returns:
`np.ndarray`:
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
embed_dim]`.
"""
if embed_dim % 4 != 0:
raise ValueError("`embed_dim` must be divisible by 4")
Expand Down Expand Up @@ -129,8 +142,24 @@ def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
Creates 2D sinusoidal positional embeddings.

Args:
embed_dim (`int`):
The embedding dimension.
grid_size (`int`):
The size of the grid height and width.
cls_token (`bool`, defaults to `False`):
Whether or not to add a classification token.
extra_tokens (`int`, defaults to `0`):
The number of extra tokens to add.
interpolation_scale (`float`, defaults to `1.0`):
The scale of the interpolation.

Returns:
pos_embed (`np.ndarray`):
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
embed_dim]` if using cls_token
"""
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
Expand All @@ -148,6 +177,16 @@ def get_2d_sincos_pos_embed(


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
r"""
This function generates 2D sinusoidal positional embeddings from a grid.

Args:
embed_dim (`int`): The embedding dimension.
grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.

Returns:
`np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")

Expand All @@ -161,7 +200,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
This function generates 1D positional embeddings from a grid.

Args:
embed_dim (`int`): The embedding dimension `D`
pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`

Returns:
`numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
Expand All @@ -181,7 +227,22 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):


class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding with support for SD3 cropping."""
"""
2D Image to Patch Embedding with support for SD3 cropping.

Args:
height (`int`, defaults to `224`): The height of the image.
width (`int`, defaults to `224`): The width of the image.
patch_size (`int`, defaults to `16`): The size of the patches.
in_channels (`int`, defaults to `3`): The number of input channels.
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
bias (`bool`, defaults to `True`): Whether or not to use bias.
interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
"""

def __init__(
self,
Expand Down Expand Up @@ -289,7 +350,15 @@ def forward(self, latent):


class LuminaPatchEmbed(nn.Module):
"""2D Image to Patch Embedding with support for Lumina-T2X"""
"""
2D Image to Patch Embedding with support for Lumina-T2X

Args:
patch_size (`int`, defaults to `2`): The size of the patches.
in_channels (`int`, defaults to `4`): The number of input channels.
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
bias (`bool`, defaults to `True`): Whether or not to use bias.
"""

def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
super().__init__()
Expand Down Expand Up @@ -675,6 +744,20 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):


def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
"""
Get 2D RoPE from grid.

Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
grid (`np.ndarray`):
The grid of the positional embedding.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.

Returns:
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
"""
assert embed_dim % 4 == 0

# use half of dimensions to encode grid_h
Expand All @@ -695,6 +778,23 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):


def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
"""
Get 2D RoPE from grid.

Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
grid (`np.ndarray`):
The grid of the positional embedding.
linear_factor (`float`):
The linear factor of the positional embedding, which is used to scale the positional embedding in the linear
layer.
ntk_factor (`float`):
The ntk factor of the positional embedding, which is used to scale the positional embedding in the ntk layer.

Returns:
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
"""
assert embed_dim % 4 == 0

emb_h = get_1d_rotary_pos_embed(
Expand Down
Loading