Skip to content

Commit a113fea

Browse files
authored
Allow cancel out of the streaming result (#579)
Fix for #574 @rm-openai I'm not sure how to add a test within the repo but I have pasted a test script below that seems to work ```python import asyncio from openai.types.responses import ResponseTextDeltaEvent from agents import Agent, Runner async def main(): agent = Agent( name="Joker", instructions="You are a helpful assistant.", ) result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") num_visible_event = 0 async for event in result.stream_events(): if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): print(event.data.delta, end="", flush=True) num_visible_event += 1 print(num_visible_event) if num_visible_event == 3: result.cancel() if __name__ == "__main__": asyncio.run(main()) ````
1 parent 178020e commit a113fea

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

src/agents/result.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) -
7575

7676
def to_input_list(self) -> list[TResponseInputItem]:
7777
"""Creates a new input list, merging the original input with all the new items generated."""
78-
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input)
78+
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(
79+
self.input
80+
)
7981
new_items = [item.to_input_item() for item in self.new_items]
8082

8183
return original_items + new_items
@@ -152,6 +154,18 @@ def last_agent(self) -> Agent[Any]:
152154
"""
153155
return self.current_agent
154156

157+
def cancel(self) -> None:
158+
"""Cancels the streaming run, stopping all background tasks and marking the run as
159+
complete."""
160+
self._cleanup_tasks() # Cancel all running tasks
161+
self.is_complete = True # Mark the run as complete to stop event streaming
162+
163+
# Optionally, clear the event queue to prevent processing stale events
164+
while not self._event_queue.empty():
165+
self._event_queue.get_nowait()
166+
while not self._input_guardrail_queue.empty():
167+
self._input_guardrail_queue.get_nowait()
168+
155169
async def stream_events(self) -> AsyncIterator[StreamEvent]:
156170
"""Stream deltas for new items as they are generated. We're using the types from the
157171
OpenAI Responses API, so these are semantic events: each event has a `type` field that
@@ -192,13 +206,17 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
192206

193207
def _check_errors(self):
194208
if self.current_turn > self.max_turns:
195-
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
209+
self._stored_exception = MaxTurnsExceeded(
210+
f"Max turns ({self.max_turns}) exceeded"
211+
)
196212

197213
# Fetch all the completed guardrail results from the queue and raise if needed
198214
while not self._input_guardrail_queue.empty():
199215
guardrail_result = self._input_guardrail_queue.get_nowait()
200216
if guardrail_result.output.tripwire_triggered:
201-
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
217+
self._stored_exception = InputGuardrailTripwireTriggered(
218+
guardrail_result
219+
)
202220

203221
# Check the tasks for any exceptions
204222
if self._run_impl_task and self._run_impl_task.done():

tests/test_cancel_streaming.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
3+
from agents import Agent, Runner
4+
5+
from .fake_model import FakeModel
6+
7+
8+
@pytest.mark.asyncio
9+
async def test_joker_streamed_jokes_with_cancel():
10+
model = FakeModel()
11+
agent = Agent(name="Joker", model=model)
12+
13+
result = Runner.run_streamed(agent, input="Please tell me 5 jokes.")
14+
num_events = 0
15+
stop_after = 1 # There are two that the model gives back.
16+
17+
async for _event in result.stream_events():
18+
num_events += 1
19+
if num_events == 1:
20+
result.cancel()
21+
22+
assert num_events == 1, f"Expected {stop_after} visible events, but got {num_events}"

0 commit comments

Comments
 (0)