Skip to content

Commit e3698f3

Browse files
authored
Enable non-strict output types (#539)
See #528, some folks are having issues because their output types are not strict-compatible. My approach was: 1. Create `AgentOutputSchemaBase`, which represents the base methods for an output type - the json schema + validation 2. Make the existing `AgentOutputSchema` subclass `AgentOutputSchemaBase` 3. Allow users to pass a `AgentOutputSchemaBase` to `Agent(output_type=...)`
1 parent 4b8472d commit e3698f3

18 files changed

+256
-61
lines changed
+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import asyncio
2+
import json
3+
from dataclasses import dataclass
4+
from typing import Any
5+
6+
from agents import Agent, AgentOutputSchema, AgentOutputSchemaBase, Runner
7+
8+
"""This example demonstrates how to use an output type that is not in strict mode. Strict mode
9+
allows us to guarantee valid JSON output, but some schemas are not strict-compatible.
10+
11+
In this example, we define an output type that is not strict-compatible, and then we run the
12+
agent with strict_json_schema=False.
13+
14+
We also demonstrate a custom output type.
15+
16+
To understand which schemas are strict-compatible, see:
17+
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
18+
"""
19+
20+
21+
@dataclass
22+
class OutputType:
23+
jokes: dict[int, str]
24+
"""A list of jokes, indexed by joke number."""
25+
26+
27+
class CustomOutputSchema(AgentOutputSchemaBase):
28+
"""A demonstration of a custom output schema."""
29+
30+
def is_plain_text(self) -> bool:
31+
return False
32+
33+
def name(self) -> str:
34+
return "CustomOutputSchema"
35+
36+
def json_schema(self) -> dict[str, Any]:
37+
return {
38+
"type": "object",
39+
"properties": {"jokes": {"type": "object", "properties": {"joke": {"type": "string"}}}},
40+
}
41+
42+
def is_strict_json_schema(self) -> bool:
43+
return False
44+
45+
def validate_json(self, json_str: str) -> Any:
46+
json_obj = json.loads(json_str)
47+
# Just for demonstration, we'll return a list.
48+
return list(json_obj["jokes"].values())
49+
50+
51+
async def main():
52+
agent = Agent(
53+
name="Assistant",
54+
instructions="You are a helpful assistant.",
55+
output_type=OutputType,
56+
)
57+
58+
input = "Tell me 3 short jokes."
59+
60+
# First, let's try with a strict output type. This should raise an exception.
61+
try:
62+
result = await Runner.run(agent, input)
63+
raise AssertionError("Should have raised an exception")
64+
except Exception as e:
65+
print(f"Error (expected): {e}")
66+
67+
# Now let's try again with a non-strict output type. This should work.
68+
# In some cases, it will raise an error - the schema isn't strict, so the model may
69+
# produce an invalid JSON object.
70+
agent.output_type = AgentOutputSchema(OutputType, strict_json_schema=False)
71+
result = await Runner.run(agent, input)
72+
print(result.final_output)
73+
74+
# Finally, let's try a custom output type.
75+
agent.output_type = CustomOutputSchema()
76+
result = await Runner.run(agent, input)
77+
print(result.final_output)
78+
79+
80+
if __name__ == "__main__":
81+
asyncio.run(main())

src/agents/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from . import _config
88
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
9-
from .agent_output import AgentOutputSchema
9+
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
1010
from .computer import AsyncComputer, Button, Computer, Environment
1111
from .exceptions import (
1212
AgentsException,
@@ -158,6 +158,7 @@ def enable_verbose_stdout_logging():
158158
"OpenAIProvider",
159159
"OpenAIResponsesModel",
160160
"AgentOutputSchema",
161+
"AgentOutputSchemaBase",
161162
"Computer",
162163
"AsyncComputer",
163164
"Environment",

src/agents/_run_impl.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
3030

3131
from .agent import Agent, ToolsToFinalOutputResult
32-
from .agent_output import AgentOutputSchema
32+
from .agent_output import AgentOutputSchemaBase
3333
from .computer import AsyncComputer, Computer
3434
from .exceptions import AgentsException, ModelBehaviorError, UserError
3535
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
@@ -195,7 +195,7 @@ async def execute_tools_and_side_effects(
195195
pre_step_items: list[RunItem],
196196
new_response: ModelResponse,
197197
processed_response: ProcessedResponse,
198-
output_schema: AgentOutputSchema | None,
198+
output_schema: AgentOutputSchemaBase | None,
199199
hooks: RunHooks[TContext],
200200
context_wrapper: RunContextWrapper[TContext],
201201
run_config: RunConfig,
@@ -335,7 +335,7 @@ def process_model_response(
335335
agent: Agent[Any],
336336
all_tools: list[Tool],
337337
response: ModelResponse,
338-
output_schema: AgentOutputSchema | None,
338+
output_schema: AgentOutputSchemaBase | None,
339339
handoffs: list[Handoff],
340340
) -> ProcessedResponse:
341341
items: list[RunItem] = []

src/agents/agent.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from typing_extensions import NotRequired, TypeAlias, TypedDict
1010

11+
from .agent_output import AgentOutputSchemaBase
1112
from .guardrail import InputGuardrail, OutputGuardrail
1213
from .handoffs import Handoff
1314
from .items import ItemHelpers
@@ -141,8 +142,14 @@ class Agent(Generic[TContext]):
141142
Runs only if the agent produces a final output.
142143
"""
143144

144-
output_type: type[Any] | None = None
145-
"""The type of the output object. If not provided, the output will be `str`."""
145+
output_type: type[Any] | AgentOutputSchemaBase | None = None
146+
"""The type of the output object. If not provided, the output will be `str`. In most cases,
147+
you should pass a regular Python type (e.g. a dataclass, Pydantic model, TypedDict, etc).
148+
You can customize this in two ways:
149+
1. If you want non-strict schemas, pass `AgentOutputSchema(MyClass, strict_json_schema=False)`.
150+
2. If you want to use a custom JSON schema (i.e. without using the SDK's automatic schema)
151+
creation, subclass and pass an `AgentOutputSchemaBase` subclass.
152+
"""
146153

147154
hooks: AgentHooks[TContext] | None = None
148155
"""A class that receives callbacks on various lifecycle events for this agent.

src/agents/agent_output.py

+58-8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import abc
12
from dataclasses import dataclass
23
from typing import Any
34

@@ -12,8 +13,46 @@
1213
_WRAPPER_DICT_KEY = "response"
1314

1415

16+
class AgentOutputSchemaBase(abc.ABC):
17+
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
18+
produced by the LLM into the output type.
19+
"""
20+
21+
@abc.abstractmethod
22+
def is_plain_text(self) -> bool:
23+
"""Whether the output type is plain text (versus a JSON object)."""
24+
pass
25+
26+
@abc.abstractmethod
27+
def name(self) -> str:
28+
"""The name of the output type."""
29+
pass
30+
31+
@abc.abstractmethod
32+
def json_schema(self) -> dict[str, Any]:
33+
"""Returns the JSON schema of the output. Will only be called if the output type is not
34+
plain text.
35+
"""
36+
pass
37+
38+
@abc.abstractmethod
39+
def is_strict_json_schema(self) -> bool:
40+
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
41+
features, but guarantees valis JSON. See here for details:
42+
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
43+
"""
44+
pass
45+
46+
@abc.abstractmethod
47+
def validate_json(self, json_str: str) -> Any:
48+
"""Validate a JSON string against the output type. You must return the validated object,
49+
or raise a `ModelBehaviorError` if the JSON is invalid.
50+
"""
51+
pass
52+
53+
1554
@dataclass(init=False)
16-
class AgentOutputSchema:
55+
class AgentOutputSchema(AgentOutputSchemaBase):
1756
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
1857
produced by the LLM into the output type.
1958
"""
@@ -32,7 +71,7 @@ class AgentOutputSchema:
3271
_output_schema: dict[str, Any]
3372
"""The JSON schema of the output."""
3473

35-
strict_json_schema: bool
74+
_strict_json_schema: bool
3675
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
3776
as it increases the likelihood of correct JSON input.
3877
"""
@@ -45,7 +84,7 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
4584
setting this to True, as it increases the likelihood of correct JSON input.
4685
"""
4786
self.output_type = output_type
48-
self.strict_json_schema = strict_json_schema
87+
self._strict_json_schema = strict_json_schema
4988

5089
if output_type is None or output_type is str:
5190
self._is_wrapped = False
@@ -70,24 +109,35 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
70109
self._type_adapter = TypeAdapter(output_type)
71110
self._output_schema = self._type_adapter.json_schema()
72111

73-
if self.strict_json_schema:
74-
self._output_schema = ensure_strict_json_schema(self._output_schema)
112+
if self._strict_json_schema:
113+
try:
114+
self._output_schema = ensure_strict_json_schema(self._output_schema)
115+
except UserError as e:
116+
raise UserError(
117+
"Strict JSON schema is enabled, but the output type is not valid. "
118+
"Either make the output type strict, or pass output_schema_strict=False to "
119+
"your Agent()"
120+
) from e
75121

76122
def is_plain_text(self) -> bool:
77123
"""Whether the output type is plain text (versus a JSON object)."""
78124
return self.output_type is None or self.output_type is str
79125

126+
def is_strict_json_schema(self) -> bool:
127+
"""Whether the JSON schema is in strict mode."""
128+
return self._strict_json_schema
129+
80130
def json_schema(self) -> dict[str, Any]:
81131
"""The JSON schema of the output type."""
82132
if self.is_plain_text():
83133
raise UserError("Output type is plain text, so no JSON schema is available")
84134
return self._output_schema
85135

86-
def validate_json(self, json_str: str, partial: bool = False) -> Any:
136+
def validate_json(self, json_str: str) -> Any:
87137
"""Validate a JSON string against the output type. Returns the validated object, or raises
88138
a `ModelBehaviorError` if the JSON is invalid.
89139
"""
90-
validated = _json.validate_json(json_str, self._type_adapter, partial)
140+
validated = _json.validate_json(json_str, self._type_adapter, partial=False)
91141
if self._is_wrapped:
92142
if not isinstance(validated, dict):
93143
_error_tracing.attach_error_to_current_span(
@@ -113,7 +163,7 @@ def validate_json(self, json_str: str, partial: bool = False) -> Any:
113163
return validated[_WRAPPER_DICT_KEY]
114164
return validated
115165

116-
def output_type_name(self) -> str:
166+
def name(self) -> str:
117167
"""The name of the output type."""
118168
return _type_to_str(self.output_type)
119169

src/agents/extensions/models/litellm_model.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from openai.types.responses import Response
3030

3131
from ... import _debug
32-
from ...agent_output import AgentOutputSchema
32+
from ...agent_output import AgentOutputSchemaBase
3333
from ...handoffs import Handoff
3434
from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent
3535
from ...logger import logger
@@ -68,7 +68,7 @@ async def get_response(
6868
input: str | list[TResponseInputItem],
6969
model_settings: ModelSettings,
7070
tools: list[Tool],
71-
output_schema: AgentOutputSchema | None,
71+
output_schema: AgentOutputSchemaBase | None,
7272
handoffs: list[Handoff],
7373
tracing: ModelTracing,
7474
previous_response_id: str | None,
@@ -139,7 +139,7 @@ async def stream_response(
139139
input: str | list[TResponseInputItem],
140140
model_settings: ModelSettings,
141141
tools: list[Tool],
142-
output_schema: AgentOutputSchema | None,
142+
output_schema: AgentOutputSchemaBase | None,
143143
handoffs: list[Handoff],
144144
tracing: ModelTracing,
145145
*,
@@ -186,7 +186,7 @@ async def _fetch_response(
186186
input: str | list[TResponseInputItem],
187187
model_settings: ModelSettings,
188188
tools: list[Tool],
189-
output_schema: AgentOutputSchema | None,
189+
output_schema: AgentOutputSchemaBase | None,
190190
handoffs: list[Handoff],
191191
span: Span[GenerationSpanData],
192192
tracing: ModelTracing,
@@ -200,7 +200,7 @@ async def _fetch_response(
200200
input: str | list[TResponseInputItem],
201201
model_settings: ModelSettings,
202202
tools: list[Tool],
203-
output_schema: AgentOutputSchema | None,
203+
output_schema: AgentOutputSchemaBase | None,
204204
handoffs: list[Handoff],
205205
span: Span[GenerationSpanData],
206206
tracing: ModelTracing,
@@ -213,7 +213,7 @@ async def _fetch_response(
213213
input: str | list[TResponseInputItem],
214214
model_settings: ModelSettings,
215215
tools: list[Tool],
216-
output_schema: AgentOutputSchema | None,
216+
output_schema: AgentOutputSchemaBase | None,
217217
handoffs: list[Handoff],
218218
span: Span[GenerationSpanData],
219219
tracing: ModelTracing,

src/agents/models/chatcmpl_converter.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
)
3737
from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message
3838

39-
from ..agent_output import AgentOutputSchema
39+
from ..agent_output import AgentOutputSchemaBase
4040
from ..exceptions import AgentsException, UserError
4141
from ..handoffs import Handoff
4242
from ..items import TResponseInputItem, TResponseOutputItem
@@ -67,7 +67,7 @@ def convert_tool_choice(
6767

6868
@classmethod
6969
def convert_response_format(
70-
cls, final_output_schema: AgentOutputSchema | None
70+
cls, final_output_schema: AgentOutputSchemaBase | None
7171
) -> ResponseFormat | NotGiven:
7272
if not final_output_schema or final_output_schema.is_plain_text():
7373
return NOT_GIVEN
@@ -76,7 +76,7 @@ def convert_response_format(
7676
"type": "json_schema",
7777
"json_schema": {
7878
"name": "final_output",
79-
"strict": final_output_schema.strict_json_schema,
79+
"strict": final_output_schema.is_strict_json_schema(),
8080
"schema": final_output_schema.json_schema(),
8181
},
8282
}

src/agents/models/interface.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import AsyncIterator
66
from typing import TYPE_CHECKING
77

8-
from ..agent_output import AgentOutputSchema
8+
from ..agent_output import AgentOutputSchemaBase
99
from ..handoffs import Handoff
1010
from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
1111
from ..tool import Tool
@@ -41,7 +41,7 @@ async def get_response(
4141
input: str | list[TResponseInputItem],
4242
model_settings: ModelSettings,
4343
tools: list[Tool],
44-
output_schema: AgentOutputSchema | None,
44+
output_schema: AgentOutputSchemaBase | None,
4545
handoffs: list[Handoff],
4646
tracing: ModelTracing,
4747
*,
@@ -72,7 +72,7 @@ def stream_response(
7272
input: str | list[TResponseInputItem],
7373
model_settings: ModelSettings,
7474
tools: list[Tool],
75-
output_schema: AgentOutputSchema | None,
75+
output_schema: AgentOutputSchemaBase | None,
7676
handoffs: list[Handoff],
7777
tracing: ModelTracing,
7878
*,

0 commit comments

Comments
 (0)