Skip to content

Commit 61a4924

Browse files
alec-flowersmarkmc
authored and
jimpang
committed
[V1][Metrics] add support for kv event publishing (vllm-project#16750)
Signed-off-by: alec-flowers <[email protected]> Signed-off-by: Mark McLoughlin <[email protected]> Co-authored-by: Mark McLoughlin <[email protected]>
1 parent 9006a68 commit 61a4924

File tree

15 files changed

+1183
-51
lines changed

15 files changed

+1183
-51
lines changed

examples/online_serving/kv_events.sh

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#!/bin/bash
2+
# This file demonstrates the KV cache event publishing
3+
# We will launch a vllm instances configured to publish KV cache
4+
# events and launch a simple subscriber to log those events.
5+
6+
set -xe
7+
8+
echo "🚧🚧 Warning: The usage of KV cache events is experimental and subject to change 🚧🚧"
9+
sleep 1
10+
11+
MODEL_NAME=${HF_MODEL_NAME:-meta-llama/Meta-Llama-3.1-8B-Instruct}
12+
13+
# Trap the SIGINT signal (triggered by Ctrl+C)
14+
trap 'cleanup' INT
15+
16+
# Cleanup function
17+
cleanup() {
18+
echo "Caught Ctrl+C, cleaning up..."
19+
# Cleanup commands
20+
pgrep python | xargs kill -9
21+
pkill -f python
22+
echo "Cleanup complete. Exiting."
23+
exit 0
24+
}
25+
26+
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
27+
28+
# a function that waits vLLM server to start
29+
wait_for_server() {
30+
local port=$1
31+
timeout 1200 bash -c "
32+
until curl -s localhost:${port}/v1/completions > /dev/null; do
33+
sleep 1
34+
done" && return 0 || return 1
35+
}
36+
37+
vllm serve $MODEL_NAME \
38+
--port 8100 \
39+
--max-model-len 100 \
40+
--enforce-eager \
41+
--gpu-memory-utilization 0.8 \
42+
--trust-remote-code \
43+
--kv-events-config \
44+
'{"enable_kv_cache_events": true, "publisher": "zmq", "topic": "kv-events"}' &
45+
46+
wait_for_server 8100
47+
48+
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
49+
50+
python3 "$SCRIPT_DIR/kv_events_subscriber.py" &
51+
sleep 1
52+
53+
# serve two example requests
54+
output1=$(curl -X POST -s http://localhost:8100/v1/completions \
55+
-H "Content-Type: application/json" \
56+
-d '{
57+
"model": "'"$MODEL_NAME"'",
58+
"prompt": "Explain quantum computing in simple terms a 5-year-old could understand.",
59+
"max_tokens": 80,
60+
"temperature": 0
61+
}')
62+
63+
output2=$(curl -X POST -s http://localhost:8100/v1/completions \
64+
-H "Content-Type: application/json" \
65+
-d '{
66+
"model": "'"$MODEL_NAME"'",
67+
"prompt": "Explain quantum computing in simple terms a 50-year-old could understand.",
68+
"max_tokens": 80,
69+
"temperature": 0
70+
}')
71+
72+
# Cleanup commands
73+
pkill -9 -u "$USER" -f python
74+
pkill -9 -u "$USER" -f vllm
75+
76+
sleep 1
77+
78+
echo "Cleaned up"
79+
80+
# Print the outputs of the curl requests
81+
echo ""
82+
echo "Output of first request: $output1"
83+
echo "Output of second request: $output2"
84+
85+
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
86+
echo ""
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from typing import Any, Optional, Union
3+
4+
import msgspec
5+
import zmq
6+
from msgspec.msgpack import Decoder
7+
8+
9+
#
10+
# Types copied from vllm.distributed.kv_events
11+
#
12+
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True,
13+
gc=False):
14+
ts: float
15+
events: list[Any]
16+
17+
18+
class KVCacheEvent(msgspec.Struct,
19+
array_like=True,
20+
omit_defaults=True,
21+
gc=False,
22+
tag=True):
23+
"""Base class for all KV cache-related events"""
24+
25+
26+
class BlockStored(KVCacheEvent):
27+
block_hashes: list[int]
28+
parent_block_hash: Optional[int]
29+
token_ids: list[int]
30+
block_size: int
31+
lora_id: Optional[int]
32+
33+
34+
class BlockRemoved(KVCacheEvent):
35+
block_hashes: list[int]
36+
37+
38+
class AllBlocksCleared(KVCacheEvent):
39+
pass
40+
41+
42+
class KVEventBatch(EventBatch):
43+
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
44+
45+
46+
def process_event(event_batch):
47+
print(f"Received event batch at {event_batch.ts}:")
48+
for event in event_batch.events:
49+
print(f" - {event}")
50+
51+
52+
def main():
53+
decoder = Decoder(type=KVEventBatch)
54+
last_seq = -1
55+
56+
context = zmq.Context()
57+
58+
# Set up the main subscription socket
59+
sub = context.socket(zmq.SUB)
60+
sub.connect("tcp://localhost:5557")
61+
topic = "kv-events"
62+
sub.setsockopt_string(zmq.SUBSCRIBE, topic)
63+
64+
# Initialize replay socket
65+
replay = context.socket(zmq.REQ)
66+
replay.connect("tcp://localhost:5558")
67+
poller = zmq.Poller()
68+
poller.register(replay, zmq.POLLIN)
69+
70+
print("Listening for KV cache events on topic:", topic)
71+
72+
while True:
73+
try:
74+
if sub.poll(50):
75+
_, seq_bytes, payload = sub.recv_multipart()
76+
seq = int.from_bytes(seq_bytes, "big")
77+
78+
if last_seq >= 0 and seq > last_seq + 1:
79+
missed = seq - last_seq - 1
80+
print(f"Missed {missed} messages"
81+
f" (last: {last_seq}, current: {seq})")
82+
83+
replay.send((last_seq + 1).to_bytes(8, "big"))
84+
85+
while poller.poll(timeout=200):
86+
seq_bytes, replay_payload = replay.recv_multipart()
87+
if not replay_payload:
88+
# End of replay marker is sent as an empty frame
89+
# for the payload
90+
break
91+
92+
replay_seq = int.from_bytes(seq_bytes, "big")
93+
94+
if replay_seq > last_seq:
95+
event_batch = decoder.decode(replay_payload)
96+
process_event(event_batch)
97+
last_seq = replay_seq
98+
if replay_seq >= seq - 1:
99+
break
100+
101+
event_batch = decoder.decode(payload)
102+
process_event(event_batch)
103+
104+
# ... do other periodic work or check for shutdown ...
105+
106+
except KeyboardInterrupt:
107+
print("Interrupted")
108+
break
109+
except Exception as e:
110+
print("Error decoding message:", e)
111+
112+
113+
if __name__ == "__main__":
114+
main()

