@@ -514,7 +514,7 @@ def get_1d_rotary_pos_embed(
514
514
linear_factor = 1.0 ,
515
515
ntk_factor = 1.0 ,
516
516
repeat_interleave_real = True ,
517
- freqs_dtype = torch .float32 , # torch.float32 (hunyuan, stable audio) , torch.float64 (flux)
517
+ freqs_dtype = torch .float32 , # torch.float32, torch.float64 (flux)
518
518
):
519
519
"""
520
520
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -551,15 +551,18 @@ def get_1d_rotary_pos_embed(
551
551
t = torch .from_numpy (pos ).to (freqs .device ) # type: ignore # [S]
552
552
freqs = torch .outer (t , freqs ) # type: ignore # [S, D/2]
553
553
if use_real and repeat_interleave_real :
554
+ # flux, hunyuan-dit, cogvideox
554
555
freqs_cos = freqs .cos ().repeat_interleave (2 , dim = 1 ).float () # [S, D]
555
556
freqs_sin = freqs .sin ().repeat_interleave (2 , dim = 1 ).float () # [S, D]
556
557
return freqs_cos , freqs_sin
557
558
elif use_real :
559
+ # stable audio
558
560
freqs_cos = torch .cat ([freqs .cos (), freqs .cos ()], dim = - 1 ).float () # [S, D]
559
561
freqs_sin = torch .cat ([freqs .sin (), freqs .sin ()], dim = - 1 ).float () # [S, D]
560
562
return freqs_cos , freqs_sin
561
563
else :
562
- freqs_cis = torch .polar (torch .ones_like (freqs ), freqs ).float () # complex64 # [S, D/2]
564
+ # lumina
565
+ freqs_cis = torch .polar (torch .ones_like (freqs ), freqs ) # complex64 # [S, D/2]
563
566
return freqs_cis
564
567
565
568
@@ -590,11 +593,11 @@ def apply_rotary_emb(
590
593
cos , sin = cos .to (x .device ), sin .to (x .device )
591
594
592
595
if use_real_unbind_dim == - 1 :
593
- # Use for example in Lumina
596
+ # Used for flux, cogvideox, hunyuan-dit
594
597
x_real , x_imag = x .reshape (* x .shape [:- 1 ], - 1 , 2 ).unbind (- 1 ) # [B, S, H, D//2]
595
598
x_rotated = torch .stack ([- x_imag , x_real ], dim = - 1 ).flatten (3 )
596
599
elif use_real_unbind_dim == - 2 :
597
- # Use for example in Stable Audio
600
+ # Used for Stable Audio
598
601
x_real , x_imag = x .reshape (* x .shape [:- 1 ], 2 , - 1 ).unbind (- 2 ) # [B, S, H, D//2]
599
602
x_rotated = torch .cat ([- x_imag , x_real ], dim = - 1 )
600
603
else :
@@ -604,6 +607,7 @@ def apply_rotary_emb(
604
607
605
608
return out
606
609
else :
610
+ # used for lumina
607
611
x_rotated = torch .view_as_complex (x .float ().reshape (* x .shape [:- 1 ], - 1 , 2 ))
608
612
freqs_cis = freqs_cis .unsqueeze (2 )
609
613
x_out = torch .view_as_real (x_rotated * freqs_cis ).flatten (3 )
0 commit comments