Skip to content

Commit 5d20cdd

Browse files
committed
Apply torch.compile & cudagraph to EAGLE3
1 parent 81ecf42 commit 5d20cdd

File tree

2 files changed

+31
-23
lines changed

2 files changed

+31
-23
lines changed

vllm/model_executor/models/llama_eagle3.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.nn as nn
77
from transformers import LlamaConfig
88

9+
from vllm.compilation.decorators import support_torch_compile
910
from vllm.config import ModelConfig, VllmConfig
1011
from vllm.logger import init_logger
1112
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -75,18 +76,19 @@ def forward(
7576

7677
return hidden_states, residual
7778

78-
79+
@support_torch_compile
7980
class LlamaModel(nn.Module):
8081

8182
def __init__(
8283
self,
8384
*,
84-
model_config: ModelConfig,
85+
vllm_config: VllmConfig,
8586
start_layer_id: int = 0,
8687
prefix: str = "",
8788
) -> None:
8889
super().__init__()
89-
self.config = model_config.hf_config
90+
self.config = vllm_config. \
91+
speculative_config.draft_model_config.hf_config
9092
self.vocab_size = self.config.vocab_size
9193
self.embed_tokens = VocabParallelEmbedding(
9294
self.config.vocab_size,
@@ -119,8 +121,7 @@ def forward(
119121
hidden_states: torch.Tensor,
120122
) -> tuple[torch.Tensor, torch.Tensor]:
121123
input_embeds = self.embed_tokens(input_ids)
122-
if (hidden_states.shape[-1] != input_embeds.shape[-1]):
123-
hidden_states = self.fc(hidden_states)
124+
assert hidden_states.shape[-1] == input_embeds.shape[-1]
124125

125126
residual = None
126127
hidden_states, residual = self.layers[0](
@@ -169,9 +170,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
169170

170171
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
171172
nn.Module.__init__(self)
172-
model_config = vllm_config.speculative_config.draft_model_config
173-
self.config = model_config.hf_config
174-
self.model = LlamaModel(model_config=model_config,
173+
self.config = vllm_config. \
174+
speculative_config.draft_model_config.hf_config
175+
self.model = LlamaModel(vllm_config=vllm_config,
175176
start_layer_id=start_layer_id,
176177
prefix="model")
177178

@@ -214,6 +215,13 @@ def compute_logits(
214215
logits_new[:, targets] = logits
215216
return logits_new
216217

218+
def combine_hidden_states(
219+
self,
220+
hidden_states: torch.Tensor,
221+
) -> torch.Tensor:
222+
# combine multiple auxiliary hidden states returned by eagle3
223+
return self.model.fc(hidden_states)
224+
217225
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
218226
loader = AutoWeightsLoader(
219227
self,

vllm/v1/spec_decode/eagle.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.logger import init_logger
1010
from vllm.model_executor.model_loader.loader import get_model_loader
1111
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
12+
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
1213
from vllm.model_executor.models import ModelRegistry
1314
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
1415
from vllm.v1.sample.metadata import SamplingMetadata
@@ -39,11 +40,10 @@ def __init__(
3940

4041
self.hidden_size = vllm_config.model_config.get_hidden_size()
4142

42-
# TODO: make eagle3 compatible with cudagraph
43-
self.use_cuda_graph = self.method != 'eagle3' and \
44-
(self.vllm_config.compilation_config.level
45-
== CompilationLevel.PIECEWISE and
46-
not self.vllm_config.model_config.enforce_eager)
43+
self.use_cuda_graph = (
44+
self.vllm_config.compilation_config.level
45+
== CompilationLevel.PIECEWISE and
46+
not self.vllm_config.model_config.enforce_eager)
4747

4848
self.cudagraph_batch_sizes = list(
4949
reversed(
@@ -90,6 +90,12 @@ def propose(
9090
batch_size = next_token_ids.shape[0]
9191
last_token_indices = cu_num_tokens[1:] - 1
9292

93+
if self.method == "eagle3":
94+
assert isinstance(self.model, Eagle3LlamaForCausalLM)
95+
target_hidden_states = self.model.combine_hidden_states(
96+
target_hidden_states)
97+
assert target_hidden_states.shape[-1] == self.hidden_size
98+
9399
# Shift the input ids by one token.
94100
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
95101
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
@@ -126,12 +132,8 @@ def propose(
126132
# copy inputs to buffer for cudagraph
127133
self.positions[:num_tokens] = target_positions
128134

129-
if self.method == 'eagle':
130-
self.hidden_states[:num_tokens] = target_hidden_states
131-
hidden_states = self.hidden_states
132-
else:
133-
# TODO: make eagle3 compatible with cuda graph
134-
hidden_states = target_hidden_states
135+
self.hidden_states[:num_tokens] = target_hidden_states
136+
hidden_states = self.hidden_states
135137

136138
with set_forward_context(attn_metadata,
137139
self.vllm_config,
@@ -209,10 +211,8 @@ def propose(
209211
self.input_ids[:batch_size] = input_ids
210212
self.positions[:batch_size] = clamped_positions
211213

212-
if self.method == 'eagle':
213-
# TODO: make eagle3 compatible with cudagraph.
214-
self.hidden_states[:batch_size] = hidden_states
215-
hidden_states = self.hidden_states
214+
self.hidden_states[:batch_size] = hidden_states
215+
hidden_states = self.hidden_states
216216

217217
# Run the model.
218218
with set_forward_context(attn_metadata,

0 commit comments

Comments
 (0)