Skip to content

Commit 8040228

Browse files
committed
Create to_json_dict for ModelSettings
1 parent a113fea commit 8040228

File tree

7 files changed

+84
-15
lines changed

7 files changed

+84
-15
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ requires-python = ">=3.9"
77
license = "MIT"
88
authors = [{ name = "OpenAI", email = "[email protected]" }]
99
dependencies = [
10-
"openai>=1.66.5",
10+
"openai>=1.76.0",
1111
"pydantic>=2.10, <3",
1212
"griffe>=1.5.6, <2",
1313
"typing-extensions>=4.12.2, <5",

src/agents/extensions/models/litellm_model.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import dataclasses
43
import json
54
import time
65
from collections.abc import AsyncIterator
@@ -75,7 +74,7 @@ async def get_response(
7574
) -> ModelResponse:
7675
with generation_span(
7776
model=str(self.model),
78-
model_config=dataclasses.asdict(model_settings)
77+
model_config=model_settings.to_json_dict()
7978
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
8079
disabled=tracing.is_disabled(),
8180
) as span_generation:
@@ -147,7 +146,7 @@ async def stream_response(
147146
) -> AsyncIterator[TResponseStreamEvent]:
148147
with generation_span(
149148
model=str(self.model),
150-
model_config=dataclasses.asdict(model_settings)
149+
model_config=model_settings.to_json_dict()
151150
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
152151
disabled=tracing.is_disabled(),
153152
) as span_generation:

src/agents/model_settings.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3+
import dataclasses
34
from dataclasses import dataclass, fields, replace
4-
from typing import Literal
5+
from typing import Any, Literal
56

67
from openai._types import Body, Headers, Query
78
from openai.types.shared import Reasoning
9+
from pydantic import BaseModel
810

911

1012
@dataclass
@@ -83,3 +85,16 @@ def resolve(self, override: ModelSettings | None) -> ModelSettings:
8385
if getattr(override, field.name) is not None
8486
}
8587
return replace(self, **changes)
88+
89+
def to_json_dict(self) -> dict[str, Any]:
90+
dataclass_dict = dataclasses.asdict(self)
91+
92+
json_dict: dict[str, Any] = {}
93+
94+
for field_name, value in dataclass_dict.items():
95+
if isinstance(value, BaseModel):
96+
json_dict[field_name] = value.model_dump(mode="json")
97+
else:
98+
json_dict[field_name] = value
99+
100+
return json_dict

src/agents/models/openai_chatcompletions.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import dataclasses
43
import json
54
import time
65
from collections.abc import AsyncIterator
@@ -56,8 +55,7 @@ async def get_response(
5655
) -> ModelResponse:
5756
with generation_span(
5857
model=str(self.model),
59-
model_config=dataclasses.asdict(model_settings)
60-
| {"base_url": str(self._client.base_url)},
58+
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
6159
disabled=tracing.is_disabled(),
6260
) as span_generation:
6361
response = await self._fetch_response(
@@ -121,8 +119,7 @@ async def stream_response(
121119
"""
122120
with generation_span(
123121
model=str(self.model),
124-
model_config=dataclasses.asdict(model_settings)
125-
| {"base_url": str(self._client.base_url)},
122+
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
126123
disabled=tracing.is_disabled(),
127124
) as span_generation:
128125
response, stream = await self._fetch_response(
+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import json
2+
from dataclasses import fields
3+
4+
from openai.types.shared import Reasoning
5+
6+
from agents.model_settings import ModelSettings
7+
8+
9+
def verify_serialization(model_settings: ModelSettings) -> None:
10+
"""Verify that ModelSettings can be serialized to a JSON string."""
11+
json_dict = model_settings.to_json_dict()
12+
json_string = json.dumps(json_dict)
13+
assert json_string is not None
14+
15+
16+
def test_basic_serialization() -> None:
17+
"""Tests whether ModelSettings can be serialized to a JSON string."""
18+
19+
# First, lets create a ModelSettings instance
20+
model_settings = ModelSettings(
21+
temperature=0.5,
22+
top_p=0.9,
23+
max_tokens=100,
24+
)
25+
26+
# Now, lets serialize the ModelSettings instance to a JSON string
27+
verify_serialization(model_settings)
28+
29+
30+
def test_all_fields_serialization() -> None:
31+
"""Tests whether ModelSettings can be serialized to a JSON string."""
32+
33+
# First, lets create a ModelSettings instance
34+
model_settings = ModelSettings(
35+
temperature=0.5,
36+
top_p=0.9,
37+
frequency_penalty=0.0,
38+
presence_penalty=0.0,
39+
tool_choice="auto",
40+
parallel_tool_calls=True,
41+
truncation="auto",
42+
max_tokens=100,
43+
reasoning=Reasoning(),
44+
metadata={"foo": "bar"},
45+
store=False,
46+
include_usage=False,
47+
extra_query={"foo": "bar"},
48+
extra_body={"foo": "bar"},
49+
extra_headers={"foo": "bar"},
50+
)
51+
52+
# Verify that every single field is set to a non-None value
53+
for field in fields(model_settings):
54+
assert getattr(model_settings, field.name) is not None, (
55+
f"You must set the {field.name} field"
56+
)
57+
58+
# Now, lets serialize the ModelSettings instance to a JSON string
59+
verify_serialization(model_settings)

tests/voice/conftest.py

-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,3 @@ def pytest_ignore_collect(collection_path, config):
99

1010
if str(collection_path).startswith(this_dir):
1111
return True
12-

uv.lock

+4-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)