Skip to content

Properly clean up response streams in BaseSession #515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 39 additions & 37 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def __init__(
self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - is this intentional?

self._exit_stack = AsyncExitStack()

async def __aenter__(self) -> Self:
Expand Down Expand Up @@ -232,45 +231,48 @@ async def send_request(
](1)
self._response_streams[request_id] = response_stream

self._exit_stack.push_async_callback(lambda: response_stream.aclose())
self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose())

jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
)

# TODO: Support progress callbacks

await self._write_stream.send(JSONRPCMessage(jsonrpc_request))

# request read timeout takes precedence over session read timeout
timeout = None
if request_read_timeout_seconds is not None:
timeout = request_read_timeout_seconds.total_seconds()
elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds()

try:
with anyio.fail_after(timeout):
response_or_error = await response_stream_reader.receive()
except TimeoutError:
raise McpError(
ErrorData(
code=httpx.codes.REQUEST_TIMEOUT,
message=(
f"Timed out while waiting for response to "
f"{request.__class__.__name__}. Waited "
f"{timeout} seconds."
),
)
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
)

if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error)
else:
return result_type.model_validate(response_or_error.result)
# TODO: Support progress callbacks

await self._write_stream.send(JSONRPCMessage(jsonrpc_request))

# request read timeout takes precedence over session read timeout
timeout = None
if request_read_timeout_seconds is not None:
timeout = request_read_timeout_seconds.total_seconds()
elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds()

try:
with anyio.fail_after(timeout):
response_or_error = await response_stream_reader.receive()
except TimeoutError:
raise McpError(
ErrorData(
code=httpx.codes.REQUEST_TIMEOUT,
message=(
f"Timed out while waiting for response to "
f"{request.__class__.__name__}. Waited "
f"{timeout} seconds."
),
)
)

if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error)
else:
return result_type.model_validate(response_or_error.result)

finally:
self._response_streams.pop(request_id, None)
await response_stream.aclose()
await response_stream_reader.aclose()
Comment on lines +274 to +275
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct tho? Do you want to close the stream on every request? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we're creating them on every request, so I think the old code was just piling
up open streams until the end of the session (or getting unclosed resource errors if they got __del__d before the session's __aexit__, or if the latter never ran).

Please check my logic though - I'm new to this code. E.g. is there a reason we need to leave streams open after getting the response?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but the connection stays open for some time after you send back the response.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, sounds like this would need a more sophisticated solution. I'll move this to draft and come back to it. Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, it stays alive for 5 secs (https://www.uvicorn.org/settings/#timeouts) fyi

Copy link
Contributor Author

@bhosmer-ant bhosmer-ant Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kludex just to clarify - these are in-memory streams, not network streams. (_receive_loop gets the network response and writes it to response_stream). There may still be a problem I'm not aware of with closing the streams immediately, but want to make sure you weren't worried about closing while network connections were still open.

[edit: aha wrote this before seeing your fyi, so it sounds like you were worried about closing HTTPS streams, right? undrafting based on this assumption]


async def send_notification(self, notification: SendNotificationT) -> None:
"""
Expand Down
68 changes: 68 additions & 0 deletions tests/client/test_resource_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from unittest.mock import patch

import anyio
import pytest

from mcp.shared.session import BaseSession
from mcp.types import (
ClientRequest,
EmptyResult,
PingRequest,
)


@pytest.mark.anyio
async def test_send_request_stream_cleanup():
"""
Test that send_request properly cleans up streams when an exception occurs.

This test mocks out most of the session functionality to focus on stream cleanup.
"""

# Create a mock session with the minimal required functionality
class TestSession(BaseSession):
async def _send_response(self, request_id, response):
pass

# Create streams
write_stream_send, write_stream_receive = anyio.create_memory_object_stream(1)
read_stream_send, read_stream_receive = anyio.create_memory_object_stream(1)

# Create the session
session = TestSession(
read_stream_receive,
write_stream_send,
object, # Request type doesn't matter for this test
object, # Notification type doesn't matter for this test
)

# Create a test request
request = ClientRequest(
PingRequest(
method="ping",
)
)

# Patch the _write_stream.send method to raise an exception
async def mock_send(*args, **kwargs):
raise RuntimeError("Simulated network error")

# Record the response streams before the test
initial_stream_count = len(session._response_streams)

# Run the test with the patched method
with patch.object(session._write_stream, "send", mock_send):
with pytest.raises(RuntimeError):
await session.send_request(request, EmptyResult)

# Verify that no response streams were leaked
assert len(session._response_streams) == initial_stream_count, (
f"Expected {initial_stream_count} response streams after request, "
f"but found {len(session._response_streams)}"
)

# Clean up
await write_stream_send.aclose()
await write_stream_receive.aclose()
await read_stream_send.aclose()
await read_stream_receive.aclose()
Loading