Skip to content

Commit 01f5e86

Browse files
authored
Convert MCP schemas to strict where possible (#414)
## Summary: Towards #404. I made this configurable because it's not clear this is always a good thing to do. I also made it default to False because I'm not sure if this could cause errors. If it works out well, we can switch the default in the future as a small breaking changes ## Test Plan: Unit tests
1 parent 45c25f8 commit 01f5e86

File tree

5 files changed

+202
-24
lines changed

5 files changed

+202
-24
lines changed

src/agents/agent.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass, field
77
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
88

9-
from typing_extensions import TypeAlias, TypedDict
9+
from typing_extensions import NotRequired, TypeAlias, TypedDict
1010

1111
from .guardrail import InputGuardrail, OutputGuardrail
1212
from .handoffs import Handoff
@@ -53,6 +53,15 @@ class StopAtTools(TypedDict):
5353
"""A list of tool names, any of which will stop the agent from running further."""
5454

5555

56+
class MCPConfig(TypedDict):
57+
"""Configuration for MCP servers."""
58+
59+
convert_schemas_to_strict: NotRequired[bool]
60+
"""If True, we will attempt to convert the MCP schemas to strict-mode schemas. This is a
61+
best-effort conversion, so some schemas may not be convertible. Defaults to False.
62+
"""
63+
64+
5665
@dataclass
5766
class Agent(Generic[TContext]):
5867
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
@@ -119,6 +128,9 @@ class Agent(Generic[TContext]):
119128
longer needed.
120129
"""
121130

131+
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
132+
"""Configuration for MCP servers."""
133+
122134
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
123135
"""A list of checks that run in parallel to the agent's execution, before generating a
124136
response. Runs only if the agent is the first agent in the chain.
@@ -224,7 +236,8 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s
224236

225237
async def get_mcp_tools(self) -> list[Tool]:
226238
"""Fetches the available tools from the MCP servers."""
227-
return await MCPUtil.get_all_function_tools(self.mcp_servers)
239+
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
240+
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
228241

229242
async def get_all_tools(self) -> list[Tool]:
230243
"""All agent tools, including MCP tools and function tools."""

src/agents/mcp/util.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import json
33
from typing import TYPE_CHECKING, Any
44

5+
from agents.strict_schema import ensure_strict_json_schema
6+
57
from .. import _debug
68
from ..exceptions import AgentsException, ModelBehaviorError, UserError
79
from ..logger import logger
@@ -19,12 +21,14 @@ class MCPUtil:
1921
"""Set of utilities for interop between MCP and Agents SDK tools."""
2022

2123
@classmethod
22-
async def get_all_function_tools(cls, servers: list["MCPServer"]) -> list[Tool]:
24+
async def get_all_function_tools(
25+
cls, servers: list["MCPServer"], convert_schemas_to_strict: bool
26+
) -> list[Tool]:
2327
"""Get all function tools from a list of MCP servers."""
2428
tools = []
2529
tool_names: set[str] = set()
2630
for server in servers:
27-
server_tools = await cls.get_function_tools(server)
31+
server_tools = await cls.get_function_tools(server, convert_schemas_to_strict)
2832
server_tool_names = {tool.name for tool in server_tools}
2933
if len(server_tool_names & tool_names) > 0:
3034
raise UserError(
@@ -37,25 +41,37 @@ async def get_all_function_tools(cls, servers: list["MCPServer"]) -> list[Tool]:
3741
return tools
3842

3943
@classmethod
40-
async def get_function_tools(cls, server: "MCPServer") -> list[Tool]:
44+
async def get_function_tools(
45+
cls, server: "MCPServer", convert_schemas_to_strict: bool
46+
) -> list[Tool]:
4147
"""Get all function tools from a single MCP server."""
4248

4349
with mcp_tools_span(server=server.name) as span:
4450
tools = await server.list_tools()
4551
span.span_data.result = [tool.name for tool in tools]
4652

47-
return [cls.to_function_tool(tool, server) for tool in tools]
53+
return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools]
4854

4955
@classmethod
50-
def to_function_tool(cls, tool: "MCPTool", server: "MCPServer") -> FunctionTool:
56+
def to_function_tool(
57+
cls, tool: "MCPTool", server: "MCPServer", convert_schemas_to_strict: bool
58+
) -> FunctionTool:
5159
"""Convert an MCP tool to an Agents SDK function tool."""
5260
invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool)
61+
schema, is_strict = tool.inputSchema, False
62+
if convert_schemas_to_strict:
63+
try:
64+
schema = ensure_strict_json_schema(schema)
65+
is_strict = True
66+
except Exception as e:
67+
logger.info(f"Error converting MCP schema to strict mode: {e}")
68+
5369
return FunctionTool(
5470
name=tool.name,
5571
description=tool.description or "",
56-
params_json_schema=tool.inputSchema,
72+
params_json_schema=schema,
5773
on_invoke_tool=invoke_func,
58-
strict_json_schema=False,
74+
strict_json_schema=is_strict,
5975
)
6076

