Skip to content

Commit 4f495b0

Browse files
authored
rotary embedding refactor 2: update comments, fix dtype for use_real=False (#9312)
fix notes and dtype
1 parent 40c13fe commit 4f495b0

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/diffusers/models/embeddings.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def get_1d_rotary_pos_embed(
514514
linear_factor=1.0,
515515
ntk_factor=1.0,
516516
repeat_interleave_real=True,
517-
freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux)
517+
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
518518
):
519519
"""
520520
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -551,15 +551,18 @@ def get_1d_rotary_pos_embed(
551551
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
552552
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
553553
if use_real and repeat_interleave_real:
554+
# flux, hunyuan-dit, cogvideox
554555
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
555556
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
556557
return freqs_cos, freqs_sin
557558
elif use_real:
559+
# stable audio
558560
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
559561
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
560562
return freqs_cos, freqs_sin
561563
else:
562-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2]
564+
# lumina
565+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
563566
return freqs_cis
564567

565568

@@ -590,11 +593,11 @@ def apply_rotary_emb(
590593
cos, sin = cos.to(x.device), sin.to(x.device)
591594

592595
if use_real_unbind_dim == -1:
593-
# Use for example in Lumina
596+
# Used for flux, cogvideox, hunyuan-dit
594597
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
595598
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
596599
elif use_real_unbind_dim == -2:
597-
# Use for example in Stable Audio
600+
# Used for Stable Audio
598601
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
599602
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
600603
else:
@@ -604,6 +607,7 @@ def apply_rotary_emb(
604607

605608
return out
606609
else:
610+
# used for lumina
607611
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
608612
freqs_cis = freqs_cis.unsqueeze(2)
609613
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)

0 commit comments

Comments
 (0)