Skip to content

Commit 32bfba4

Browse files
committed
Add Dashscope AI
1 parent fa53e2a commit 32bfba4

37 files changed

+2728
-0
lines changed

models/spring-ai-dashscope/pom.xml

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<modelVersion>4.0.0</modelVersion>
6+
<parent>
7+
<groupId>org.springframework.ai</groupId>
8+
<artifactId>spring-ai</artifactId>
9+
<version>1.0.0-SNAPSHOT</version>
10+
<relativePath>../../pom.xml</relativePath>
11+
</parent>
12+
<artifactId>spring-ai-dashscope</artifactId>
13+
<packaging>jar</packaging>
14+
<name>Spring AI Dashscope</name>
15+
<description>Dashscope support</description>
16+
17+
<dependencies>
18+
19+
<!-- production dependencies -->
20+
<dependency>
21+
<groupId>org.springframework.ai</groupId>
22+
<artifactId>spring-ai-core</artifactId>
23+
<version>${project.parent.version}</version>
24+
</dependency>
25+
26+
<dependency>
27+
<groupId>org.springframework.ai</groupId>
28+
<artifactId>spring-ai-retry</artifactId>
29+
<version>${project.parent.version}</version>
30+
</dependency>
31+
32+
<!-- NOTE: Required only by the @ConstructorBinding. -->
33+
<dependency>
34+
<groupId>org.springframework.boot</groupId>
35+
<artifactId>spring-boot</artifactId>
36+
</dependency>
37+
38+
<dependency>
39+
<groupId>io.rest-assured</groupId>
40+
<artifactId>json-path</artifactId>
41+
</dependency>
42+
43+
<dependency>
44+
<groupId>com.github.victools</groupId>
45+
<artifactId>jsonschema-generator</artifactId>
46+
<version>${victools.version}</version>
47+
</dependency>
48+
49+
<dependency>
50+
<groupId>com.github.victools</groupId>
51+
<artifactId>jsonschema-module-jackson</artifactId>
52+
<version>${victools.version}</version>
53+
</dependency>
54+
55+
<dependency>
56+
<groupId>org.springframework</groupId>
57+
<artifactId>spring-context-support</artifactId>
58+
</dependency>
59+
<dependency>
60+
<groupId>org.springframework.boot</groupId>
61+
<artifactId>spring-boot-starter-logging</artifactId>
62+
</dependency>
63+
64+
<!-- test dependencies -->
65+
<dependency>
66+
<groupId>org.springframework.ai</groupId>
67+
<artifactId>spring-ai-test</artifactId>
68+
<version>${project.version}</version>
69+
<scope>test</scope>
70+
</dependency>
71+
</dependencies>
72+
</project>
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
package org.springframework.ai.dashscope;
2+
3+
import org.slf4j.Logger;
4+
import org.slf4j.LoggerFactory;
5+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
6+
import org.springframework.ai.chat.model.ChatModel;
7+
import org.springframework.ai.chat.model.ChatResponse;
8+
import org.springframework.ai.chat.model.Generation;
9+
import org.springframework.ai.chat.model.StreamingChatModel;
10+
import org.springframework.ai.chat.prompt.ChatOptions;
11+
import org.springframework.ai.chat.prompt.Prompt;
12+
import org.springframework.ai.dashscope.api.DashscopeApi;
13+
import org.springframework.ai.dashscope.metadata.DashscopeChatResponseMetadata;
14+
import org.springframework.ai.dashscope.record.chat.ChatCompletion;
15+
import org.springframework.ai.dashscope.record.chat.ChatCompletionChoice;
16+
import org.springframework.ai.dashscope.record.chat.ChatCompletionRequestInput;
17+
import org.springframework.ai.dashscope.record.chat.ChatCompletionMessage;
18+
import org.springframework.ai.dashscope.record.chat.ChatCompletionRequest;
19+
import org.springframework.ai.dashscope.record.chat.ChatCompletionRequestParameters;
20+
import org.springframework.ai.dashscope.record.chat.ToolCall;
21+
import org.springframework.ai.model.ModelOptionsUtils;
22+
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
23+
import org.springframework.ai.model.function.FunctionCallbackContext;
24+
import org.springframework.ai.retry.RetryUtils;
25+
import org.springframework.http.ResponseEntity;
26+
import org.springframework.retry.support.RetryTemplate;
27+
import org.springframework.util.Assert;
28+
import org.springframework.util.CollectionUtils;
29+
import reactor.core.publisher.Flux;
30+
31+
import java.util.HashMap;
32+
import java.util.HashSet;
33+
import java.util.List;
34+
import java.util.Map;
35+
import java.util.Set;
36+
import java.util.concurrent.ConcurrentHashMap;
37+
import java.util.stream.Collectors;
38+
39+
/**
40+
* @author Nottyjay Ji
41+
*/
42+
public class DashscopeChatModel extends
43+
AbstractFunctionCallSupport<ChatCompletionMessage, ChatCompletionRequest, ResponseEntity<ChatCompletion>>
44+
implements ChatModel, StreamingChatModel {
45+
46+
private static final Logger logger = LoggerFactory.getLogger(DashscopeChatModel.class);
47+
48+
/** Low-level access to the Dashscope API */
49+
private final DashscopeApi dashscopeApi;
50+
51+
/** The retry template used to retry the OpenAI API calls. */
52+
public final RetryTemplate retryTemplate;
53+
54+
/** The default options used for the chat completion requests. */
55+
private DashscopeChatOptions defaultOptions;
56+
57+
public DashscopeChatModel(DashscopeApi dashscopeApi) {
58+
this(dashscopeApi,
59+
DashscopeChatOptions.builder()
60+
.withModel(DashscopeApi.DEFAULT_CHAT_MODEL)
61+
.withTemperature(0.7f)
62+
.build());
63+
}
64+
65+
public DashscopeChatModel(DashscopeApi dashscopeApi, DashscopeChatOptions options) {
66+
this(dashscopeApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
67+
}
68+
69+
public DashscopeChatModel(DashscopeApi dashscopeApi, DashscopeChatOptions options,
70+
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
71+
super(functionCallbackContext);
72+
Assert.notNull(dashscopeApi, "DashscopeApi must not be null");
73+
Assert.notNull(options, "Options must not be null");
74+
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
75+
76+
this.dashscopeApi = dashscopeApi;
77+
this.defaultOptions = options;
78+
this.retryTemplate = retryTemplate;
79+
}
80+
81+
@Override
82+
public ChatResponse call(Prompt prompt) {
83+
ChatCompletionRequest request = createRequest(prompt, false);
84+
return this.retryTemplate.execute(ctx -> {
85+
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
86+
87+
var chatCompletion = completionEntity.getBody();
88+
if (chatCompletion == null) {
89+
logger.warn("No chat completion returned for prompt: {}", prompt);
90+
return new ChatResponse(List.of());
91+
}
92+
93+
List<Generation> generations = chatCompletion.output().choices().stream().map(choice -> {
94+
return new Generation(choice.message().content(), toMap(chatCompletion.requestId(), choice))
95+
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason(), null));
96+
}).toList();
97+
98+
return new ChatResponse(generations,
99+
DashscopeChatResponseMetadata.from(chatCompletion.usage(), chatCompletion.requestId()));
100+
});
101+
}
102+
103+
@Override
104+
public ChatOptions getDefaultOptions() {
105+
return DashscopeChatOptions.fromOptions(this.defaultOptions);
106+
}
107+
108+
private Map<String, Object> toMap(String id, ChatCompletionChoice choice) {
109+
Map<String, Object> map = new HashMap<>();
110+
111+
var message = choice.message();
112+
if (message.role() != null) {
113+
map.put("role", message.role().name());
114+
}
115+
if (choice.finishReason() != null) {
116+
map.put("finishReason", choice.finishReason());
117+
}
118+
map.put("id", id);
119+
return map;
120+
}
121+
122+
@Override
123+
public Flux<ChatResponse> stream(Prompt prompt) {
124+
ChatCompletionRequest request = createRequest(prompt, true);
125+
return this.retryTemplate.execute(ctx -> {
126+
Flux<ChatCompletion> chatCompletionFlux = this.dashscopeApi.chatCompletionStream(request);
127+
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
128+
return chatCompletionFlux.map(chatCompletion -> {
129+
String id = chatCompletion.requestId();
130+
List<Generation> generations = chatCompletion.output().choices().stream().map(choice -> {
131+
if (choice.message().role() != null) {
132+
roleMap.putIfAbsent(id, choice.message().role().name());
133+
}
134+
String finish = (choice.finishReason() != null ? choice.finishReason() : "");
135+
var generation = new Generation(choice.message().content(),
136+
Map.of("requestId", id, "role", roleMap.get(id), "finishReason", finish));
137+
if (choice.finishReason() != null) {
138+
generation = generation
139+
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason(), null));
140+
}
141+
return generation;
142+
}).toList();
143+
144+
return new ChatResponse(generations,
145+
DashscopeChatResponseMetadata.from(chatCompletion.usage(), chatCompletion.requestId()));
146+
});
147+
});
148+
}
149+
150+
@Override
151+
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
152+
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
153+
// Every tool-call item requires a separate function call and a response (TOOL)
154+
// message.
155+
for (ToolCall toolCall : responseMessage.toolCalls()) {
156+
157+
var functionName = toolCall.function().name();
158+
String functionArguments = toolCall.function().arguments();
159+
160+
if (!this.functionCallbackRegister.containsKey(functionName)) {
161+
throw new IllegalStateException("No function callback found for function name: " + functionName);
162+
}
163+
164+
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
165+
166+
// Add the function response to the conversation.
167+
conversationHistory
168+
.add(new ChatCompletionMessage(ChatCompletionMessage.Role.TOOL, functionResponse, functionName, null));
169+
}
170+
171+
// Recursively call chatCompletionWithTools until the model doesn't call a
172+
// functions anymore.
173+
ChatCompletionRequest newRequest = new ChatCompletionRequest(
174+
new ChatCompletionRequestInput(conversationHistory), false);
175+
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);
176+
177+
return newRequest;
178+
}
179+
180+
@Override
181+
protected List<ChatCompletionMessage> doGetUserMessages(ChatCompletionRequest request) {
182+
return request.chatCompletionInput().messages();
183+
}
184+
185+
@Override
186+
protected ChatCompletionMessage doGetToolResponseMessage(ResponseEntity<ChatCompletion> chatCompletion) {
187+
return chatCompletion.getBody().output().choices().iterator().next().message();
188+
}
189+
190+
@Override
191+
protected ResponseEntity<ChatCompletion> doChatCompletion(ChatCompletionRequest request) {
192+
return this.dashscopeApi.chatCompletionEntity(request);
193+
}
194+
195+
@Override
196+
protected Flux<ResponseEntity<ChatCompletion>> doChatCompletionStream(ChatCompletionRequest request) {
197+
return null;
198+
}
199+
200+
@Override
201+
protected boolean isToolFunctionCall(ResponseEntity<ChatCompletion> chatCompletion) {
202+
var body = chatCompletion.getBody();
203+
if (body == null) {
204+
return false;
205+
}
206+
207+
var choices = body.output().choices();
208+
if (CollectionUtils.isEmpty(choices)) {
209+
return false;
210+
}
211+
212+
var choice = choices.get(0);
213+
return !CollectionUtils.isEmpty(choice.message().toolCalls());
214+
}
215+
216+
private ChatCompletionRequest createRequest(Prompt prompt, boolean isStream) {
217+
Set<String> functionsForThisRequest = new HashSet<>();
218+
String model = this.defaultOptions.getModel();
219+
220+
// 构造请求中的messages
221+
List<ChatCompletionMessage> chatCompletionInputsMessages = prompt.getInstructions().stream().map(m -> {
222+
return new ChatCompletionMessage(ChatCompletionMessage.Role.valueOf(m.getMessageType().name()),
223+
m.getContent());
224+
}).collect(Collectors.toList());
225+
226+
// 构造请求中的parameters
227+
ChatCompletionRequestParameters chatCompletionRequestParameters = new ChatCompletionRequestParameters(null,
228+
null, null, isStream, null, null, null, null, null, null, null);
229+
if (prompt.getOptions() != null) {
230+
if (prompt.getOptions() instanceof DashscopeChatOptions) {
231+
model = ((DashscopeChatOptions) prompt.getOptions()).getModel();
232+
}
233+
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
234+
235+
DashscopeChatOptions updateRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
236+
ChatOptions.class, DashscopeChatOptions.class);
237+
238+
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updateRuntimeOptions,
239+
IS_RUNTIME_CALL);
240+
functionsForThisRequest.addAll(promptEnabledFunctions);
241+
242+
chatCompletionRequestParameters = ModelOptionsUtils.merge(chatCompletionRequestParameters,
243+
updateRuntimeOptions, ChatCompletionRequestParameters.class);
244+
}
245+
}
246+
247+
if (this.defaultOptions != null) {
248+
Set<String> defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions,
249+
!IS_RUNTIME_CALL);
250+
251+
functionsForThisRequest.addAll(defaultEnabledFunctions);
252+
253+
chatCompletionRequestParameters = ModelOptionsUtils.merge(chatCompletionRequestParameters,
254+
this.defaultOptions, ChatCompletionRequestParameters.class);
255+
}
256+
257+
// Add the enabled functions definitions to the request's tools parameter.
258+
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
259+
260+
chatCompletionRequestParameters = ModelOptionsUtils.merge(
261+
DashscopeChatOptions.builder().withTools(this.getFunctionTools(functionsForThisRequest)).build(),
262+
chatCompletionRequestParameters, ChatCompletionRequestParameters.class);
263+
}
264+
265+
return new ChatCompletionRequest(model, isStream, new ChatCompletionRequestInput(chatCompletionInputsMessages),
266+
chatCompletionRequestParameters);
267+
}
268+
269+
private List<DashscopeApi.FunctionTool> getFunctionTools(Set<String> functionNames) {
270+
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
271+
var function = new DashscopeApi.FunctionTool.Function(functionCallback.getDescription(),
272+
functionCallback.getName(), functionCallback.getInputTypeSchema());
273+
return new DashscopeApi.FunctionTool(function);
274+
}).toList();
275+
}
276+
277+
}

0 commit comments

Comments
 (0)