|
28 | 28 | import torch
|
29 | 29 | import torch.nn as nn
|
30 | 30 | from transformers import PretrainedConfig
|
| 31 | +from transformers.utils import is_flash_attn_2_available |
31 | 32 |
|
32 | 33 | from vllm.model_executor.custom_op import CustomOp
|
33 | 34 | from vllm.platforms import current_platform
|
34 | 35 |
|
| 36 | +if is_flash_attn_2_available(): |
| 37 | + from flash_attn.ops.triton.rotary import apply_rotary |
| 38 | + |
35 | 39 |
|
36 | 40 | def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
37 | 41 | x1 = x[..., :x.shape[-1] // 2]
|
@@ -100,6 +104,10 @@ def __init__(
|
100 | 104 | cache = cache.to(dtype)
|
101 | 105 | self.cos_sin_cache: torch.Tensor
|
102 | 106 | self.register_buffer("cos_sin_cache", cache, persistent=False)
|
| 107 | + if is_flash_attn_2_available(): |
| 108 | + self._use_flash_attn = True |
| 109 | + else: |
| 110 | + self._use_flash_attn = False |
103 | 111 |
|
104 | 112 | def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
105 | 113 | """Compute the inverse frequency."""
|
@@ -141,14 +149,23 @@ def forward_native(
|
141 | 149 | query = query.view(num_tokens, -1, self.head_size)
|
142 | 150 | query_rot = query[..., :self.rotary_dim]
|
143 | 151 | query_pass = query[..., self.rotary_dim:]
|
144 |
| - query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) |
| 152 | + if self._use_flash_attn: |
| 153 | + query_rot = apply_rotary(query_rot.unsqueeze(0), cos, sin, |
| 154 | + 0).squeeze(0) |
| 155 | + else: |
| 156 | + query_rot = _apply_rotary_emb(query_rot, cos, sin, |
| 157 | + self.is_neox_style) |
145 | 158 | query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
146 | 159 |
|
147 | 160 | key_shape = key.shape
|
148 | 161 | key = key.view(num_tokens, -1, self.head_size)
|
149 | 162 | key_rot = key[..., :self.rotary_dim]
|
150 | 163 | key_pass = key[..., self.rotary_dim:]
|
151 |
| - key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) |
| 164 | + if self._use_flash_attn: |
| 165 | + key_rot = apply_rotary(key_rot.unsqueeze(0), cos, sin, |
| 166 | + 0).squeeze(0) |
| 167 | + else: |
| 168 | + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) |
152 | 169 | key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
153 | 170 | return query, key
|
154 | 171 |
|
@@ -938,6 +955,10 @@ def __init__(
|
938 | 955 | self.mrope_section = mrope_section
|
939 | 956 | if self.mrope_section:
|
940 | 957 | assert sum(self.mrope_section) == rotary_dim // 2
|
| 958 | + if is_flash_attn_2_available(): |
| 959 | + self._use_flash_attn = True |
| 960 | + else: |
| 961 | + self._use_flash_attn = False |
941 | 962 |
|
942 | 963 | def forward(
|
943 | 964 | self,
|
@@ -977,14 +998,23 @@ def forward(
|
977 | 998 | query = query.view(num_tokens, -1, self.head_size)
|
978 | 999 | query_rot = query[..., :self.rotary_dim]
|
979 | 1000 | query_pass = query[..., self.rotary_dim:]
|
980 |
| - query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) |
| 1001 | + if self._use_flash_attn: |
| 1002 | + query_rot = apply_rotary(query_rot.unsqueeze(0), cos, sin, |
| 1003 | + 0).squeeze(0) |
| 1004 | + else: |
| 1005 | + query_rot = _apply_rotary_emb(query_rot, cos, sin, |
| 1006 | + self.is_neox_style) |
981 | 1007 | query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
982 | 1008 |
|
983 | 1009 | key_shape = key.shape
|
984 | 1010 | key = key.view(num_tokens, -1, self.head_size)
|
985 | 1011 | key_rot = key[..., :self.rotary_dim]
|
986 | 1012 | key_pass = key[..., self.rotary_dim:]
|
987 |
| - key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) |
| 1013 | + if self._use_flash_attn: |
| 1014 | + key_rot = apply_rotary(key_rot.unsqueeze(0), cos, sin, |
| 1015 | + 0).squeeze(0) |
| 1016 | + else: |
| 1017 | + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) |
988 | 1018 | key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
989 | 1019 | return query, key
|
990 | 1020 |
|
|
0 commit comments