|
8 | 8 |
|
9 | 9 | import psutil
|
10 | 10 | import pytest
|
11 |
| -import zmq |
12 | 11 | from transformers import AutoTokenizer
|
13 | 12 |
|
14 | 13 | from vllm import SamplingParams
|
@@ -201,54 +200,57 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
201 | 200 | log_stats=True,
|
202 | 201 | )
|
203 | 202 |
|
204 |
| - MAX_TOKENS = 20 |
205 |
| - params = SamplingParams(max_tokens=MAX_TOKENS) |
206 |
| - """Normal Request Cycle.""" |
207 |
| - |
208 |
| - requests = [make_request(params) for _ in range(10)] |
209 |
| - request_ids = [req.request_id for req in requests] |
210 |
| - |
211 |
| - # Add requests to the engine. |
212 |
| - for request in requests: |
213 |
| - await client.add_request_async(request) |
214 |
| - await asyncio.sleep(0.01) |
215 |
| - |
216 |
| - outputs: dict[str, list] = {req_id: [] for req_id in request_ids} |
217 |
| - await loop_until_done_async(client, outputs) |
| 203 | + try: |
| 204 | + MAX_TOKENS = 20 |
| 205 | + params = SamplingParams(max_tokens=MAX_TOKENS) |
| 206 | + """Normal Request Cycle.""" |
218 | 207 |
|
219 |
| - for req_id in request_ids: |
220 |
| - assert len(outputs[req_id]) == MAX_TOKENS, ( |
221 |
| - f"{outputs[req_id]=}, {MAX_TOKENS=}") |
222 |
| - """Abort Request Cycle.""" |
| 208 | + requests = [make_request(params) for _ in range(10)] |
| 209 | + request_ids = [req.request_id for req in requests] |
223 | 210 |
|
224 |
| - # Add requests to the engine. |
225 |
| - for idx, request in enumerate(requests): |
226 |
| - await client.add_request_async(request) |
227 |
| - await asyncio.sleep(0.01) |
228 |
| - if idx % 2 == 0: |
229 |
| - await client.abort_requests_async([request.request_id]) |
| 211 | + # Add requests to the engine. |
| 212 | + for request in requests: |
| 213 | + await client.add_request_async(request) |
| 214 | + await asyncio.sleep(0.01) |
230 | 215 |
|
231 |
| - outputs = {req_id: [] for req_id in request_ids} |
232 |
| - await loop_until_done_async(client, outputs) |
| 216 | + outputs: dict[str, list] = {req_id: [] for req_id in request_ids} |
| 217 | + await loop_until_done_async(client, outputs) |
233 | 218 |
|
234 |
| - for idx, req_id in enumerate(request_ids): |
235 |
| - if idx % 2 == 0: |
236 |
| - assert len(outputs[req_id]) < MAX_TOKENS, ( |
237 |
| - f"{len(outputs[req_id])=}, {MAX_TOKENS=}") |
238 |
| - else: |
| 219 | + for req_id in request_ids: |
239 | 220 | assert len(outputs[req_id]) == MAX_TOKENS, (
|
240 |
| - f"{len(outputs[req_id])=}, {MAX_TOKENS=}") |
241 |
| - """Utility method invocation""" |
| 221 | + f"{outputs[req_id]=}, {MAX_TOKENS=}") |
| 222 | + """Abort Request Cycle.""" |
| 223 | + |
| 224 | + # Add requests to the engine. |
| 225 | + for idx, request in enumerate(requests): |
| 226 | + await client.add_request_async(request) |
| 227 | + await asyncio.sleep(0.01) |
| 228 | + if idx % 2 == 0: |
| 229 | + await client.abort_requests_async([request.request_id]) |
| 230 | + |
| 231 | + outputs = {req_id: [] for req_id in request_ids} |
| 232 | + await loop_until_done_async(client, outputs) |
| 233 | + |
| 234 | + for idx, req_id in enumerate(request_ids): |
| 235 | + if idx % 2 == 0: |
| 236 | + assert len(outputs[req_id]) < MAX_TOKENS, ( |
| 237 | + f"{len(outputs[req_id])=}, {MAX_TOKENS=}") |
| 238 | + else: |
| 239 | + assert len(outputs[req_id]) == MAX_TOKENS, ( |
| 240 | + f"{len(outputs[req_id])=}, {MAX_TOKENS=}") |
| 241 | + """Utility method invocation""" |
242 | 242 |
|
243 |
| - core_client: AsyncMPClient = client |
| 243 | + core_client: AsyncMPClient = client |
244 | 244 |
|
245 |
| - result = await core_client.call_utility_async("echo", "testarg") |
246 |
| - assert result == "testarg" |
| 245 | + result = await core_client.call_utility_async("echo", "testarg") |
| 246 | + assert result == "testarg" |
247 | 247 |
|
248 |
| - with pytest.raises(Exception) as e_info: |
249 |
| - await core_client.call_utility_async("echo", None, "help!") |
| 248 | + with pytest.raises(Exception) as e_info: |
| 249 | + await core_client.call_utility_async("echo", None, "help!") |
250 | 250 |
|
251 |
| - assert str(e_info.value) == "Call to echo method failed: help!" |
| 251 | + assert str(e_info.value) == "Call to echo method failed: help!" |
| 252 | + finally: |
| 253 | + client.shutdown() |
252 | 254 |
|
253 | 255 |
|
254 | 256 | @pytest.mark.parametrize(
|
@@ -333,10 +335,6 @@ def test_kv_cache_events(
|
333 | 335 | "Token ids should be the same as the custom tokens")
|
334 | 336 | finally:
|
335 | 337 | client.shutdown()
|
336 |
| - subscriber.close() |
337 |
| - # TODO hack to try and fix CI hang |
338 |
| - ctx = zmq.Context.instance() |
339 |
| - ctx.term() |
340 | 338 | return
|
341 | 339 |
|
342 | 340 |
|
|
0 commit comments