@@ -86,12 +86,25 @@ def get_3d_sincos_pos_embed(
86
86
temporal_interpolation_scale : float = 1.0 ,
87
87
) -> np .ndarray :
88
88
r"""
89
+ Creates 3D sinusoidal positional embeddings.
90
+
89
91
Args:
90
92
embed_dim (`int`):
93
+ The embedding dimension of inputs. It must be divisible by 16.
91
94
spatial_size (`int` or `Tuple[int, int]`):
95
+ The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
96
+ spatial dimensions (height and width).
92
97
temporal_size (`int`):
98
+ The temporal dimension of postional embeddings (number of frames).
93
99
spatial_interpolation_scale (`float`, defaults to 1.0):
100
+ Scale factor for spatial grid interpolation.
94
101
temporal_interpolation_scale (`float`, defaults to 1.0):
102
+ Scale factor for temporal grid interpolation.
103
+
104
+ Returns:
105
+ `np.ndarray`:
106
+ The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
107
+ embed_dim]`.
95
108
"""
96
109
if embed_dim % 4 != 0 :
97
110
raise ValueError ("`embed_dim` must be divisible by 4" )
@@ -129,8 +142,24 @@ def get_2d_sincos_pos_embed(
129
142
embed_dim , grid_size , cls_token = False , extra_tokens = 0 , interpolation_scale = 1.0 , base_size = 16
130
143
):
131
144
"""
132
- grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
133
- [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
145
+ Creates 2D sinusoidal positional embeddings.
146
+
147
+ Args:
148
+ embed_dim (`int`):
149
+ The embedding dimension.
150
+ grid_size (`int`):
151
+ The size of the grid height and width.
152
+ cls_token (`bool`, defaults to `False`):
153
+ Whether or not to add a classification token.
154
+ extra_tokens (`int`, defaults to `0`):
155
+ The number of extra tokens to add.
156
+ interpolation_scale (`float`, defaults to `1.0`):
157
+ The scale of the interpolation.
158
+
159
+ Returns:
160
+ pos_embed (`np.ndarray`):
161
+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
162
+ embed_dim]` if using cls_token
134
163
"""
135
164
if isinstance (grid_size , int ):
136
165
grid_size = (grid_size , grid_size )
@@ -148,6 +177,16 @@ def get_2d_sincos_pos_embed(
148
177
149
178
150
179
def get_2d_sincos_pos_embed_from_grid (embed_dim , grid ):
180
+ r"""
181
+ This function generates 2D sinusoidal positional embeddings from a grid.
182
+
183
+ Args:
184
+ embed_dim (`int`): The embedding dimension.
185
+ grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
186
+
187
+ Returns:
188
+ `np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
189
+ """
151
190
if embed_dim % 2 != 0 :
152
191
raise ValueError ("embed_dim must be divisible by 2" )
153
192
@@ -161,7 +200,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
161
200
162
201
def get_1d_sincos_pos_embed_from_grid (embed_dim , pos ):
163
202
"""
164
- embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
203
+ This function generates 1D positional embeddings from a grid.
204
+
205
+ Args:
206
+ embed_dim (`int`): The embedding dimension `D`
207
+ pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
208
+
209
+ Returns:
210
+ `numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
165
211
"""
166
212
if embed_dim % 2 != 0 :
167
213
raise ValueError ("embed_dim must be divisible by 2" )
@@ -181,7 +227,22 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
181
227
182
228
183
229
class PatchEmbed (nn .Module ):
184
- """2D Image to Patch Embedding with support for SD3 cropping."""
230
+ """
231
+ 2D Image to Patch Embedding with support for SD3 cropping.
232
+
233
+ Args:
234
+ height (`int`, defaults to `224`): The height of the image.
235
+ width (`int`, defaults to `224`): The width of the image.
236
+ patch_size (`int`, defaults to `16`): The size of the patches.
237
+ in_channels (`int`, defaults to `3`): The number of input channels.
238
+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
239
+ layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
240
+ flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
241
+ bias (`bool`, defaults to `True`): Whether or not to use bias.
242
+ interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
243
+ pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
244
+ pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
245
+ """
185
246
186
247
def __init__ (
187
248
self ,
@@ -289,7 +350,15 @@ def forward(self, latent):
289
350
290
351
291
352
class LuminaPatchEmbed (nn .Module ):
292
- """2D Image to Patch Embedding with support for Lumina-T2X"""
353
+ """
354
+ 2D Image to Patch Embedding with support for Lumina-T2X
355
+
356
+ Args:
357
+ patch_size (`int`, defaults to `2`): The size of the patches.
358
+ in_channels (`int`, defaults to `4`): The number of input channels.
359
+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
360
+ bias (`bool`, defaults to `True`): Whether or not to use bias.
361
+ """
293
362
294
363
def __init__ (self , patch_size = 2 , in_channels = 4 , embed_dim = 768 , bias = True ):
295
364
super ().__init__ ()
@@ -675,6 +744,20 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
675
744
676
745
677
746
def get_2d_rotary_pos_embed_from_grid (embed_dim , grid , use_real = False ):
747
+ """
748
+ Get 2D RoPE from grid.
749
+
750
+ Args:
751
+ embed_dim: (`int`):
752
+ The embedding dimension size, corresponding to hidden_size_head.
753
+ grid (`np.ndarray`):
754
+ The grid of the positional embedding.
755
+ use_real (`bool`):
756
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
757
+
758
+ Returns:
759
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
760
+ """
678
761
assert embed_dim % 4 == 0
679
762
680
763
# use half of dimensions to encode grid_h
@@ -695,6 +778,23 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
695
778
696
779
697
780
def get_2d_rotary_pos_embed_lumina (embed_dim , len_h , len_w , linear_factor = 1.0 , ntk_factor = 1.0 ):
781
+ """
782
+ Get 2D RoPE from grid.
783
+
784
+ Args:
785
+ embed_dim: (`int`):
786
+ The embedding dimension size, corresponding to hidden_size_head.
787
+ grid (`np.ndarray`):
788
+ The grid of the positional embedding.
789
+ linear_factor (`float`):
790
+ The linear factor of the positional embedding, which is used to scale the positional embedding in the linear
791
+ layer.
792
+ ntk_factor (`float`):
793
+ The ntk factor of the positional embedding, which is used to scale the positional embedding in the ntk layer.
794
+
795
+ Returns:
796
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
797
+ """
698
798
assert embed_dim % 4 == 0
699
799
700
800
emb_h = get_1d_rotary_pos_embed (
0 commit comments