Skip to content

Add usage to context in streaming #595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .logger import logger
from .run_context import RunContextWrapper
from .stream_events import StreamEvent
from .tracing import Trace
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
Expand Down Expand Up @@ -50,6 +51,9 @@ class RunResultBase(abc.ABC):
output_guardrail_results: list[OutputGuardrailResult]
"""Guardrail results for the final output of the agent."""

context_wrapper: RunContextWrapper[Any]
"""The context wrapper for the agent run."""

@property
@abc.abstractmethod
def last_agent(self) -> Agent[Any]:
Expand All @@ -75,9 +79,7 @@ def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) -

def to_input_list(self) -> list[TResponseInputItem]:
"""Creates a new input list, merging the original input with all the new items generated."""
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(
self.input
)
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input)
new_items = [item.to_input_item() for item in self.new_items]

return original_items + new_items
Expand Down Expand Up @@ -206,17 +208,13 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:

def _check_errors(self):
if self.current_turn > self.max_turns:
self._stored_exception = MaxTurnsExceeded(
f"Max turns ({self.max_turns}) exceeded"
)
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")

# Fetch all the completed guardrail results from the queue and raise if needed
while not self._input_guardrail_queue.empty():
guardrail_result = self._input_guardrail_queue.get_nowait()
if guardrail_result.output.tripwire_triggered:
self._stored_exception = InputGuardrailTripwireTriggered(
guardrail_result
)
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)

# Check the tasks for any exceptions
if self._run_impl_task and self._run_impl_task.done():
Expand Down
3 changes: 3 additions & 0 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ async def run(
_last_agent=current_agent,
input_guardrail_results=input_guardrail_results,
output_guardrail_results=output_guardrail_results,
context_wrapper=context_wrapper,
)
elif isinstance(turn_result.next_step, NextStepHandoff):
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
Expand Down Expand Up @@ -423,6 +424,7 @@ def run_streamed(
output_guardrail_results=[],
_current_agent_output_schema=output_schema,
trace=new_trace,
context_wrapper=context_wrapper,
)

# Kick off the actual agent loop in the background and return the streamed result object.
Expand Down Expand Up @@ -696,6 +698,7 @@ async def _run_single_turn_streamed(
usage=usage,
response_id=event.response.id,
)
context_wrapper.usage.add(usage)

streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))

Expand Down
19 changes: 16 additions & 3 deletions tests/fake_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from collections.abc import AsyncIterator
from typing import Any

from openai.types.responses import Response, ResponseCompletedEvent
from openai.types.responses import Response, ResponseCompletedEvent, ResponseUsage
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails

from agents.agent_output import AgentOutputSchemaBase
from agents.handoffs import Handoff
Expand Down Expand Up @@ -33,6 +34,10 @@ def __init__(
)
self.tracing_enabled = tracing_enabled
self.last_turn_args: dict[str, Any] = {}
self.hardcoded_usage: Usage | None = None

def set_hardcoded_usage(self, usage: Usage):
self.hardcoded_usage = usage

def set_next_output(self, output: list[TResponseOutputItem] | Exception):
self.turn_outputs.append(output)
Expand Down Expand Up @@ -83,7 +88,7 @@ async def get_response(

return ModelResponse(
output=output,
usage=Usage(),
usage=self.hardcoded_usage or Usage(),
response_id=None,
)

Expand Down Expand Up @@ -123,13 +128,14 @@ async def stream_response(

yield ResponseCompletedEvent(
type="response.completed",
response=get_response_obj(output),
response=get_response_obj(output, usage=self.hardcoded_usage),
)


def get_response_obj(
output: list[TResponseOutputItem],
response_id: str | None = None,
usage: Usage | None = None,
) -> Response:
return Response(
id=response_id or "123",
Expand All @@ -141,4 +147,11 @@ def get_response_obj(
tools=[],
top_p=None,
parallel_tool_calls=False,
usage=ResponseUsage(
input_tokens=usage.input_tokens if usage else 0,
output_tokens=usage.output_tokens if usage else 0,
total_tokens=usage.total_tokens if usage else 0,
input_tokens_details=InputTokensDetails(cached_tokens=0),
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
),
)
3 changes: 2 additions & 1 deletion tests/test_result_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from pydantic import BaseModel

from agents import Agent, RunResult
from agents import Agent, RunContextWrapper, RunResult


def create_run_result(final_output: Any) -> RunResult:
Expand All @@ -15,6 +15,7 @@ def create_run_result(final_output: Any) -> RunResult:
input_guardrail_results=[],
output_guardrail_results=[],
_last_agent=Agent(name="test"),
context_wrapper=RunContextWrapper(context=None),
)


Expand Down