Skip to content

Commit c435afc

Browse files
[Perf]Optimize rotary_emb implementation to use Triton operator for improved inference performance
Signed-off-by: cynthieye <[email protected]> Co-authored-by: MagnetoWang <[email protected]>
1 parent 99ef59c commit c435afc

File tree

1 file changed

+47
-15
lines changed

1 file changed

+47
-15
lines changed

vllm/model_executor/layers/rotary_embedding.py

+47-15
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,14 @@
2828
import torch
2929
import torch.nn as nn
3030
from transformers import PretrainedConfig
31+
from transformers.utils import is_flash_attn_2_available
3132

3233
from vllm.model_executor.custom_op import CustomOp
3334
from vllm.platforms import current_platform
3435

36+
if is_flash_attn_2_available():
37+
from flash_attn.layers.rotary import apply_rotary_emb
38+
3539

3640
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
3741
x1 = x[..., :x.shape[-1] // 2]
@@ -46,20 +50,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
4650
return x.flatten(-2)
4751

4852

49-
def _apply_rotary_emb(
53+
def _apply_rotary_emb_torch(
5054
x: torch.Tensor,
5155
cos: torch.Tensor,
5256
sin: torch.Tensor,
5357
is_neox_style: bool,
5458
) -> torch.Tensor:
55-
"""
56-
Args:
57-
x: [num_tokens, num_heads, head_size]
58-
cos: [num_tokens, head_size // 2]
59-
sin: [num_tokens, head_size // 2]
60-
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
61-
positional embeddings.
62-
"""
6359
cos = cos.unsqueeze(-2).to(x.dtype)
6460
sin = sin.unsqueeze(-2).to(x.dtype)
6561
if is_neox_style:
@@ -75,6 +71,28 @@ def _apply_rotary_emb(
7571
return torch.stack((o1, o2), dim=-1).flatten(-2)
7672

7773

74+
def _apply_rotary_emb(
75+
x: torch.Tensor,
76+
cos: torch.Tensor,
77+
sin: torch.Tensor,
78+
is_neox_style: bool,
79+
use_flash_attn=False
80+
) -> torch.Tensor:
81+
"""
82+
Args:
83+
x: [num_tokens, num_heads, head_size]
84+
cos: [num_tokens, head_size // 2]
85+
sin: [num_tokens, head_size // 2]
86+
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
87+
positional embeddings.
88+
use_flash_attn: Whether to enable Flash Attention optimizations.
89+
"""
90+
if use_flash_attn:
91+
return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0)
92+
else:
93+
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
94+
95+
7896
@CustomOp.register("rotary_embedding")
7997
class RotaryEmbedding(CustomOp):
8098
"""Original rotary positional embedding."""
@@ -100,6 +118,10 @@ def __init__(
100118
cache = cache.to(dtype)
101119
self.cos_sin_cache: torch.Tensor
102120
self.register_buffer("cos_sin_cache", cache, persistent=False)
121+
if is_flash_attn_2_available():
122+
self._use_flash_attn = True
123+
else:
124+
self._use_flash_attn = False
103125

104126
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
105127
"""Compute the inverse frequency."""
@@ -141,14 +163,16 @@ def forward_native(
141163
query = query.view(num_tokens, -1, self.head_size)
142164
query_rot = query[..., :self.rotary_dim]
143165
query_pass = query[..., self.rotary_dim:]
144-
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
166+
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style,
167+
self._use_flash_attn)
145168
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
146169

147170
key_shape = key.shape
148171
key = key.view(num_tokens, -1, self.head_size)
149172
key_rot = key[..., :self.rotary_dim]
150173
key_pass = key[..., self.rotary_dim:]
151-
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
174+
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style,
175+
self._use_flash_attn)
152176
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
153177
return query, key
154178

@@ -309,9 +333,11 @@ def _apply_rotary_emb_neuron(
309333
key = key.view(num_tokens, -1, self.head_size)
310334

311335
if self.rotary_dim == self.head_size:
312-
query = _apply_rotary_emb(query, cos, sin, self.is_neox_style)
336+
query = _apply_rotary_emb(query, cos, sin, self.is_neox_style,
337+
self._use_flash_attn)
313338
query = query.reshape(query_shape)
314-
key = _apply_rotary_emb(key, cos, sin, self.is_neox_style)
339+
key = _apply_rotary_emb(key, cos, sin, self.is_neox_style,
340+
self._use_flash_attn)
315341
key = key.reshape(key_shape)
316342
else:
317343
head_size = query.shape[-1]
@@ -938,6 +964,10 @@ def __init__(
938964
self.mrope_section = mrope_section
939965
if self.mrope_section:
940966
assert sum(self.mrope_section) == rotary_dim // 2
967+
if is_flash_attn_2_available():
968+
self._use_flash_attn = True
969+
else:
970+
self._use_flash_attn = False
941971

942972
def forward(
943973
self,
@@ -977,14 +1007,16 @@ def forward(
9771007
query = query.view(num_tokens, -1, self.head_size)
9781008
query_rot = query[..., :self.rotary_dim]
9791009
query_pass = query[..., self.rotary_dim:]
980-
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
1010+
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style,
1011+
self._use_flash_attn)
9811012
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
9821013

9831014
key_shape = key.shape
9841015
key = key.view(num_tokens, -1, self.head_size)
9851016
key_rot = key[..., :self.rotary_dim]
9861017
key_pass = key[..., self.rotary_dim:]
987-
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
1018+
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style,
1019+
self._use_flash_attn)
9881020
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
9891021
return query, key
9901022

0 commit comments

Comments
 (0)