6177
@classmethod

tests/mcp/test_mcp_util.py

+157-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from typing import Any
33

44
import pytest
5+
from inline_snapshot import snapshot
56
from mcp.types import Tool as MCPTool
6-
from pydantic import BaseModel
7+
from pydantic import BaseModel, TypeAdapter
78

8-
from agents import FunctionTool, RunContextWrapper
9+
from agents import Agent, FunctionTool, RunContextWrapper
910
from agents.exceptions import AgentsException, ModelBehaviorError
1011
from agents.mcp import MCPServer, MCPUtil
1112

@@ -18,7 +19,16 @@ class Foo(BaseModel):
1819

1920

2021
class Bar(BaseModel):
21-
qux: str
22+
qux: dict[str, str]
23+
24+
25+
Baz = TypeAdapter(dict[str, str])
26+
27+
28+
def _convertible_schema() -> dict[str, Any]:
29+
schema = Foo.model_json_schema()
30+
schema["additionalProperties"] = False
31+
return schema
2232

2333

2434
@pytest.mark.asyncio
@@ -47,7 +57,7 @@ async def test_get_all_function_tools():
4757
server3.add_tool(names[4], schemas[4])
4858

4959
servers: list[MCPServer] = [server1, server2, server3]
50-
tools = await MCPUtil.get_all_function_tools(servers)
60+
tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=False)
5161
assert len(tools) == 5
5262
assert all(tool.name in names for tool in tools)
5363

@@ -56,6 +66,11 @@ async def test_get_all_function_tools():
5666
assert tool.params_json_schema == schemas[idx]
5767
assert tool.name == names[idx]
5868

69+
# Also make sure it works with strict schemas
70+
tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=True)
71+
assert len(tools) == 5
72+
assert all(tool.name in names for tool in tools)
73+
5974

6075
@pytest.mark.asyncio
6176
async def test_invoke_mcp_tool():
@@ -107,3 +122,141 @@ async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixtur
107122
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "")
108123

