Skip to content

Commit 3e6f13a

Browse files
[Perf]Optimize rotary_emb implementation to use Triton operator for improved inference performance
Signed-off-by: cynthieye <[email protected]> Co-authored-by: MagnetoWang <[email protected]>
1 parent 99ef59c commit 3e6f13a

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

vllm/model_executor/layers/rotary_embedding.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
4646
return x.flatten(-2)
4747

4848

49-
def _apply_rotary_emb(
49+
def _apply_rotary_emb_torch(
5050
x: torch.Tensor,
5151
cos: torch.Tensor,
5252
sin: torch.Tensor,
5353
is_neox_style: bool,
5454
) -> torch.Tensor:
55-
"""
56-
Args:
57-
x: [num_tokens, num_heads, head_size]
58-
cos: [num_tokens, head_size // 2]
59-
sin: [num_tokens, head_size // 2]
60-
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
61-
positional embeddings.
62-
"""
6355
cos = cos.unsqueeze(-2).to(x.dtype)
6456
sin = sin.unsqueeze(-2).to(x.dtype)
6557
if is_neox_style:
@@ -75,6 +67,24 @@ def _apply_rotary_emb(
7567
return torch.stack((o1, o2), dim=-1).flatten(-2)
7668

7769

70+
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
71+
is_neox_style: bool) -> torch.Tensor:
72+
"""
73+
Args:
74+
x: [num_tokens, num_heads, head_size]
75+
cos: [num_tokens, head_size // 2]
76+
sin: [num_tokens, head_size // 2]
77+
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
78+
positional embeddings.
79+
"""
80+
if current_platform.is_cuda_alike():
81+
from vllm_flash_attn.layers.rotary import apply_rotary_emb
82+
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
83+
not is_neox_style).squeeze(0)
84+
else:
85+
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
86+
87+
7888
@CustomOp.register("rotary_embedding")
7989
class RotaryEmbedding(CustomOp):
8090
"""Original rotary positional embedding."""
@@ -141,14 +151,16 @@ def forward_native(
141151
query = query.view(num_tokens, -1, self.head_size)
142152
query_rot = query[..., :self.rotary_dim]
143153
query_pass = query[..., self.rotary_dim:]
144-
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
154+
query_rot = _apply_rotary_emb_torch(query_rot, cos, sin,
155+
self.is_neox_style)
145156
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
146157

147158
key_shape = key.shape
148159
key = key.view(num_tokens, -1, self.head_size)
149160
key_rot = key[..., :self.rotary_dim]
150161
key_pass = key[..., self.rotary_dim:]
151-
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
162+
key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
163+
self.is_neox_style)
152164
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
153165
return query, key
154166

0 commit comments

Comments
 (0)