3
3
import abc
4
4
import asyncio
5
5
from contextlib import AbstractAsyncContextManager , AsyncExitStack
6
+ from datetime import timedelta
6
7
from pathlib import Path
7
8
from typing import Any , Literal
8
9
@@ -54,7 +55,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
54
55
class _MCPServerWithClientSession (MCPServer , abc .ABC ):
55
56
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
56
57
57
- def __init__ (self , cache_tools_list : bool ):
58
+ def __init__ (self , cache_tools_list : bool , client_session_timeout_seconds : float | None ):
58
59
"""
59
60
Args:
60
61
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
@@ -63,12 +64,16 @@ def __init__(self, cache_tools_list: bool):
63
64
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
64
65
server will not change its tools list, because it can drastically improve latency
65
66
(by avoiding a round-trip to the server every time).
67
+
68
+ client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
66
69
"""
67
70
self .session : ClientSession | None = None
68
71
self .exit_stack : AsyncExitStack = AsyncExitStack ()
69
72
self ._cleanup_lock : asyncio .Lock = asyncio .Lock ()
70
73
self .cache_tools_list = cache_tools_list
71
74
75
+ self .client_session_timeout_seconds = client_session_timeout_seconds
76
+
72
77
# The cache is always dirty at startup, so that we fetch tools at least once
73
78
self ._cache_dirty = True
74
79
self ._tools_list : list [MCPTool ] | None = None
@@ -101,7 +106,15 @@ async def connect(self):
101
106
try :
102
107
transport = await self .exit_stack .enter_async_context (self .create_streams ())
103
108
read , write = transport
104
- session = await self .exit_stack .enter_async_context (ClientSession (read , write ))
109
+ session = await self .exit_stack .enter_async_context (
110
+ ClientSession (
111
+ read ,
112
+ write ,
113
+ timedelta (seconds = self .client_session_timeout_seconds )
114
+ if self .client_session_timeout_seconds
115
+ else None ,
116
+ )
117
+ )
105
118
await session .initialize ()
106
119
self .session = session
107
120
except Exception as e :
@@ -183,6 +196,7 @@ def __init__(
183
196
params : MCPServerStdioParams ,
184
197
cache_tools_list : bool = False ,
185
198
name : str | None = None ,
199
+ client_session_timeout_seconds : float | None = 5 ,
186
200
):
187
201
"""Create a new MCP server based on the stdio transport.
188
202
@@ -199,8 +213,9 @@ def __init__(
199
213
improve latency (by avoiding a round-trip to the server every time).
200
214
name: A readable name for the server. If not provided, we'll create one from the
201
215
command.
216
+ client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
202
217
"""
203
- super ().__init__ (cache_tools_list )
218
+ super ().__init__ (cache_tools_list , client_session_timeout_seconds )
204
219
205
220
self .params = StdioServerParameters (
206
221
command = params ["command" ],
@@ -257,6 +272,7 @@ def __init__(
257
272
params : MCPServerSseParams ,
258
273
cache_tools_list : bool = False ,
259
274
name : str | None = None ,
275
+ client_session_timeout_seconds : float | None = 5 ,
260
276
):
261
277
"""Create a new MCP server based on the HTTP with SSE transport.
262
278
@@ -274,8 +290,10 @@ def __init__(
274
290
275
291
name: A readable name for the server. If not provided, we'll create one from the
276
292
URL.
293
+
294
+ client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
277
295
"""
278
- super ().__init__ (cache_tools_list )
296
+ super ().__init__ (cache_tools_list , client_session_timeout_seconds )
279
297
280
298
self .params = params
281
299
self ._name = name or f"sse: { self .params ['url' ]} "
0 commit comments