1
+ import abc
1
2
from dataclasses import dataclass
2
3
from typing import Any
3
4
12
13
_WRAPPER_DICT_KEY = "response"
13
14
14
15
16
+ class AgentOutputSchemaBase (abc .ABC ):
17
+ """An object that captures the JSON schema of the output, as well as validating/parsing JSON
18
+ produced by the LLM into the output type.
19
+ """
20
+
21
+ @abc .abstractmethod
22
+ def is_plain_text (self ) -> bool :
23
+ """Whether the output type is plain text (versus a JSON object)."""
24
+ pass
25
+
26
+ @abc .abstractmethod
27
+ def name (self ) -> str :
28
+ """The name of the output type."""
29
+ pass
30
+
31
+ @abc .abstractmethod
32
+ def json_schema (self ) -> dict [str , Any ]:
33
+ """Returns the JSON schema of the output. Will only be called if the output type is not
34
+ plain text.
35
+ """
36
+ pass
37
+
38
+ @abc .abstractmethod
39
+ def is_strict_json_schema (self ) -> bool :
40
+ """Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
41
+ features, but guarantees valis JSON. See here for details:
42
+ https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
43
+ """
44
+ pass
45
+
46
+ @abc .abstractmethod
47
+ def validate_json (self , json_str : str ) -> Any :
48
+ """Validate a JSON string against the output type. You must return the validated object,
49
+ or raise a `ModelBehaviorError` if the JSON is invalid.
50
+ """
51
+ pass
52
+
53
+
15
54
@dataclass (init = False )
16
- class AgentOutputSchema :
55
+ class AgentOutputSchema ( AgentOutputSchemaBase ) :
17
56
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
18
57
produced by the LLM into the output type.
19
58
"""
@@ -32,7 +71,7 @@ class AgentOutputSchema:
32
71
_output_schema : dict [str , Any ]
33
72
"""The JSON schema of the output."""
34
73
35
- strict_json_schema : bool
74
+ _strict_json_schema : bool
36
75
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
37
76
as it increases the likelihood of correct JSON input.
38
77
"""
@@ -45,7 +84,7 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
45
84
setting this to True, as it increases the likelihood of correct JSON input.
46
85
"""
47
86
self .output_type = output_type
48
- self .strict_json_schema = strict_json_schema
87
+ self ._strict_json_schema = strict_json_schema
49
88
50
89
if output_type is None or output_type is str :
51
90
self ._is_wrapped = False
@@ -70,24 +109,35 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
70
109
self ._type_adapter = TypeAdapter (output_type )
71
110
self ._output_schema = self ._type_adapter .json_schema ()
72
111
73
- if self .strict_json_schema :
74
- self ._output_schema = ensure_strict_json_schema (self ._output_schema )
112
+ if self ._strict_json_schema :
113
+ try :
114
+ self ._output_schema = ensure_strict_json_schema (self ._output_schema )
115
+ except UserError as e :
116
+ raise UserError (
117
+ "Strict JSON schema is enabled, but the output type is not valid. "
118
+ "Either make the output type strict, or pass output_schema_strict=False to "
119
+ "your Agent()"
120
+ ) from e
75
121
76
122
def is_plain_text (self ) -> bool :
77
123
"""Whether the output type is plain text (versus a JSON object)."""
78
124
return self .output_type is None or self .output_type is str
79
125
126
+ def is_strict_json_schema (self ) -> bool :
127
+ """Whether the JSON schema is in strict mode."""
128
+ return self ._strict_json_schema
129
+
80
130
def json_schema (self ) -> dict [str , Any ]:
81
131
"""The JSON schema of the output type."""
82
132
if self .is_plain_text ():
83
133
raise UserError ("Output type is plain text, so no JSON schema is available" )
84
134
return self ._output_schema
85
135
86
- def validate_json (self , json_str : str , partial : bool = False ) -> Any :
136
+ def validate_json (self , json_str : str ) -> Any :
87
137
"""Validate a JSON string against the output type. Returns the validated object, or raises
88
138
a `ModelBehaviorError` if the JSON is invalid.
89
139
"""
90
- validated = _json .validate_json (json_str , self ._type_adapter , partial )
140
+ validated = _json .validate_json (json_str , self ._type_adapter , partial = False )
91
141
if self ._is_wrapped :
92
142
if not isinstance (validated , dict ):
93
143
_error_tracing .attach_error_to_current_span (
@@ -113,7 +163,7 @@ def validate_json(self, json_str: str, partial: bool = False) -> Any:
113
163
return validated [_WRAPPER_DICT_KEY ]
114
164
return validated
115
165
116
- def output_type_name (self ) -> str :
166
+ def name (self ) -> str :
117
167
"""The name of the output type."""
118
168
return _type_to_str (self .output_type )
119
169
0 commit comments