28
28
import torch
29
29
import torch .nn as nn
30
30
from transformers import PretrainedConfig
31
+ from transformers .utils import is_flash_attn_2_available
31
32
32
33
from vllm .model_executor .custom_op import CustomOp
33
34
from vllm .platforms import current_platform
34
35
36
+ if is_flash_attn_2_available ():
37
+ from flash_attn .layers .rotary import apply_rotary_emb
38
+
35
39
36
40
def _rotate_neox (x : torch .Tensor ) -> torch .Tensor :
37
41
x1 = x [..., :x .shape [- 1 ] // 2 ]
@@ -46,20 +50,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
46
50
return x .flatten (- 2 )
47
51
48
52
49
- def _apply_rotary_emb (
53
+ def _apply_rotary_emb_torch (
50
54
x : torch .Tensor ,
51
55
cos : torch .Tensor ,
52
56
sin : torch .Tensor ,
53
57
is_neox_style : bool ,
54
58
) -> torch .Tensor :
55
- """
56
- Args:
57
- x: [num_tokens, num_heads, head_size]
58
- cos: [num_tokens, head_size // 2]
59
- sin: [num_tokens, head_size // 2]
60
- is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
61
- positional embeddings.
62
- """
63
59
cos = cos .unsqueeze (- 2 ).to (x .dtype )
64
60
sin = sin .unsqueeze (- 2 ).to (x .dtype )
65
61
if is_neox_style :
@@ -75,6 +71,28 @@ def _apply_rotary_emb(
75
71
return torch .stack ((o1 , o2 ), dim = - 1 ).flatten (- 2 )
76
72
77
73
74
+ def _apply_rotary_emb (
75
+ x : torch .Tensor ,
76
+ cos : torch .Tensor ,
77
+ sin : torch .Tensor ,
78
+ is_neox_style : bool ,
79
+ use_flash_attn = False
80
+ ) -> torch .Tensor :
81
+ """
82
+ Args:
83
+ x: [num_tokens, num_heads, head_size]
84
+ cos: [num_tokens, head_size // 2]
85
+ sin: [num_tokens, head_size // 2]
86
+ is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
87
+ positional embeddings.
88
+ use_flash_attn: Whether to enable Flash Attention optimizations.
89
+ """
90
+ if use_flash_attn :
91
+ return apply_rotary_emb (x .unsqueeze (0 ), cos , sin , not is_neox_style ).squeeze (0 )
92
+ else :
93
+ return _apply_rotary_emb_torch (x , cos , sin , is_neox_style )
94
+
95
+
78
96
@CustomOp .register ("rotary_embedding" )
79
97
class RotaryEmbedding (CustomOp ):
80
98
"""Original rotary positional embedding."""
@@ -100,6 +118,10 @@ def __init__(
100
118
cache = cache .to (dtype )
101
119
self .cos_sin_cache : torch .Tensor
102
120
self .register_buffer ("cos_sin_cache" , cache , persistent = False )
121
+ if is_flash_attn_2_available ():
122
+ self ._use_flash_attn = True
123
+ else :
124
+ self ._use_flash_attn = False
103
125
104
126
def _compute_inv_freq (self , base : Union [int , float ]) -> torch .Tensor :
105
127
"""Compute the inverse frequency."""
@@ -141,14 +163,16 @@ def forward_native(
141
163
query = query .view (num_tokens , - 1 , self .head_size )
142
164
query_rot = query [..., :self .rotary_dim ]
143
165
query_pass = query [..., self .rotary_dim :]
144
- query_rot = _apply_rotary_emb (query_rot , cos , sin , self .is_neox_style )
166
+ query_rot = _apply_rotary_emb (query_rot , cos , sin , self .is_neox_style ,
167
+ self ._use_flash_attn )
145
168
query = torch .cat ((query_rot , query_pass ), dim = - 1 ).reshape (query_shape )
146
169
147
170
key_shape = key .shape
148
171
key = key .view (num_tokens , - 1 , self .head_size )
149
172
key_rot = key [..., :self .rotary_dim ]
150
173
key_pass = key [..., self .rotary_dim :]
151
- key_rot = _apply_rotary_emb (key_rot , cos , sin , self .is_neox_style )
174
+ key_rot = _apply_rotary_emb (key_rot , cos , sin , self .is_neox_style ,
175
+ self ._use_flash_attn )
152
176
key = torch .cat ((key_rot , key_pass ), dim = - 1 ).reshape (key_shape )
153
177
return query , key
154
178
@@ -309,9 +333,11 @@ def _apply_rotary_emb_neuron(
309
333
key = key .view (num_tokens , - 1 , self .head_size )
310
334
311
335
if self .rotary_dim == self .head_size :
312
- query = _apply_rotary_emb (query , cos , sin , self .is_neox_style )
336
+ query = _apply_rotary_emb (query , cos , sin , self .is_neox_style ,
337
+ self ._use_flash_attn )
313
338
query = query .reshape (query_shape )
314
- key = _apply_rotary_emb (key , cos , sin , self .is_neox_style )
339
+ key = _apply_rotary_emb (key , cos , sin , self .is_neox_style ,
340
+ self ._use_flash_attn )
315
341
key = key .reshape (key_shape )
316
342
else :
317
343
head_size = query .shape [- 1 ]
@@ -938,6 +964,10 @@ def __init__(
938
964
self .mrope_section = mrope_section
939
965
if self .mrope_section :
940
966
assert sum (self .mrope_section ) == rotary_dim // 2
967
+ if is_flash_attn_2_available ():
968
+ self ._use_flash_attn = True
969
+ else :
970
+ self ._use_flash_attn = False
941
971
942
972
def forward (
943
973
self ,
@@ -977,14 +1007,16 @@ def forward(
977
1007
query = query .view (num_tokens , - 1 , self .head_size )
978
1008
query_rot = query [..., :self .rotary_dim ]
979
1009
query_pass = query [..., self .rotary_dim :]
980
- query_rot = _apply_rotary_emb (query_rot , cos , sin , self .is_neox_style )
1010
+ query_rot = _apply_rotary_emb (query_rot , cos , sin , self .is_neox_style ,
1011
+ self ._use_flash_attn )
981
1012
query = torch .cat ((query_rot , query_pass ), dim = - 1 ).reshape (query_shape )
982
1013
983
1014
key_shape = key .shape
984
1015
key = key .view (num_tokens , - 1 , self .head_size )
985
1016
key_rot = key [..., :self .rotary_dim ]
986
1017
key_pass = key [..., self .rotary_dim :]
987
- key_rot = _apply_rotary_emb (key_rot , cos , sin , self .is_neox_style )
1018
+ key_rot = _apply_rotary_emb (key_rot , cos , sin , self .is_neox_style ,
1019
+ self ._use_flash_attn )
988
1020
key = torch .cat ((key_rot , key_pass ), dim = - 1 ).reshape (key_shape )
989
1021
return query , key
990
1022
0 commit comments