Skip to content

Commit dffc269

Browse files
committed
Create to_json_dict for ModelSettings
1 parent 357074c commit dffc269

File tree

5 files changed

+79
-11
lines changed

5 files changed

+79
-11
lines changed

src/agents/extensions/models/litellm_model.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import dataclasses
43
import json
54
import time
65
from collections.abc import AsyncIterator
@@ -32,7 +31,7 @@
3231
from ... import _debug
3332
from ...agent_output import AgentOutputSchemaBase
3433
from ...handoffs import Handoff
35-
from ...items import ModelResponse, ReasoningItem, TResponseInputItem, TResponseStreamEvent
34+
from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent
3635
from ...logger import logger
3736
from ...model_settings import ModelSettings
3837
from ...models.chatcmpl_converter import Converter
@@ -76,7 +75,7 @@ async def get_response(
7675
) -> ModelResponse:
7776
with generation_span(
7877
model=str(self.model),
79-
model_config=dataclasses.asdict(model_settings)
78+
model_config=model_settings.to_json_dict()
8079
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
8180
disabled=tracing.is_disabled(),
8281
) as span_generation:
@@ -159,7 +158,7 @@ async def stream_response(
159158
) -> AsyncIterator[TResponseStreamEvent]:
160159
with generation_span(
161160
model=str(self.model),
162-
model_config=dataclasses.asdict(model_settings)
161+
model_config=model_settings.to_json_dict()
163162
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
164163
disabled=tracing.is_disabled(),
165164
) as span_generation:

src/agents/model_settings.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3+
import dataclasses
34
from dataclasses import dataclass, fields, replace
4-
from typing import Literal
5+
from typing import Any, Literal
56

67
from openai._types import Body, Query
78
from openai.types.shared import Reasoning
9+
from pydantic import BaseModel
810

911

1012
@dataclass
@@ -79,3 +81,16 @@ def resolve(self, override: ModelSettings | None) -> ModelSettings:
7981
if getattr(override, field.name) is not None
8082
}
8183
return replace(self, **changes)
84+
85+
def to_json_dict(self) -> dict[str, Any]:
86+
dataclass_dict = dataclasses.asdict(self)
87+
88+
json_dict: dict[str, Any] = {}
89+
90+
for field_name, value in dataclass_dict.items():
91+
if isinstance(value, BaseModel):
92+
json_dict[field_name] = value.model_dump(mode="json")
93+
else:
94+
json_dict[field_name] = value
95+
96+
return json_dict

src/agents/models/openai_chatcompletions.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import dataclasses
43
import json
54
import time
65
from collections.abc import AsyncIterator
@@ -56,8 +55,7 @@ async def get_response(
5655
) -> ModelResponse:
5756
with generation_span(
5857
model=str(self.model),
59-
model_config=dataclasses.asdict(model_settings)
60-
| {"base_url": str(self._client.base_url)},
58+
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
6159
disabled=tracing.is_disabled(),
6260
) as span_generation:
6361
response = await self._fetch_response(
@@ -121,8 +119,7 @@ async def stream_response(
121119
"""
122120
with generation_span(
123121
model=str(self.model),
124-
model_config=dataclasses.asdict(model_settings)
125-
| {"base_url": str(self._client.base_url)},
122+
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
126123
disabled=tracing.is_disabled(),
127124
) as span_generation:
128125
response, stream = await self._fetch_response(
+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import json
2+
from dataclasses import fields
3+
4+
from openai.types.shared import Reasoning
5+
6+
from agents.model_settings import ModelSettings
7+
8+
9+
def verify_serialization(model_settings: ModelSettings) -> None:
10+
"""Verify that ModelSettings can be serialized to a JSON string."""
11+
json_dict = model_settings.to_json_dict()
12+
json_string = json.dumps(json_dict)
13+
assert json_string is not None
14+
15+
16+
def test_basic_serialization() -> None:
17+
"""Tests whether ModelSettings can be serialized to a JSON string."""
18+
19+
# First, lets create a ModelSettings instance
20+
model_settings = ModelSettings(
21+
temperature=0.5,
22+
top_p=0.9,
23+
max_tokens=100,
24+
)
25+
26+
# Now, lets serialize the ModelSettings instance to a JSON string
27+
verify_serialization(model_settings)
28+
29+
30+
def test_all_fields_serialization() -> None:
31+
"""Tests whether ModelSettings can be serialized to a JSON string."""
32+
33+
# First, lets create a ModelSettings instance
34+
model_settings = ModelSettings(
35+
temperature=0.5,
36+
top_p=0.9,
37+
frequency_penalty=0.0,
38+
presence_penalty=0.0,
39+
tool_choice="auto",
40+
parallel_tool_calls=True,
41+
truncation="auto",
42+
max_tokens=100,
43+
reasoning=Reasoning(summary="auto"),
44+
metadata={"foo": "bar"},
45+
store=False,
46+
include_usage=False,
47+
extra_query={"foo": "bar"},
48+
extra_body={"foo": "bar"},
49+
)
50+
51+
# Verify that every single field is set to a non-None value
52+
for field in fields(model_settings):
53+
assert getattr(model_settings, field.name) is not None, (
54+
f"You must set the {field.name} field"
55+
)
56+
57+
# Now, lets serialize the ModelSettings instance to a JSON string
58+
verify_serialization(model_settings)

tests/voice/conftest.py

-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,3 @@ def pytest_ignore_collect(collection_path, config):
99

1010
if str(collection_path).startswith(this_dir):
1111
return True
12-

0 commit comments

Comments
 (0)