Skip to content

Commit 4c8a0fc

Browse files
committed
Add usage to context in streaming
1 parent 3bbc7c4 commit 4c8a0fc

File tree

4 files changed

+28
-13
lines changed

4 files changed

+28
-13
lines changed

src/agents/result.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .guardrail import InputGuardrailResult, OutputGuardrailResult
1616
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
1717
from .logger import logger
18+
from .run_context import RunContextWrapper
1819
from .stream_events import StreamEvent
1920
from .tracing import Trace
2021
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
@@ -50,6 +51,9 @@ class RunResultBase(abc.ABC):
5051
output_guardrail_results: list[OutputGuardrailResult]
5152
"""Guardrail results for the final output of the agent."""
5253

54+
context_wrapper: RunContextWrapper[Any]
55+
"""The context wrapper for the agent run."""
56+
5357
@property
5458
@abc.abstractmethod
5559
def last_agent(self) -> Agent[Any]:
@@ -75,9 +79,7 @@ def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) -
7579

7680
def to_input_list(self) -> list[TResponseInputItem]:
7781
"""Creates a new input list, merging the original input with all the new items generated."""
78-
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(
79-
self.input
80-
)
82+
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input)
8183
new_items = [item.to_input_item() for item in self.new_items]
8284

8385
return original_items + new_items
@@ -206,17 +208,13 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
206208

207209
def _check_errors(self):
208210
if self.current_turn > self.max_turns:
209-
self._stored_exception = MaxTurnsExceeded(
210-
f"Max turns ({self.max_turns}) exceeded"
211-
)
211+
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
212212

213213
# Fetch all the completed guardrail results from the queue and raise if needed
214214
while not self._input_guardrail_queue.empty():
215215
guardrail_result = self._input_guardrail_queue.get_nowait()
216216
if guardrail_result.output.tripwire_triggered:
217-
self._stored_exception = InputGuardrailTripwireTriggered(
218-
guardrail_result
219-
)
217+
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
220218

221219
# Check the tasks for any exceptions
222220
if self._run_impl_task and self._run_impl_task.done():

src/agents/run.py

+3
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ async def run(
270270
_last_agent=current_agent,
271271
input_guardrail_results=input_guardrail_results,
272272
output_guardrail_results=output_guardrail_results,
273+
context_wrapper=context_wrapper,
273274
)
274275
elif isinstance(turn_result.next_step, NextStepHandoff):
275276
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
@@ -423,6 +424,7 @@ def run_streamed(
423424
output_guardrail_results=[],
424425
_current_agent_output_schema=output_schema,
425426
trace=new_trace,
427+
context_wrapper=context_wrapper,
426428
)
427429

428430
# Kick off the actual agent loop in the background and return the streamed result object.
@@ -696,6 +698,7 @@ async def _run_single_turn_streamed(
696698
usage=usage,
697699
response_id=event.response.id,
698700
)
701+
context_wrapper.usage.add(usage)
699702

700703
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
701704

tests/fake_model.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from collections.abc import AsyncIterator
44
from typing import Any
55

6-
from openai.types.responses import Response, ResponseCompletedEvent
6+
from openai.types.responses import Response, ResponseCompletedEvent, ResponseUsage
7+
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
78

89
from agents.agent_output import AgentOutputSchemaBase
910
from agents.handoffs import Handoff
@@ -33,6 +34,10 @@ def __init__(
3334
)
3435
self.tracing_enabled = tracing_enabled
3536
self.last_turn_args: dict[str, Any] = {}
37+
self.hardcoded_usage: Usage | None = None
38+
39+
def set_hardcoded_usage(self, usage: Usage):
40+
self.hardcoded_usage = usage
3641

3742
def set_next_output(self, output: list[TResponseOutputItem] | Exception):
3843
self.turn_outputs.append(output)
@@ -83,7 +88,7 @@ async def get_response(
8388

8489
return ModelResponse(
8590
output=output,
86-
usage=Usage(),
91+
usage=self.hardcoded_usage or Usage(),
8792
response_id=None,
8893
)
8994

@@ -123,13 +128,14 @@ async def stream_response(
123128

124129
yield ResponseCompletedEvent(
125130
type="response.completed",
126-
response=get_response_obj(output),
131+
response=get_response_obj(output, usage=self.hardcoded_usage),
127132
)
128133

129134

130135
def get_response_obj(
131136
output: list[TResponseOutputItem],
132137
response_id: str | None = None,
138+
usage: Usage | None = None,
133139
) -> Response:
134140
return Response(
135141
id=response_id or "123",
@@ -141,4 +147,11 @@ def get_response_obj(
141147
tools=[],
142148
top_p=None,
143149
parallel_tool_calls=False,
150+
usage=ResponseUsage(
151+
input_tokens=usage.input_tokens if usage else 0,
152+
output_tokens=usage.output_tokens if usage else 0,
153+
total_tokens=usage.total_tokens if usage else 0,
154+
input_tokens_details=InputTokensDetails(cached_tokens=0),
155+
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
156+
),
144157
)

tests/test_result_cast.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from pydantic import BaseModel
55

6-
from agents import Agent, RunResult
6+
from agents import Agent, RunContextWrapper, RunResult
77

88

99
def create_run_result(final_output: Any) -> RunResult:
@@ -15,6 +15,7 @@ def create_run_result(final_output: Any) -> RunResult:
1515
input_guardrail_results=[],
1616
output_guardrail_results=[],
1717
_last_agent=Agent(name="test"),
18+
context_wrapper=RunContextWrapper(context=None),
1819
)
1920

2021

0 commit comments

Comments
 (0)