Skip to content

Fix async callable object tools #568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/mcp/server/fastmcp/tools/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations as _annotations

import functools
import inspect
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, get_origin
Expand Down Expand Up @@ -48,7 +49,7 @@ def from_function(
raise ValueError("You must provide a name for lambda functions")

func_doc = description or fn.__doc__ or ""
is_async = inspect.iscoroutinefunction(fn)
is_async = _is_async_callable(fn)

if context_kwarg is None:
sig = inspect.signature(fn)
Expand Down Expand Up @@ -92,3 +93,12 @@ async def run(
)
except Exception as e:
raise ToolError(f"Error executing tool {self.name}: {e}") from e


def _is_async_callable(obj: Any) -> bool:
while isinstance(obj, functools.partial):
obj = obj.func

return inspect.iscoroutinefunction(obj) or (
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
)
61 changes: 61 additions & 0 deletions tests/server/fastmcp/test_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,39 @@ def create_user(user: UserInput, flag: bool) -> dict:
assert "age" in tool.parameters["$defs"]["UserInput"]["properties"]
assert "flag" in tool.parameters["properties"]

def test_add_callable_object(self):
"""Test registering a callable object."""

class MyTool:
def __init__(self):
self.__name__ = "MyTool"

def __call__(self, x: int) -> int:
return x * 2

manager = ToolManager()
tool = manager.add_tool(MyTool())
assert tool.name == "MyTool"
assert tool.is_async is False
assert tool.parameters["properties"]["x"]["type"] == "integer"

@pytest.mark.anyio
async def test_add_async_callable_object(self):
"""Test registering an async callable object."""

class MyAsyncTool:
def __init__(self):
self.__name__ = "MyAsyncTool"

async def __call__(self, x: int) -> int:
return x * 2

manager = ToolManager()
tool = manager.add_tool(MyAsyncTool())
assert tool.name == "MyAsyncTool"
assert tool.is_async is True
assert tool.parameters["properties"]["x"]["type"] == "integer"

def test_add_invalid_tool(self):
manager = ToolManager()
with pytest.raises(AttributeError):
Expand Down Expand Up @@ -137,6 +170,34 @@ async def double(n: int) -> int:
result = await manager.call_tool("double", {"n": 5})
assert result == 10

@pytest.mark.anyio
async def test_call_object_tool(self):
class MyTool:
def __init__(self):
self.__name__ = "MyTool"

def __call__(self, x: int) -> int:
return x * 2

manager = ToolManager()
tool = manager.add_tool(MyTool())
result = await tool.run({"x": 5})
assert result == 10

@pytest.mark.anyio
async def test_call_async_object_tool(self):
class MyAsyncTool:
def __init__(self):
self.__name__ = "MyAsyncTool"

async def __call__(self, x: int) -> int:
return x * 2

manager = ToolManager()
tool = manager.add_tool(MyAsyncTool())
result = await tool.run({"x": 5})
assert result == 10

@pytest.mark.anyio
async def test_call_tool_with_default_args(self):
def add(a: int, b: int = 1) -> int:
Expand Down
Loading