Skip to content

Commit 6fb8e82

Browse files
committed
Example for streaming guardrails
1 parent 25f97f9 commit 6fb8e82

File tree

3 files changed

+99
-4
lines changed

3 files changed

+99
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
5+
from openai.types.responses import ResponseTextDeltaEvent
6+
from pydantic import BaseModel, Field
7+
8+
from agents import Agent, Runner
9+
10+
"""
11+
This example shows how to use guardrails as the model is streaming. Output guardrails run after the
12+
final output has been generated; this example runs guardails every N tokens, allowing for early
13+
termination if bad output is detected.
14+
15+
The expected output is that you'll see a bunch of tokens stream in, then the guardrail will trigger
16+
and stop the streaming.
17+
"""
18+
19+
20+
agent = Agent(
21+
name="Assistant",
22+
instructions=(
23+
"You are a helpful assistant. You ALWAYS write long responses, making sure to be verbose "
24+
"and detailed."
25+
),
26+
)
27+
28+
29+
class GuardrailOutput(BaseModel):
30+
reasoning: str = Field(
31+
description="Reasoning about whether the response could be understood by a ten year old."
32+
)
33+
is_readable_by_ten_year_old: bool = Field(
34+
description="Whether the response is understandable by a ten year old."
35+
)
36+
37+
38+
guardrail_agent = Agent(
39+
name="Checker",
40+
instructions=(
41+
"You will be given a question and a response. Your goal is to judge whether the response "
42+
"is simple enough to be understood by a ten year old."
43+
),
44+
output_type=GuardrailOutput,
45+
model="gpt-4o-mini",
46+
)
47+
48+
49+
async def check_guardrail(text: str) -> GuardrailOutput:
50+
result = await Runner.run(guardrail_agent, text)
51+
return result.final_output_as(GuardrailOutput)
52+
53+
54+
async def main():
55+
question = "What is a black hole, and how does it behave?"
56+
result = Runner.run_streamed(agent, question)
57+
current_text = ""
58+
59+
# We will check the guardrail every N characters
60+
next_guardrail_check_len = 300
61+
guardrail_task = None
62+
63+
async for event in result.stream_events():
64+
if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
65+
print(event.data.delta, end="", flush=True)
66+
current_text += event.data.delta
67+
68+
# Check if it's time to run the guardrail check
69+
# Note that we don't run the guardrail check if there's already a task running. An
70+
# alternate implementation is to have N guardrails running, or cancel the previous
71+
# one.
72+
if len(current_text) >= next_guardrail_check_len and not guardrail_task:
73+
print("Running guardrail check")
74+
guardrail_task = asyncio.create_task(check_guardrail(current_text))
75+
next_guardrail_check_len += 300
76+
77+
# Every iteration of the loop, check if the guardrail has been triggered
78+
if guardrail_task and guardrail_task.done():
79+
guardrail_result = guardrail_task.result()
80+
if not guardrail_result.is_readable_by_ten_year_old:
81+
print("\n\n================\n\n")
82+
print(f"Guardrail triggered. Reasoning:\n{guardrail_result.reasoning}")
83+
break
84+
85+
# Do one final check on the final output
86+
guardrail_result = await check_guardrail(current_text)
87+
if not guardrail_result.is_readable_by_ten_year_old:
88+
print("\n\n================\n\n")
89+
print(f"Guardrail triggered. Reasoning:\n{guardrail_result.reasoning}")
90+
91+
92+
if __name__ == "__main__":
93+
asyncio.run(main())

src/agents/models/openai_chatcompletions.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,6 @@ def _get_client(self) -> AsyncOpenAI:
572572

573573

574574
class _Converter:
575-
576575
@classmethod
577576
def is_openai(cls, client: AsyncOpenAI):
578577
return str(client.base_url).startswith("https://api.openai.com")
@@ -585,11 +584,14 @@ def get_store_param(cls, client: AsyncOpenAI, model_settings: ModelSettings) ->
585584

586585
@classmethod
587586
def get_stream_options_param(
588-
cls, client: AsyncOpenAI, model_settings: ModelSettings
587+
cls, client: AsyncOpenAI, model_settings: ModelSettings
589588
) -> dict[str, bool] | None:
590589
default_include_usage = True if cls.is_openai(client) else None
591-
include_usage = model_settings.include_usage if model_settings.include_usage is not None \
590+
include_usage = (
591+
model_settings.include_usage
592+
if model_settings.include_usage is not None
592593
else default_include_usage
594+
)
593595
stream_options = {"include_usage": include_usage} if include_usage is not None else None
594596
return stream_options
595597

src/agents/models/openai_responses.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ async def _fetch_response(
250250
text=response_format,
251251
store=self._non_null_or_not_given(model_settings.store),
252252
reasoning=self._non_null_or_not_given(model_settings.reasoning),
253-
metadata=self._non_null_or_not_given(model_settings.metadata)
253+
metadata=self._non_null_or_not_given(model_settings.metadata),
254254
)
255255

256256
def _get_client(self) -> AsyncOpenAI:

0 commit comments

Comments
 (0)