Skip to content

Commit 54330ea

Browse files
committed
Add streaming support for openai and anthropic providers
1 parent 9dc9ae9 commit 54330ea

File tree

3 files changed

+179
-51
lines changed

3 files changed

+179
-51
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# aisuite/framework/chat_completion_stream_response.py
2+
3+
from typing import Optional
4+
5+
class ChatCompletionStreamResponseDelta:
6+
"""
7+
Mimics the 'delta' object returned by OpenAI streaming chunks.
8+
Example usage in code:
9+
chunk.choices[0].delta.content
10+
"""
11+
def __init__(self, role: Optional[str] = None, content: Optional[str] = None):
12+
self.role = role
13+
self.content = content
14+
15+
16+
class ChatCompletionStreamResponseChoice:
17+
"""
18+
Holds the 'delta' for a single chunk choice.
19+
Example usage in code:
20+
chunk.choices[0].delta
21+
"""
22+
def __init__(self, delta: ChatCompletionStreamResponseDelta):
23+
self.delta = delta
24+
25+
26+
class ChatCompletionStreamResponse:
27+
"""
28+
Container for streaming response chunks.
29+
Each chunk has a 'choices' list, each with a 'delta'.
30+
Example usage in code:
31+
chunk.choices[0].delta.content
32+
"""
33+
def __init__(self, choices: list[ChatCompletionStreamResponseChoice]):
34+
self.choices = choices
35+
36+
37+

aisuite/providers/anthropic_provider.py

Lines changed: 90 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1-
# Anthropic provider
2-
# Links:
3-
# Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use
1+
# aisuite/providers/anthropic_provider.py
42

53
import anthropic
64
import json
5+
76
from aisuite.provider import Provider
87
from aisuite.framework import ChatCompletionResponse
98
from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function
109

10+
# Import our new streaming response classes:
11+
from aisuite.framework.chat_completion_stream_response import (
12+
ChatCompletionStreamResponse,
13+
ChatCompletionStreamResponseChoice,
14+
ChatCompletionStreamResponseDelta,
15+
)
16+
1117
# Define a constant for the default max_tokens value
1218
DEFAULT_MAX_TOKENS = 4096
1319

@@ -33,7 +39,7 @@ def convert_request(self, messages):
3339
return system_message, converted_messages
3440

3541
def convert_response(self, response):
36-
"""Normalize the response from the Anthropic API to match OpenAI's response format."""
42+
"""Normalize a non-streaming response from the Anthropic API to match OpenAI's response format."""
3743
normalized_response = ChatCompletionResponse()
3844
normalized_response.choices[0].finish_reason = self._get_finish_reason(response)
3945
normalized_response.usage = self._get_usage_stats(response)
@@ -57,7 +63,7 @@ def _convert_dict_message(self, msg):
5763
return {"role": msg["role"], "content": msg["content"]}
5864

5965
def _convert_message_object(self, msg):
60-
"""Convert a Message object to Anthropic format."""
66+
"""Convert a `Message` object to Anthropic format."""
6167
if msg.role == self.ROLE_TOOL:
6268
return self._create_tool_result_message(msg.tool_call_id, msg.content)
6369
elif msg.role == self.ROLE_ASSISTANT and msg.tool_calls:
@@ -107,22 +113,23 @@ def _create_assistant_tool_message(self, content, tool_calls):
107113
return {"role": self.ROLE_ASSISTANT, "content": message_content}
108114

109115
def _extract_system_message(self, messages):
110-
"""Extract system message if present, otherwise return empty list."""
111-
# TODO: This is a temporary solution to extract the system message.
112-
# User can pass multiple system messages, which can mingled with other messages.
113-
# This needs to be fixed to handle this case.
116+
"""
117+
Extract system message if present, otherwise return empty string.
118+
If there are multiple system messages, or the system message is not the first,
119+
you may need to adapt this approach.
120+
"""
114121
if messages and messages[0]["role"] == "system":
115122
system_message = messages[0]["content"]
116123
messages.pop(0)
117124
return system_message
118-
return []
125+
return ""
119126

120127
def _get_finish_reason(self, response):
121128
"""Get the normalized finish reason."""
122129
return self.FINISH_REASON_MAPPING.get(response.stop_reason, "stop")
123130

