-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Properly clean up response streams in BaseSession #515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
46603d1
7a2e646
48890f1
db6a17f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -187,7 +187,6 @@ def __init__( | |
self._receive_notification_type = receive_notification_type | ||
self._session_read_timeout_seconds = read_timeout_seconds | ||
self._in_flight = {} | ||
|
||
self._exit_stack = AsyncExitStack() | ||
|
||
async def __aenter__(self) -> Self: | ||
|
@@ -232,45 +231,48 @@ async def send_request( | |
](1) | ||
self._response_streams[request_id] = response_stream | ||
|
||
self._exit_stack.push_async_callback(lambda: response_stream.aclose()) | ||
self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose()) | ||
|
||
jsonrpc_request = JSONRPCRequest( | ||
jsonrpc="2.0", | ||
id=request_id, | ||
**request.model_dump(by_alias=True, mode="json", exclude_none=True), | ||
) | ||
|
||
# TODO: Support progress callbacks | ||
|
||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) | ||
|
||
# request read timeout takes precedence over session read timeout | ||
timeout = None | ||
if request_read_timeout_seconds is not None: | ||
timeout = request_read_timeout_seconds.total_seconds() | ||
elif self._session_read_timeout_seconds is not None: | ||
timeout = self._session_read_timeout_seconds.total_seconds() | ||
|
||
try: | ||
with anyio.fail_after(timeout): | ||
response_or_error = await response_stream_reader.receive() | ||
except TimeoutError: | ||
raise McpError( | ||
ErrorData( | ||
code=httpx.codes.REQUEST_TIMEOUT, | ||
message=( | ||
f"Timed out while waiting for response to " | ||
f"{request.__class__.__name__}. Waited " | ||
f"{timeout} seconds." | ||
), | ||
) | ||
jsonrpc_request = JSONRPCRequest( | ||
jsonrpc="2.0", | ||
id=request_id, | ||
**request.model_dump(by_alias=True, mode="json", exclude_none=True), | ||
) | ||
|
||
if isinstance(response_or_error, JSONRPCError): | ||
raise McpError(response_or_error.error) | ||
else: | ||
return result_type.model_validate(response_or_error.result) | ||
# TODO: Support progress callbacks | ||
|
||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) | ||
|
||
# request read timeout takes precedence over session read timeout | ||
timeout = None | ||
if request_read_timeout_seconds is not None: | ||
timeout = request_read_timeout_seconds.total_seconds() | ||
elif self._session_read_timeout_seconds is not None: | ||
timeout = self._session_read_timeout_seconds.total_seconds() | ||
|
||
try: | ||
with anyio.fail_after(timeout): | ||
response_or_error = await response_stream_reader.receive() | ||
except TimeoutError: | ||
raise McpError( | ||
ErrorData( | ||
code=httpx.codes.REQUEST_TIMEOUT, | ||
message=( | ||
f"Timed out while waiting for response to " | ||
f"{request.__class__.__name__}. Waited " | ||
f"{timeout} seconds." | ||
), | ||
) | ||
) | ||
|
||
if isinstance(response_or_error, JSONRPCError): | ||
raise McpError(response_or_error.error) | ||
else: | ||
return result_type.model_validate(response_or_error.result) | ||
|
||
finally: | ||
self._response_streams.pop(request_id, None) | ||
await response_stream.aclose() | ||
await response_stream_reader.aclose() | ||
Comment on lines
+274
to
+275
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this correct tho? Do you want to close the stream on every request? 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, we're creating them on every request, so I think the old code was just piling Please check my logic though - I'm new to this code. E.g. is there a reason we need to leave streams open after getting the response? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, but the connection stays open for some time after you send back the response. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah ok, sounds like this would need a more sophisticated solution. I'll move this to draft and come back to it. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By default, it stays alive for 5 secs (https://www.uvicorn.org/settings/#timeouts) fyi There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Kludex just to clarify - these are in-memory streams, not network streams. ( [edit: aha wrote this before seeing your fyi, so it sounds like you were worried about closing HTTPS streams, right? undrafting based on this assumption] |
||
|
||
async def send_notification(self, notification: SendNotificationT) -> None: | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from unittest.mock import patch | ||
|
||
import anyio | ||
import pytest | ||
|
||
from mcp.shared.session import BaseSession | ||
from mcp.types import ( | ||
ClientRequest, | ||
EmptyResult, | ||
PingRequest, | ||
) | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_send_request_stream_cleanup(): | ||
""" | ||
Test that send_request properly cleans up streams when an exception occurs. | ||
|
||
This test mocks out most of the session functionality to focus on stream cleanup. | ||
""" | ||
|
||
# Create a mock session with the minimal required functionality | ||
class TestSession(BaseSession): | ||
async def _send_response(self, request_id, response): | ||
pass | ||
|
||
# Create streams | ||
write_stream_send, write_stream_receive = anyio.create_memory_object_stream(1) | ||
read_stream_send, read_stream_receive = anyio.create_memory_object_stream(1) | ||
|
||
# Create the session | ||
session = TestSession( | ||
read_stream_receive, | ||
write_stream_send, | ||
object, # Request type doesn't matter for this test | ||
object, # Notification type doesn't matter for this test | ||
) | ||
|
||
# Create a test request | ||
request = ClientRequest( | ||
PingRequest( | ||
method="ping", | ||
) | ||
) | ||
|
||
# Patch the _write_stream.send method to raise an exception | ||
async def mock_send(*args, **kwargs): | ||
raise RuntimeError("Simulated network error") | ||
|
||
# Record the response streams before the test | ||
initial_stream_count = len(session._response_streams) | ||
|
||
# Run the test with the patched method | ||
with patch.object(session._write_stream, "send", mock_send): | ||
with pytest.raises(RuntimeError): | ||
await session.send_request(request, EmptyResult) | ||
|
||
# Verify that no response streams were leaked | ||
assert len(session._response_streams) == initial_stream_count, ( | ||
f"Expected {initial_stream_count} response streams after request, " | ||
f"but found {len(session._response_streams)}" | ||
) | ||
|
||
# Clean up | ||
await write_stream_send.aclose() | ||
await write_stream_receive.aclose() | ||
await read_stream_send.aclose() | ||
await read_stream_receive.aclose() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - is this intentional?