Skip to content

[Minor][Models] Pass partial_rotary_factor parameter to rope #17266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def __init__(self,
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
# Phi models introduced a partial_rotary_factor parameter in the config
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
1)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
Expand Down Expand Up @@ -163,11 +163,12 @@ def __init__(self,

self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)

if hasattr(config, "interleaved_sliding_window"):
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ def __init__(self,

self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=int(self.partial_rotary_factor * self.head_dim),
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
partial_rotary_factor=self.partial_rotary_factor,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,8 @@ def __init__(self,
1, self.total_num_key_value_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
rope_pct = getattr(config, "rope_pct",
getattr(config, "partial_rotary_factor", 1))
self.rotary_ndims = int(self.head_dim * rope_pct)
self.partial_rotary_factor = getattr(
config, "rope_pct", getattr(config, "partial_rotary_factor", 1))
self.scaling = self.head_dim**-0.5
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
Expand All @@ -130,9 +129,10 @@ def __init__(self,
prefix=f"{prefix}.o_proj")
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_ndims,
rotary_dim=self.head_dim,
max_position=self.config.max_position_embeddings,
base=self.config.rope_theta,
partial_rotary_factor=self.partial_rotary_factor,
)
self.attn = Attention(self.num_heads,
self.head_dim,
Expand Down