1
- # Anthropic provider
2
- # Links:
3
- # Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use
1
+ # aisuite/providers/anthropic_provider.py
4
2
5
3
import anthropic
6
4
import json
5
+
7
6
from aisuite .provider import Provider
8
7
from aisuite .framework import ChatCompletionResponse
9
8
from aisuite .framework .message import Message , ChatCompletionMessageToolCall , Function
10
9
10
+ # Import our new streaming response classes:
11
+ from aisuite .framework .chat_completion_stream_response import (
12
+ ChatCompletionStreamResponse ,
13
+ ChatCompletionStreamResponseChoice ,
14
+ ChatCompletionStreamResponseDelta ,
15
+ )
16
+
11
17
# Define a constant for the default max_tokens value
12
18
DEFAULT_MAX_TOKENS = 4096
13
19
@@ -33,7 +39,7 @@ def convert_request(self, messages):
33
39
return system_message , converted_messages
34
40
35
41
def convert_response (self , response ):
36
- """Normalize the response from the Anthropic API to match OpenAI's response format."""
42
+ """Normalize a non-streaming response from the Anthropic API to match OpenAI's response format."""
37
43
normalized_response = ChatCompletionResponse ()
38
44
normalized_response .choices [0 ].finish_reason = self ._get_finish_reason (response )
39
45
normalized_response .usage = self ._get_usage_stats (response )
@@ -57,7 +63,7 @@ def _convert_dict_message(self, msg):
57
63
return {"role" : msg ["role" ], "content" : msg ["content" ]}
58
64
59
65
def _convert_message_object (self , msg ):
60
- """Convert a Message object to Anthropic format."""
66
+ """Convert a ` Message` object to Anthropic format."""
61
67
if msg .role == self .ROLE_TOOL :
62
68
return self ._create_tool_result_message (msg .tool_call_id , msg .content )
63
69
elif msg .role == self .ROLE_ASSISTANT and msg .tool_calls :
@@ -107,22 +113,23 @@ def _create_assistant_tool_message(self, content, tool_calls):
107
113
return {"role" : self .ROLE_ASSISTANT , "content" : message_content }
108
114
109
115
def _extract_system_message (self , messages ):
110
- """Extract system message if present, otherwise return empty list."""
111
- # TODO: This is a temporary solution to extract the system message.
112
- # User can pass multiple system messages, which can mingled with other messages.
113
- # This needs to be fixed to handle this case.
116
+ """
117
+ Extract system message if present, otherwise return empty string.
118
+ If there are multiple system messages, or the system message is not the first,
119
+ you may need to adapt this approach.
120
+ """
114
121
if messages and messages [0 ]["role" ] == "system" :
115
122
system_message = messages [0 ]["content" ]
116
123
messages .pop (0 )
117
124
return system_message
118
- return []
125
+ return ""
119
126
120
127
def _get_finish_reason (self , response ):
121
128
"""Get the normalized finish reason."""
122
129
return self .FINISH_REASON_MAPPING .get (response .stop_reason , "stop" )
123
130
124
131
def _get_usage_stats (self , response ):
125
- """Get the usage statistics."""
132
+ """Get the usage statistics from Anthropic response ."""
126
133
return {
127
134
"prompt_tokens" : response .usage .input_tokens ,
128
135
"completion_tokens" : response .usage .output_tokens ,
@@ -135,9 +142,8 @@ def _get_message(self, response):
135
142
tool_message = self .convert_response_with_tool_use (response )
136
143
if tool_message :
137
144
return tool_message
138
-
139
145
return Message (
140
- content = response .content [0 ].text ,
146
+ content = response .content [0 ].text if response . content else "" ,
141
147
role = "assistant" ,
142
148
tool_calls = None ,
143
149
refusal = None ,
@@ -146,26 +152,22 @@ def _get_message(self, response):
146
152
def convert_response_with_tool_use (self , response ):
147
153
"""Convert Anthropic tool use response to the framework's format."""
148
154
tool_call = next (
149
- (content for content in response .content if content .type == "tool_use" ),
155
+ (c for c in response .content if c .type == "tool_use" ),
150
156
None ,
151
157
)
152
-
153
158
if tool_call :
154
159
function = Function (
155
- name = tool_call .name , arguments = json .dumps (tool_call .input )
160
+ name = tool_call .name ,
161
+ arguments = json .dumps (tool_call .input ),
156
162
)
157
163
tool_call_obj = ChatCompletionMessageToolCall (
158
- id = tool_call .id , function = function , type = "function"
164
+ id = tool_call .id ,
165
+ function = function ,
166
+ type = "function" ,
159
167
)
160
168
text_content = next (
161
- (
162
- content .text
163
- for content in response .content
164
- if content .type == "text"
165
- ),
166
- "" ,
169
+ (c .text for c in response .content if c .type == "text" ), ""
167
170
)
168
-
169
171
return Message (
170
172
content = text_content or None ,
171
173
tool_calls = [tool_call_obj ] if tool_call else None ,
@@ -177,11 +179,9 @@ def convert_response_with_tool_use(self, response):
177
179
def convert_tool_spec (self , openai_tools ):
178
180
"""Convert OpenAI tool specification to Anthropic format."""
179
181
anthropic_tools = []
180
-
181
182
for tool in openai_tools :
182
183
if tool .get ("type" ) != "function" :
183
184
continue
184
-
185
185
function = tool ["function" ]
186
186
anthropic_tool = {
187
187
"name" : function ["name" ],
@@ -193,7 +193,6 @@ def convert_tool_spec(self, openai_tools):
193
193
},
194
194
}
195
195
anthropic_tools .append (anthropic_tool )
196
-
197
196
return anthropic_tools
198
197
199
198
@@ -204,21 +203,78 @@ def __init__(self, **config):
204
203
self .converter = AnthropicMessageConverter ()
205
204
206
205
def chat_completions_create (self , model , messages , ** kwargs ):
207
- """Create a chat completion using the Anthropic API."""
206
+ """
207
+ Create a chat completion using the Anthropic API.
208
+
209
+ If 'stream=True' is passed, return a generator that yields
210
+ `ChatCompletionStreamResponse` objects shaped like OpenAI's streaming chunks.
211
+ """
212
+ stream = kwargs .pop ("stream" , False )
213
+
214
+ if not stream :
215
+ # Non-streaming call
216
+ kwargs = self ._prepare_kwargs (kwargs )
217
+ system_message , converted_messages = self .converter .convert_request (messages )
218
+ response = self .client .messages .create (
219
+ model = model ,
220
+ system = system_message ,
221
+ messages = converted_messages ,
222
+ ** kwargs
223
+ )
224
+ return self .converter .convert_response (response )
225
+ else :
226
+ # Streaming call
227
+ return self ._streaming_chat_completions_create (model , messages , ** kwargs )
228
+
229
+ def _streaming_chat_completions_create (self , model , messages , ** kwargs ):
230
+ """
231
+ Generator that yields chunk objects in the shape:
232
+ chunk.choices[0].delta.content
233
+ """
208
234
kwargs = self ._prepare_kwargs (kwargs )
209
235
system_message , converted_messages = self .converter .convert_request (messages )
210
-
211
- response = self .client .messages .create (
212
- model = model , system = system_message , messages = converted_messages , ** kwargs
213
- )
214
- return self .converter .convert_response (response )
236
+ first_chunk = True
237
+
238
+ with self .client .messages .stream (
239
+ model = model ,
240
+ system = system_message ,
241
+ messages = converted_messages ,
242
+ ** kwargs
243
+ ) as stream_resp :
244
+
245
+ for partial_text in stream_resp .text_stream :
246
+ # For the first token, include `role='assistant'`.
247
+ if first_chunk :
248
+ chunk = ChatCompletionStreamResponse (choices = [
249
+ ChatCompletionStreamResponseChoice (
250
+ delta = ChatCompletionStreamResponseDelta (
251
+ role = "assistant" ,
252
+ content = partial_text
253
+ )
254
+ )
255
+ ])
256
+ first_chunk = False
257
+ else :
258
+ chunk = ChatCompletionStreamResponse (choices = [
259
+ ChatCompletionStreamResponseChoice (
260
+ delta = ChatCompletionStreamResponseDelta (
261
+ content = partial_text
262
+ )
263
+ )
264
+ ])
265
+
266
+ yield chunk
215
267
216
268
def _prepare_kwargs (self , kwargs ):
217
- """Prepare kwargs for the API call."""
269
+ """Prepare kwargs for the Anthropic API call."""
218
270
kwargs = kwargs .copy ()
219
271
kwargs .setdefault ("max_tokens" , DEFAULT_MAX_TOKENS )
220
272
221
273
if "tools" in kwargs :
222
274
kwargs ["tools" ] = self .converter .convert_tool_spec (kwargs ["tools" ])
223
275
224
276
return kwargs
277
+
278
+
279
+
280
+
0 commit comments