tests/distributed/conftest.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import random
3+
from typing import Optional, Union
4+
5+
import msgspec
6+
import msgspec.msgpack
7+
import pytest
8+
import zmq
9+
10+
from vllm.config import KVEventsConfig
11+
from vllm.distributed.kv_events import EventPublisherFactory
12+
13+
from .test_events import SampleBatch
14+
15+
16+
@pytest.fixture
17+
def random_port():
18+
"""Generate a random port number for testing"""
19+
return random.randint(10000, 60000)
20+
21+
22+
@pytest.fixture
23+
def publisher_config(random_port, request):
24+
"""Create a publisher config with inproc transport"""
25+
how = request.param if hasattr(request, "param") else "inproc"
26+
27+
if how == "inproc":
28+
endpoint = f"inproc://test-{random_port}"
29+
replay_endpoint = endpoint + "-replay"
30+
else:
31+
endpoint = f"tcp://*:{random_port}"
32+
replay_endpoint = f"tcp://*:{random_port + 1}"
33+
34+
return KVEventsConfig(enable_kv_cache_events=True,
35+
publisher="zmq",
36+
endpoint=endpoint,
37+
replay_endpoint=replay_endpoint,
38+
buffer_steps=100,
39+
hwm=1000,
40+
topic="test")
41+
42+
43+
@pytest.fixture
44+
def publisher(publisher_config):
45+
"""Create and return a publisher instance"""
46+
pub = EventPublisherFactory.create(publisher_config)
47+
yield pub
48+
pub.shutdown()
49+
50+
51+
@pytest.fixture
52+
def subscriber(publisher_config):
53+
"""Create and return a subscriber for testing"""
54+
endpoint = publisher_config.endpoint
55+
replay_endpoint = publisher_config.replay_endpoint
56+
57+
if endpoint.startswith("tcp://*"):
58+
endpoint = endpoint.replace("*", "127.0.0.1")
59+
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
60+
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
61+
62+
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
63+
yield sub
64+
sub.close()
65+
66+
67+
class MockSubscriber:
68+
"""Helper class to receive and verify published events"""
69+
70+
def __init__(self,
71+
pub_endpoint: str,
72+
replay_endpoint: Optional[str] = None,
73+
topic: str = "",
74+
decode_type=SampleBatch):
75+
self.ctx = zmq.Context.instance()
76+
77+
# Set up subscriber socket
78+
self.sub = self.ctx.socket(zmq.SUB)
79+
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
80+
self.sub.connect(pub_endpoint)
81+
82+
# Set up replay socket if provided
83+
self.replay = None
84+
if replay_endpoint:
85+
self.replay = self.ctx.socket(zmq.REQ)
86+
self.replay.connect(replay_endpoint)
87+
88+
self.topic = topic
89+
self.topic_bytes = topic.encode('utf-8')
90+
self.received_msgs: list[tuple[int, SampleBatch]] = []
91+
self.last_seq = -1
92+
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
93+
94+
def receive_one(self,
95+
timeout=1000) -> Union[tuple[int, SampleBatch], None]:
96+
"""Receive a single message with timeout"""
97+
if not self.sub.poll(timeout):
98+
return None
99+
100+
topic_bytes, seq_bytes, payload = self.sub.recv_multipart()
101+
assert topic_bytes == self.topic_bytes
102+
103+
seq = int.from_bytes(seq_bytes, "big")
104+
data = self.decoder.decode(payload)
105+
self.last_seq = seq
106+
self.received_msgs.append((seq, data))
107+
return seq, data
108+
109+
def request_replay(self, start_seq: int) -> None:
110+
"""Request replay of messages starting from start_seq"""
111+
if not self.replay:
112+
raise ValueError("Replay socket not initialized")
113+
114+
self.replay.send(start_seq.to_bytes(8, "big"))
115+
116+
def receive_replay(self) -> list[tuple[int, SampleBatch]]:
117+
"""Receive replayed messages"""
118+
if not self.replay:
119+
raise ValueError("Replay socket not initialized")
120+
121+
replayed: list[tuple[int, SampleBatch]] = []
122+
while True:
123+
try:
124+
if not self.replay.poll(1000):
125+
break
126+
127+
frames = self.replay.recv_multipart()
128+
if not frames or not frames[-1]:
129+
# End of replay marker
130+
break
131+
132+
seq_bytes, payload = frames
133+
seq = int.from_bytes(seq_bytes, "big")
134+
data = self.decoder.decode(payload)
135+
replayed.append((seq, data))
136+
except zmq.ZMQError as _:
137+
break
138+
139+
return replayed
140+
141+
def close(self):
142+
"""Clean up resources"""
143+
self.sub.close()
144+
if self.replay:
145+
self.replay.close()

0 commit comments

Comments
 (0)