Skip to content

Commit 5117617

Browse files
committed
fix: leaked thread from test_engine_core_client_asyncio
Signed-off-by: alec-flowers <[email protected]>
1 parent 7ec5feb commit 5117617

File tree

1 file changed

+42
-44
lines changed

1 file changed

+42
-44
lines changed

tests/v1/engine/test_engine_core_client.py

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import psutil
1010
import pytest
11-
import zmq
1211
from transformers import AutoTokenizer
1312

1413
from vllm import SamplingParams
@@ -201,54 +200,57 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
201200
log_stats=True,
202201
)
203202

204-
MAX_TOKENS = 20
205-
params = SamplingParams(max_tokens=MAX_TOKENS)
206-
"""Normal Request Cycle."""
207-
208-
requests = [make_request(params) for _ in range(10)]
209-
request_ids = [req.request_id for req in requests]
210-
211-
# Add requests to the engine.
212-
for request in requests:
213-
await client.add_request_async(request)
214-
await asyncio.sleep(0.01)
215-
216-
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
217-
await loop_until_done_async(client, outputs)
203+
try:
204+
MAX_TOKENS = 20
205+
params = SamplingParams(max_tokens=MAX_TOKENS)
206+
"""Normal Request Cycle."""
218207

219-
for req_id in request_ids:
220-
assert len(outputs[req_id]) == MAX_TOKENS, (
221-
f"{outputs[req_id]=}, {MAX_TOKENS=}")
222-
"""Abort Request Cycle."""
208+
requests = [make_request(params) for _ in range(10)]
209+
request_ids = [req.request_id for req in requests]
223210

224-
# Add requests to the engine.
225-
for idx, request in enumerate(requests):
226-
await client.add_request_async(request)
227-
await asyncio.sleep(0.01)
228-
if idx % 2 == 0:
229-
await client.abort_requests_async([request.request_id])
211+
# Add requests to the engine.
212+
for request in requests:
213+
await client.add_request_async(request)
214+
await asyncio.sleep(0.01)
230215

231-
outputs = {req_id: [] for req_id in request_ids}
232-
await loop_until_done_async(client, outputs)
216+
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
217+
await loop_until_done_async(client, outputs)
233218

234-
for idx, req_id in enumerate(request_ids):
235-
if idx % 2 == 0:
236-
assert len(outputs[req_id]) < MAX_TOKENS, (
237-
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
238-
else:
219+
for req_id in request_ids:
239220
assert len(outputs[req_id]) == MAX_TOKENS, (
240-
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
241-
"""Utility method invocation"""
221+
f"{outputs[req_id]=}, {MAX_TOKENS=}")
222+
"""Abort Request Cycle."""
223+
224+
# Add requests to the engine.
225+
for idx, request in enumerate(requests):
226+
await client.add_request_async(request)
227+
await asyncio.sleep(0.01)
228+
if idx % 2 == 0:
229+
await client.abort_requests_async([request.request_id])
230+
231+
outputs = {req_id: [] for req_id in request_ids}
232+
await loop_until_done_async(client, outputs)
233+
234+
for idx, req_id in enumerate(request_ids):
235+
if idx % 2 == 0:
236+
assert len(outputs[req_id]) < MAX_TOKENS, (
237+
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
238+
else:
239+
assert len(outputs[req_id]) == MAX_TOKENS, (
240+
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
241+
"""Utility method invocation"""
242242

243-
core_client: AsyncMPClient = client
243+
core_client: AsyncMPClient = client
244244

245-
result = await core_client.call_utility_async("echo", "testarg")
246-
assert result == "testarg"
245+
result = await core_client.call_utility_async("echo", "testarg")
246+
assert result == "testarg"
247247

248-
with pytest.raises(Exception) as e_info:
249-
await core_client.call_utility_async("echo", None, "help!")
248+
with pytest.raises(Exception) as e_info:
249+
await core_client.call_utility_async("echo", None, "help!")
250250

251-
assert str(e_info.value) == "Call to echo method failed: help!"
251+
assert str(e_info.value) == "Call to echo method failed: help!"
252+
finally:
253+
client.shutdown()
252254

253255

254256
@pytest.mark.parametrize(
@@ -333,10 +335,6 @@ def test_kv_cache_events(
333335
"Token ids should be the same as the custom tokens")
334336
finally:
335337
client.shutdown()
336-
subscriber.close()
337-
# TODO hack to try and fix CI hang
338-
ctx = zmq.Context.instance()
339-
ctx.term()
340338
return
341339

342340

0 commit comments

Comments
 (0)