Skip to content

Commit af80e3a

Browse files
authored
Prevent MCP ClientSession hang (#580)
Per https://modelcontextprotocol.io/specification/draft/basic/lifecycle#timeouts "Implementations SHOULD establish timeouts for all sent requests, to prevent hung connections and resource exhaustion. When the request has not received a success or error response within the timeout period, the sender SHOULD issue a cancellation notification for that request and stop waiting for a response. SDKs and other middleware SHOULD allow these timeouts to be configured on a per-request basis." I picked 5 seconds since that's the default for SSE
1 parent 3755ea8 commit af80e3a

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

src/agents/mcp/server.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import abc
44
import asyncio
55
from contextlib import AbstractAsyncContextManager, AsyncExitStack
6+
from datetime import timedelta
67
from pathlib import Path
78
from typing import Any, Literal
89

@@ -54,7 +55,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
5455
class _MCPServerWithClientSession(MCPServer, abc.ABC):
5556
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
5657

57-
def __init__(self, cache_tools_list: bool):
58+
def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None):
5859
"""
5960
Args:
6061
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
@@ -63,12 +64,16 @@ def __init__(self, cache_tools_list: bool):
6364
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
6465
server will not change its tools list, because it can drastically improve latency
6566
(by avoiding a round-trip to the server every time).
67+
68+
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
6669
"""
6770
self.session: ClientSession | None = None
6871
self.exit_stack: AsyncExitStack = AsyncExitStack()
6972
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
7073
self.cache_tools_list = cache_tools_list
7174

75+
self.client_session_timeout_seconds = client_session_timeout_seconds
76+
7277
# The cache is always dirty at startup, so that we fetch tools at least once
7378
self._cache_dirty = True
7479
self._tools_list: list[MCPTool] | None = None
@@ -101,7 +106,15 @@ async def connect(self):
101106
try:
102107
transport = await self.exit_stack.enter_async_context(self.create_streams())
103108
read, write = transport
104-
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
109+
session = await self.exit_stack.enter_async_context(
110+
ClientSession(
111+
read,
112+
write,
113+
timedelta(seconds=self.client_session_timeout_seconds)
114+
if self.client_session_timeout_seconds
115+
else None,
116+
)
117+
)
105118
await session.initialize()
106119
self.session = session
107120
except Exception as e:
@@ -183,6 +196,7 @@ def __init__(
183196
params: MCPServerStdioParams,
184197
cache_tools_list: bool = False,
185198
name: str | None = None,
199+
client_session_timeout_seconds: float | None = 5,
186200
):
187201
"""Create a new MCP server based on the stdio transport.
188202
@@ -199,8 +213,9 @@ def __init__(
199213
improve latency (by avoiding a round-trip to the server every time).
200214
name: A readable name for the server. If not provided, we'll create one from the
201215
command.
216+
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
202217
"""
203-
super().__init__(cache_tools_list)
218+
super().__init__(cache_tools_list, client_session_timeout_seconds)
204219

205220
self.params = StdioServerParameters(
206221
command=params["command"],
@@ -257,6 +272,7 @@ def __init__(
257272
params: MCPServerSseParams,
258273
cache_tools_list: bool = False,
259274
name: str | None = None,
275+
client_session_timeout_seconds: float | None = 5,
260276
):
261277
"""Create a new MCP server based on the HTTP with SSE transport.
262278
@@ -274,8 +290,10 @@ def __init__(
274290
275291
name: A readable name for the server. If not provided, we'll create one from the
276292
URL.
293+
294+
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
277295
"""
278-
super().__init__(cache_tools_list)
296+
super().__init__(cache_tools_list, client_session_timeout_seconds)
279297

280298
self.params = params
281299
self._name = name or f"sse: {self.params['url']}"

tests/mcp/test_server_errors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
class CrashingClientSessionServer(_MCPServerWithClientSession):
88
def __init__(self):
9-
super().__init__(cache_tools_list=False)
9+
super().__init__(cache_tools_list=False, client_session_timeout_seconds=5)
1010
self.cleanup_called = False
1111

1212
def create_streams(self):

0 commit comments

Comments
 (0)