Skip to content

Commit e630289

Browse files
[Perf]: Optimize rotary_emb implementation to use Triton operator for improved inference performance
Signed-off-by: cynthieye <[email protected]> Co-authored-by: MagnetoWang <[email protected]>
1 parent 99ef59c commit e630289

File tree

1 file changed

+34
-4
lines changed

1 file changed

+34
-4
lines changed

vllm/model_executor/layers/rotary_embedding.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,14 @@
2828
import torch
2929
import torch.nn as nn
3030
from transformers import PretrainedConfig
31+
from transformers.utils import is_flash_attn_2_available
3132

3233
from vllm.model_executor.custom_op import CustomOp
3334
from vllm.platforms import current_platform
3435

36+
if is_flash_attn_2_available():
37+
from flash_attn.ops.triton.rotary import apply_rotary
38+
3539

3640
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
3741
x1 = x[..., :x.shape[-1] // 2]
@@ -100,6 +104,10 @@ def __init__(
100104
cache = cache.to(dtype)
101105
self.cos_sin_cache: torch.Tensor
102106
self.register_buffer("cos_sin_cache", cache, persistent=False)
107+
if is_flash_attn_2_available():
108+
self._use_flash_attn = True
109+
else:
110+
self._use_flash_attn = False
103111

104112
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
105113
"""Compute the inverse frequency."""
@@ -141,14 +149,23 @@ def forward_native(
141149
query = query.view(num_tokens, -1, self.head_size)
142150
query_rot = query[..., :self.rotary_dim]
143151
query_pass = query[..., self.rotary_dim:]
144-
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
152+
if self._use_flash_attn:
153+
query_rot = apply_rotary(query_rot.unsqueeze(0), cos, sin,
154+
0).squeeze(0)
155+
else:
156+
query_rot = _apply_rotary_emb(query_rot, cos, sin,
157+
self.is_neox_style)
145158
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
146159

147160
key_shape = key.shape
148161
key = key.view(num_tokens, -1, self.head_size)
149162
key_rot = key[..., :self.rotary_dim]
150163
key_pass = key[..., self.rotary_dim:]
151-
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
164+
if self._use_flash_attn:
165+
key_rot = apply_rotary(key_rot.unsqueeze(0), cos, sin,
166+
0).squeeze(0)
167+
else:
168+
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
152169
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
153170
return query, key
154171

@@ -938,6 +955,10 @@ def __init__(
938955
self.mrope_section = mrope_section
939956
if self.mrope_section:
940957
assert sum(self.mrope_section) == rotary_dim // 2
958+
if is_flash_attn_2_available():
959+
self._use_flash_attn = True
960+
else:
961+
self._use_flash_attn = False
941962

942963
def forward(
943964
self,
@@ -977,14 +998,23 @@ def forward(
977998
query = query.view(num_tokens, -1, self.head_size)
978999
query_rot = query[..., :self.rotary_dim]
9791000
query_pass = query[..., self.rotary_dim:]
980-
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
1001+
if self._use_flash_attn:
1002+
query_rot = apply_rotary(query_rot.unsqueeze(0), cos, sin,
1003+
0).squeeze(0)
1004+
else:
1005+
query_rot = _apply_rotary_emb(query_rot, cos, sin,
1006+
self.is_neox_style)
9811007
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
9821008

9831009
key_shape = key.shape
9841010
key = key.view(num_tokens, -1, self.head_size)
9851011
key_rot = key[..., :self.rotary_dim]
9861012
key_pass = key[..., self.rotary_dim:]
987-
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
1013+
if self._use_flash_attn:
1014+
key_rot = apply_rotary(key_rot.unsqueeze(0), cos, sin,
1015+
0).squeeze(0)
1016+
else:
1017+
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
9881018
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
9891019
return query, key
9901020

0 commit comments

Comments
 (0)