109124
assert "Error invoking MCP tool test_tool_1" in caplog.text
125+
126+
127+
@pytest.mark.asyncio
128+
async def test_agent_convert_schemas_true():
129+
"""Test that setting convert_schemas_to_strict to True converts non-strict schemas to strict.
130+
- 'foo' tool is already strict and remains strict.
131+
- 'bar' tool is non-strict and becomes strict (additionalProperties set to False, etc).
132+
"""
133+
strict_schema = Foo.model_json_schema()
134+
non_strict_schema = Baz.json_schema()
135+
possible_to_convert_schema = _convertible_schema()
136+
137+
server = FakeMCPServer()
138+
server.add_tool("foo", strict_schema)
139+
server.add_tool("bar", non_strict_schema)
140+
server.add_tool("baz", possible_to_convert_schema)
141+
agent = Agent(
142+
name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": True}
143+
)
144+
tools = await agent.get_mcp_tools()
145+
146+
foo_tool = next(tool for tool in tools if tool.name == "foo")
147+
assert isinstance(foo_tool, FunctionTool)
148+
bar_tool = next(tool for tool in tools if tool.name == "bar")
149+
assert isinstance(bar_tool, FunctionTool)
150+
baz_tool = next(tool for tool in tools if tool.name == "baz")
151+
assert isinstance(baz_tool, FunctionTool)
152+
153+
# Checks that additionalProperties is set to False
154+
assert foo_tool.params_json_schema == snapshot(
155+
{
156+
"properties": {
157+
"bar": {"title": "Bar", "type": "string"},
158+
"baz": {"title": "Baz", "type": "integer"},
159+
},
160+
"required": ["bar", "baz"],
161+
"title": "Foo",
162+
"type": "object",
163+
"additionalProperties": False,
164+
}
165+
)
166+
assert foo_tool.strict_json_schema is True, "foo_tool should be strict"
167+
168+
# Checks that additionalProperties is set to False
169+
assert bar_tool.params_json_schema == snapshot(
170+
{
171+
"type": "object",
172+
"additionalProperties": {"type": "string"},
173+
}
174+
)
175+
assert bar_tool.strict_json_schema is False, "bar_tool should not be strict"
176+
177+
# Checks that additionalProperties is set to False
178+
assert baz_tool.params_json_schema == snapshot(
179+
{
180+
"properties": {
181+
"bar": {"title": "Bar", "type": "string"},
182+
"baz": {"title": "Baz", "type": "integer"},
183+
},
184+
"required": ["bar", "baz"],
185+
"title": "Foo",
186+
"type": "object",
187+
"additionalProperties": False,
188+
}
189+
)
190+
assert baz_tool.strict_json_schema is True, "baz_tool should be strict"
191+
192+
193+
@pytest.mark.asyncio
194+
async def test_agent_convert_schemas_false():
195+
"""Test that setting convert_schemas_to_strict to False leaves tool schemas as non-strict.
196+
- 'foo' tool remains strict.
197+
- 'bar' tool remains non-strict (additionalProperties remains True).
198+
"""
199+
strict_schema = Foo.model_json_schema()
200+
non_strict_schema = Baz.json_schema()
201+
possible_to_convert_schema = _convertible_schema()
202+
203+
server = FakeMCPServer()
204+
server.add_tool("foo", strict_schema)
205+
server.add_tool("bar", non_strict_schema)
206+
server.add_tool("baz", possible_to_convert_schema)
207+
208+
agent = Agent(
209+
name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": False}
210+
)
211+
tools = await agent.get_mcp_tools()
212+
213+
foo_tool = next(tool for tool in tools if tool.name == "foo")
214+
assert isinstance(foo_tool, FunctionTool)
215+
bar_tool = next(tool for tool in tools if tool.name == "bar")
216+
assert isinstance(bar_tool, FunctionTool)
217+
baz_tool = next(tool for tool in tools if tool.name == "baz")
218+
assert isinstance(baz_tool, FunctionTool)
219+
220+
assert foo_tool.params_json_schema == strict_schema
221+
assert foo_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
222+
223+
assert bar_tool.params_json_schema == non_strict_schema
224+
assert bar_tool.strict_json_schema is False
225+
226+
assert baz_tool.params_json_schema == possible_to_convert_schema
227+
assert baz_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
228+
229+
230+
@pytest.mark.asyncio
231+
async def test_agent_convert_schemas_unset():
232+
"""Test that leaving convert_schemas_to_strict unset (defaulting to False) leaves tool schemas
233+
as non-strict.
234+
- 'foo' tool remains strict.
235+
- 'bar' tool remains non-strict.
236+
"""
237+
strict_schema = Foo.model_json_schema()
238+
non_strict_schema = Baz.json_schema()
239+
possible_to_convert_schema = _convertible_schema()
240+
241+
server = FakeMCPServer()
242+
server.add_tool("foo", strict_schema)
243+
server.add_tool("bar", non_strict_schema)
244+
server.add_tool("baz", possible_to_convert_schema)
245+
agent = Agent(name="test_agent", mcp_servers=[server])
246+
tools = await agent.get_mcp_tools()
247+
248+
foo_tool = next(tool for tool in tools if tool.name == "foo")
249+
assert isinstance(foo_tool, FunctionTool)
250+
bar_tool = next(tool for tool in tools if tool.name == "bar")
251+
assert isinstance(bar_tool, FunctionTool)
252+
baz_tool = next(tool for tool in tools if tool.name == "baz")
253+
assert isinstance(baz_tool, FunctionTool)
254+
255+
assert foo_tool.params_json_schema == strict_schema
256+
assert foo_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
257+
258+
assert bar_tool.params_json_schema == non_strict_schema
259+
assert bar_tool.strict_json_schema is False
260+
261+
assert baz_tool.params_json_schema == possible_to_convert_schema
262+
assert baz_tool.strict_json_schema is False, "Shouldn't be converted unless specified"

tests/test_agent_runner.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -642,9 +642,7 @@ async def test_tool_use_behavior_custom_function():
642642
async def test_model_settings_override():
643643
model = FakeModel()
644644
agent = Agent(
645-
name="test",
646-
model=model,
647-
model_settings=ModelSettings(temperature=1.0, max_tokens=1000)
645+
name="test", model=model, model_settings=ModelSettings(temperature=1.0, max_tokens=1000)
648646
)
649647

650648
model.add_multiple_turn_outputs(

tests/test_tracing_errors.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,10 @@ async def test_multiple_handoff_doesnt_error():
244244
},
245245
},
246246
{"type": "generation"},
247-
{"type": "handoff",
248-
"data": {"from_agent": "test", "to_agent": "test"},
249-
"error": {
247+
{
248+
"type": "handoff",
249+
"data": {"from_agent": "test", "to_agent": "test"},
250+
"error": {
250251
"data": {
251252
"requested_agents": [
252253
"test",
@@ -255,7 +256,7 @@ async def test_multiple_handoff_doesnt_error():
255256
},
256257
"message": "Multiple handoffs requested",
257258
},
258-
},
259+
},
259260
],
260261
},
261262
{
@@ -383,10 +384,7 @@ async def test_handoffs_lead_to_correct_agent_spans():
383384
{"type": "generation"},
384385
{
385386
"type": "handoff",
386-
"data": {
387-
"from_agent": "test_agent_3",
388-
"to_agent": "test_agent_1"
389-
},
387+
"data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"},
390388
"error": {
391389
"data": {
392390
"requested_agents": [

0 commit comments

Comments
 (0)