|
9 | 9 | from vllm.logger import init_logger
|
10 | 10 | from vllm.model_executor.model_loader.loader import get_model_loader
|
11 | 11 | from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
| 12 | +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM |
12 | 13 | from vllm.model_executor.models import ModelRegistry
|
13 | 14 | from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
14 | 15 | from vllm.v1.sample.metadata import SamplingMetadata
|
@@ -39,11 +40,10 @@ def __init__(
|
39 | 40 |
|
40 | 41 | self.hidden_size = vllm_config.model_config.get_hidden_size()
|
41 | 42 |
|
42 |
| - # TODO: make eagle3 compatible with cudagraph |
43 |
| - self.use_cuda_graph = self.method != 'eagle3' and \ |
44 |
| - (self.vllm_config.compilation_config.level |
45 |
| - == CompilationLevel.PIECEWISE and |
46 |
| - not self.vllm_config.model_config.enforce_eager) |
| 43 | + self.use_cuda_graph = ( |
| 44 | + self.vllm_config.compilation_config.level |
| 45 | + == CompilationLevel.PIECEWISE and |
| 46 | + not self.vllm_config.model_config.enforce_eager) |
47 | 47 |
|
48 | 48 | self.cudagraph_batch_sizes = list(
|
49 | 49 | reversed(
|
@@ -90,6 +90,12 @@ def propose(
|
90 | 90 | batch_size = next_token_ids.shape[0]
|
91 | 91 | last_token_indices = cu_num_tokens[1:] - 1
|
92 | 92 |
|
| 93 | + if self.method == "eagle3": |
| 94 | + assert isinstance(self.model, Eagle3LlamaForCausalLM) |
| 95 | + target_hidden_states = self.model.combine_hidden_states( |
| 96 | + target_hidden_states) |
| 97 | + assert target_hidden_states.shape[-1] == self.hidden_size |
| 98 | + |
93 | 99 | # Shift the input ids by one token.
|
94 | 100 | # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
95 | 101 | self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
@@ -126,12 +132,8 @@ def propose(
|
126 | 132 | # copy inputs to buffer for cudagraph
|
127 | 133 | self.positions[:num_tokens] = target_positions
|
128 | 134 |
|
129 |
| - if self.method == 'eagle': |
130 |
| - self.hidden_states[:num_tokens] = target_hidden_states |
131 |
| - hidden_states = self.hidden_states |
132 |
| - else: |
133 |
| - # TODO: make eagle3 compatible with cuda graph |
134 |
| - hidden_states = target_hidden_states |
| 135 | + self.hidden_states[:num_tokens] = target_hidden_states |
| 136 | + hidden_states = self.hidden_states |
135 | 137 |
|
136 | 138 | with set_forward_context(attn_metadata,
|
137 | 139 | self.vllm_config,
|
@@ -209,10 +211,8 @@ def propose(
|
209 | 211 | self.input_ids[:batch_size] = input_ids
|
210 | 212 | self.positions[:batch_size] = clamped_positions
|
211 | 213 |
|
212 |
| - if self.method == 'eagle': |
213 |
| - # TODO: make eagle3 compatible with cudagraph. |
214 |
| - self.hidden_states[:batch_size] = hidden_states |
215 |
| - hidden_states = self.hidden_states |
| 214 | + self.hidden_states[:batch_size] = hidden_states |
| 215 | + hidden_states = self.hidden_states |
216 | 216 |
|
217 | 217 | # Run the model.
|
218 | 218 | with set_forward_context(attn_metadata,
|
|
0 commit comments