Skip to content

Create to_json_dict for ModelSettings #582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires-python = ">=3.9"
license = "MIT"
authors = [{ name = "OpenAI", email = "[email protected]" }]
dependencies = [
"openai>=1.66.5",
"openai>=1.76.0",
"pydantic>=2.10, <3",
"griffe>=1.5.6, <2",
"typing-extensions>=4.12.2, <5",
Expand Down
5 changes: 2 additions & 3 deletions src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import dataclasses
import json
import time
from collections.abc import AsyncIterator
Expand Down Expand Up @@ -75,7 +74,7 @@ async def get_response(
) -> ModelResponse:
with generation_span(
model=str(self.model),
model_config=dataclasses.asdict(model_settings)
model_config=model_settings.to_json_dict()
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
disabled=tracing.is_disabled(),
) as span_generation:
Expand Down Expand Up @@ -147,7 +146,7 @@ async def stream_response(
) -> AsyncIterator[TResponseStreamEvent]:
with generation_span(
model=str(self.model),
model_config=dataclasses.asdict(model_settings)
model_config=model_settings.to_json_dict()
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
disabled=tracing.is_disabled(),
) as span_generation:
Expand Down
17 changes: 16 additions & 1 deletion src/agents/model_settings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import dataclasses
from dataclasses import dataclass, fields, replace
from typing import Literal
from typing import Any, Literal

from openai._types import Body, Headers, Query
from openai.types.shared import Reasoning
from pydantic import BaseModel


@dataclass
Expand Down Expand Up @@ -83,3 +85,16 @@ def resolve(self, override: ModelSettings | None) -> ModelSettings:
if getattr(override, field.name) is not None
}
return replace(self, **changes)

def to_json_dict(self) -> dict[str, Any]:
dataclass_dict = dataclasses.asdict(self)

json_dict: dict[str, Any] = {}

for field_name, value in dataclass_dict.items():
if isinstance(value, BaseModel):
json_dict[field_name] = value.model_dump(mode="json")
else:
json_dict[field_name] = value

return json_dict
7 changes: 2 additions & 5 deletions src/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import dataclasses
import json
import time
from collections.abc import AsyncIterator
Expand Down Expand Up @@ -56,8 +55,7 @@ async def get_response(
) -> ModelResponse:
with generation_span(
model=str(self.model),
model_config=dataclasses.asdict(model_settings)
| {"base_url": str(self._client.base_url)},
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
disabled=tracing.is_disabled(),
) as span_generation:
response = await self._fetch_response(
Expand Down Expand Up @@ -121,8 +119,7 @@ async def stream_response(
"""
with generation_span(
model=str(self.model),
model_config=dataclasses.asdict(model_settings)
| {"base_url": str(self._client.base_url)},
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
disabled=tracing.is_disabled(),
) as span_generation:
response, stream = await self._fetch_response(
Expand Down
59 changes: 59 additions & 0 deletions tests/model_settings/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import json
from dataclasses import fields

from openai.types.shared import Reasoning

from agents.model_settings import ModelSettings


def verify_serialization(model_settings: ModelSettings) -> None:
"""Verify that ModelSettings can be serialized to a JSON string."""
json_dict = model_settings.to_json_dict()
json_string = json.dumps(json_dict)
assert json_string is not None


def test_basic_serialization() -> None:
"""Tests whether ModelSettings can be serialized to a JSON string."""

# First, lets create a ModelSettings instance
model_settings = ModelSettings(
temperature=0.5,
top_p=0.9,
max_tokens=100,
)

# Now, lets serialize the ModelSettings instance to a JSON string
verify_serialization(model_settings)


def test_all_fields_serialization() -> None:
"""Tests whether ModelSettings can be serialized to a JSON string."""

# First, lets create a ModelSettings instance
model_settings = ModelSettings(
temperature=0.5,
top_p=0.9,
frequency_penalty=0.0,
presence_penalty=0.0,
tool_choice="auto",
parallel_tool_calls=True,
truncation="auto",
max_tokens=100,
reasoning=Reasoning(),
metadata={"foo": "bar"},
store=False,
include_usage=False,
extra_query={"foo": "bar"},
extra_body={"foo": "bar"},
extra_headers={"foo": "bar"},
)

# Verify that every single field is set to a non-None value
for field in fields(model_settings):
assert getattr(model_settings, field.name) is not None, (
f"You must set the {field.name} field"
)

# Now, lets serialize the ModelSettings instance to a JSON string
verify_serialization(model_settings)
1 change: 0 additions & 1 deletion tests/voice/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ def pytest_ignore_collect(collection_path, config):

if str(collection_path).startswith(this_dir):
return True

8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.