124131
def _get_usage_stats(self, response):
125-
"""Get the usage statistics."""
132+
"""Get the usage statistics from Anthropic response."""
126133
return {
127134
"prompt_tokens": response.usage.input_tokens,
128135
"completion_tokens": response.usage.output_tokens,
@@ -135,9 +142,8 @@ def _get_message(self, response):
135142
tool_message = self.convert_response_with_tool_use(response)
136143
if tool_message:
137144
return tool_message
138-
139145
return Message(
140-
content=response.content[0].text,
146+
content=response.content[0].text if response.content else "",
141147
role="assistant",
142148
tool_calls=None,
143149
refusal=None,
@@ -146,26 +152,22 @@ def _get_message(self, response):
146152
def convert_response_with_tool_use(self, response):
147153
"""Convert Anthropic tool use response to the framework's format."""
148154
tool_call = next(
149-
(content for content in response.content if content.type == "tool_use"),
155+
(c for c in response.content if c.type == "tool_use"),
150156
None,
151157
)
152-
153158
if tool_call:
154159
function = Function(
155-
name=tool_call.name, arguments=json.dumps(tool_call.input)
160+
name=tool_call.name,
161+
arguments=json.dumps(tool_call.input),
156162
)
157163
tool_call_obj = ChatCompletionMessageToolCall(
158-
id=tool_call.id, function=function, type="function"
164+
id=tool_call.id,
165+
function=function,
166+
type="function",
159167
)
160168
text_content = next(
161-
(
162-
content.text
163-
for content in response.content
164-
if content.type == "text"
165-
),
166-
"",
169+
(c.text for c in response.content if c.type == "text"), ""
167170
)
168-
169171
return Message(
170172
content=text_content or None,
171173
tool_calls=[tool_call_obj] if tool_call else None,
@@ -177,11 +179,9 @@ def convert_response_with_tool_use(self, response):
177179
def convert_tool_spec(self, openai_tools):
178180
"""Convert OpenAI tool specification to Anthropic format."""
179181
anthropic_tools = []
180-
181182
for tool in openai_tools:
182183
if tool.get("type") != "function":
183184
continue
184-
185185
function = tool["function"]
186186
anthropic_tool = {
187187
"name": function["name"],
@@ -193,7 +193,6 @@ def convert_tool_spec(self, openai_tools):
193193
},
194194
}
195195
anthropic_tools.append(anthropic_tool)
196-
197196
return anthropic_tools
198197

199198

@@ -204,21 +203,78 @@ def __init__(self, **config):
204203
self.converter = AnthropicMessageConverter()
205204

206205
def chat_completions_create(self, model, messages, **kwargs):
207-
"""Create a chat completion using the Anthropic API."""
206+
"""
207+
Create a chat completion using the Anthropic API.
208+
209+
If 'stream=True' is passed, return a generator that yields
210+
`ChatCompletionStreamResponse` objects shaped like OpenAI's streaming chunks.
211+
"""
212+
stream = kwargs.pop("stream", False)
213+
214+
if not stream:
215+
# Non-streaming call
216+
kwargs = self._prepare_kwargs(kwargs)
217+
system_message, converted_messages = self.converter.convert_request(messages)
218+
response = self.client.messages.create(
219+
model=model,
220+
system=system_message,
221+
messages=converted_messages,
222+
**kwargs
223+
)
224+
return self.converter.convert_response(response)
225+
else:
226+
# Streaming call
227+
return self._streaming_chat_completions_create(model, messages, **kwargs)
228+
229+
def _streaming_chat_completions_create(self, model, messages, **kwargs):
230+
"""
231+
Generator that yields chunk objects in the shape:
232+
chunk.choices[0].delta.content
233+
"""
208234
kwargs = self._prepare_kwargs(kwargs)
209235
system_message, converted_messages = self.converter.convert_request(messages)
210-
211-
response = self.client.messages.create(
212-
model=model, system=system_message, messages=converted_messages, **kwargs
213-
)
214-
return self.converter.convert_response(response)
236+
first_chunk = True
237+
238+
with self.client.messages.stream(
239+
model=model,
240+
system=system_message,
241+
messages=converted_messages,
242+
**kwargs
243+
) as stream_resp:
244+
245+
for partial_text in stream_resp.text_stream:
246+
# For the first token, include `role='assistant'`.
247+
if first_chunk:
248+
chunk = ChatCompletionStreamResponse(choices=[
249+
ChatCompletionStreamResponseChoice(
250+
delta=ChatCompletionStreamResponseDelta(
251+
role="assistant",
252+
content=partial_text
253+
)
254+
)
255+
])
256+
first_chunk = False
257+
else:
258+
chunk = ChatCompletionStreamResponse(choices=[
259+
ChatCompletionStreamResponseChoice(
260+
delta=ChatCompletionStreamResponseDelta(
261+
content=partial_text
262+
)
263+
)
264+
])
265+
266+
yield chunk
215267

216268
def _prepare_kwargs(self, kwargs):
217-
"""Prepare kwargs for the API call."""
269+
"""Prepare kwargs for the Anthropic API call."""
218270
kwargs = kwargs.copy()
219271
kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS)
220272

221273
if "tools" in kwargs:
222274
kwargs["tools"] = self.converter.convert_tool_spec(kwargs["tools"])
223275

224276
return kwargs
277+
278+
279+
280+

aisuite/providers/openai_provider.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import openai
22
import os
33
from aisuite.provider import Provider, LLMError
4-
from aisuite.providers.message_converter import OpenAICompliantMessageConverter
54

65

76
class OpenaiProvider(Provider):
@@ -14,27 +13,63 @@ def __init__(self, **config):
1413
config.setdefault("api_key", os.getenv("OPENAI_API_KEY"))
1514
if not config["api_key"]:
1615
raise ValueError(
17-
"OpenAI API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable."
16+
"OpenAI API key is missing. Please provide it in the config "
17+
"or set the OPENAI_API_KEY environment variable."
1818
)
1919

20-
# NOTE: We could choose to remove above lines for api_key since OpenAI will automatically
21-
# infer certain values from the environment variables.
22-
# Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID, OPENAI_BASE_URL, etc.
23-
2420
# Pass the entire config to the OpenAI client constructor
21+
# (Note: This assumes openai.OpenAI(...) is valid in your environment.
22+
# If you typically do `openai.api_key = ...`, adapt as needed.)
2523
self.client = openai.OpenAI(**config)
26-
self.transformer = OpenAICompliantMessageConverter()
2724

2825
def chat_completions_create(self, model, messages, **kwargs):
29-
# Any exception raised by OpenAI will be returned to the caller.
30-
# Maybe we should catch them and raise a custom LLMError.
31-
try:
32-
transformed_messages = self.transformer.convert_request(messages)
33-
response = self.client.chat.completions.create(
26+
"""
27+
Create chat completion using the OpenAI API.
28+
If 'stream=True' is passed via kwargs, return a generator that yields
29+
chunked responses in the OpenAI streaming format.
30+
"""
31+
stream = kwargs.pop("stream", False)
32+
33+
if not stream:
34+
# Non-streaming call
35+
return self.client.chat.completions.create(
3436
model=model,
35-
messages=transformed_messages,
36-
**kwargs, # Pass any additional arguments to the OpenAI API
37+
messages=messages,
38+
**kwargs
3739
)
38-
return response
39-
except Exception as e:
40-
raise LLMError(f"An error occurred: {e}")
40+
else:
41+
# Streaming call: return a generator that yields each chunk
42+
return self._streaming_chat_completions_create(model, messages, **kwargs)
43+
44+
def _streaming_chat_completions_create(self, model, messages, **kwargs):
45+
"""
46+
Internal helper method that yields chunked responses for streaming.
47+
Each chunk is already in the OpenAI streaming format:
48+
49+
{
50+
"id": ...,
51+
"object": "chat.completion.chunk",
52+
"created": ...,
53+
"model": ...,
54+
"choices": [
55+
{
56+
"delta": {
57+
"role": "assistant" or "content": ...
58+
}
59+
}
60+
]
61+
}
62+
"""
63+
response_gen = self.client.chat.completions.create(
64+
model=model,
65+
messages=messages,
66+
stream=True,
67+
**kwargs
68+
)
69+
70+
# Yield chunks as they arrive
71+
for chunk in response_gen:
72+
yield chunk
73+
74+
75+

0 commit comments

Comments
 (0)