Skip to content

Commit cd3b374

Browse files
committed
Extend ChatClient API
- add support for custom StructuredOutputConverters usng entity(outputConverterInstance) - add mutate() method that returns ChatClient.Builder to create a new ChatClient whose settings are replicated from the ChatClient's default settings. - add prompt().mutate() method that returns ChatClient.Builder to create a new ChatClient whose settings are replicated from the current default and prompt settings.
1 parent c49a5c8 commit cd3b374

File tree

4 files changed

+317
-18
lines changed

4 files changed

+317
-18
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java renamed to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
package org.springframework.ai.openai.chat;
16+
package org.springframework.ai.openai.chat.client;
1717

1818
import java.io.IOException;
1919
import java.net.URL;
2020
import java.util.Arrays;
21-
import java.util.Collection;
2221
import java.util.List;
2322
import java.util.Map;
2423
import java.util.stream.Collectors;
@@ -34,6 +33,7 @@
3433
import org.springframework.ai.chat.client.ChatClient;
3534
import org.springframework.ai.chat.model.ChatResponse;
3635
import org.springframework.ai.converter.BeanOutputConverter;
36+
import org.springframework.ai.converter.ListOutputConverter;
3737
import org.springframework.ai.openai.OpenAiChatOptions;
3838
import org.springframework.ai.openai.OpenAiTestConfiguration;
3939
import org.springframework.ai.openai.api.OpenAiApi;
@@ -42,6 +42,7 @@
4242
import org.springframework.beans.factory.annotation.Value;
4343
import org.springframework.boot.test.context.SpringBootTest;
4444
import org.springframework.core.ParameterizedTypeReference;
45+
import org.springframework.core.convert.support.DefaultConversionService;
4546
import org.springframework.core.io.ClassPathResource;
4647
import org.springframework.core.io.Resource;
4748
import org.springframework.util.MimeTypeUtils;
@@ -79,20 +80,21 @@ void roleTest() {
7980
}
8081

8182
@Test
82-
void listOutputConverter() {
83+
void listOutputConverterString() {
8384
// @formatter:off
84-
Collection<String> collection = ChatClient.builder(chatModel).build().prompt()
85+
List<String> collection = ChatClient.builder(chatModel).build().prompt()
8586
.user(u -> u.text("List five {subject}")
8687
.param("subject", "ice cream flavors"))
8788
.call()
8889
.entity(new ParameterizedTypeReference<List<String>>() {});
8990
// @formatter:on
9091

92+
logger.info(collection.toString());
9193
assertThat(collection).hasSize(5);
9294
}
9395

9496
@Test
95-
void listOutputConverter2() {
97+
void listOutputConverterBean() {
9698

9799
// @formatter:off
98100
List<ActorsFilms> actorsFilms = ChatClient.builder(chatModel).build().prompt()
@@ -106,6 +108,24 @@ void listOutputConverter2() {
106108
assertThat(actorsFilms).hasSize(2);
107109
}
108110

111+
@Test
112+
void customOutputConverter() {
113+
114+
var toStringListConverter = new ListOutputConverter(new DefaultConversionService());
115+
116+
// @formatter:off
117+
List<String> flavors = ChatClient.builder(chatModel).build().prompt()
118+
.user(u -> u.text("List five {subject}")
119+
.param("subject", "ice cream flavors"))
120+
.call()
121+
.entity(toStringListConverter);
122+
// @formatter:on
123+
124+
logger.info("ice cream flavors" + flavors);
125+
assertThat(flavors).hasSize(5);
126+
assertThat(flavors).contains("Vanilla");
127+
}
128+
109129
@Test
110130
void mapOutputConverter() {
111131
// @formatter:off
@@ -196,6 +216,24 @@ void functionCallTest() {
196216
assertThat(response).containsAnyOf("15.0", "15");
197217
}
198218

219+
@Test
220+
void defaultFunctionCallTest() {
221+
222+
// @formatter:off
223+
String response = ChatClient.builder(chatModel)
224+
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
225+
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
226+
.build()
227+
.prompt().call().content();
228+
// @formatter:on
229+
230+
logger.info("Response: {}", response);
231+
232+
assertThat(response).containsAnyOf("30.0", "30");
233+
assertThat(response).containsAnyOf("10.0", "10");
234+
assertThat(response).containsAnyOf("15.0", "15");
235+
}
236+
199237
@Test
200238
void streamFunctionCallTest() {
201239

spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.springframework.ai.chat.prompt.Prompt;
4040
import org.springframework.ai.chat.prompt.PromptTemplate;
4141
import org.springframework.ai.converter.BeanOutputConverter;
42+
import org.springframework.ai.converter.StructuredOutputConverter;
4243
import org.springframework.ai.model.function.FunctionCallback;
4344
import org.springframework.ai.model.function.FunctionCallbackWrapper;
4445
import org.springframework.ai.model.function.FunctionCallingOptions;
@@ -74,6 +75,12 @@ static Builder builder(ChatModel chatModel) {
7475

7576
ChatClientPromptRequest prompt(Prompt prompt);
7677

78+
/**
79+
* Return a {@link ChatClient.Builder} to create a new {@link ChatClient} whose
80+
* settings are replicated from the default {@link ChatClientRequest} of this client.
81+
*/
82+
Builder mutate();
83+
7784
interface PromptSpec<T> {
7885

7986
T text(String text);
@@ -223,6 +230,26 @@ class ChatClientRequest {
223230

224231
private final Map<String, Object> systemParams = new HashMap<>();
225232

233+
/**
234+
* Return a {@code ChatClient.Builder} to create a new {@code ChatClient} whose
235+
* settings are replicated from this {@code ChatClientRequest}.
236+
*/
237+
public Builder mutate() {
238+
Builder builder = ChatClient.builder(chatModel)
239+
.defaultSystem(s -> s.text(this.systemText).params(this.systemParams))
240+
.defaultUser(u -> u.text(this.userText)
241+
.params(this.userParams)
242+
.media(this.media.toArray(new Media[this.media.size()])))
243+
.defaultOptions(this.chatOptions)
244+
.defaultFunctions(StringUtils.toStringArray(this.functionNames));
245+
246+
// workaround to set the missing fields.
247+
builder.defaultRequest.messages.addAll(this.messages);
248+
builder.defaultRequest.functionCallbacks.addAll(this.functionCallbacks);
249+
250+
return builder;
251+
}
252+
226253
/* copy constructor */
227254
ChatClientRequest(ChatClientRequest ccr) {
228255
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
@@ -411,7 +438,11 @@ public <T> T entity(ParameterizedTypeReference<T> type) {
411438
return doSingleWithBeanOutputConverter(new BeanOutputConverter<T>(type));
412439
}
413440

414-
private <T> T doSingleWithBeanOutputConverter(BeanOutputConverter<T> boc) {
441+
public <T> T entity(StructuredOutputConverter<T> structuredOutputConverter) {
442+
return doSingleWithBeanOutputConverter(structuredOutputConverter);
443+
}
444+
445+
private <T> T doSingleWithBeanOutputConverter(StructuredOutputConverter<T> boc) {
415446
var processedUserText = this.request.userText + System.lineSeparator() + System.lineSeparator()
416447
+ "{format}";
417448
var chatResponse = doGetChatResponse(processedUserText, boc.getFormat());
@@ -435,13 +466,15 @@ private ChatResponse doGetChatResponse(String processedUserText, String formatPa
435466
userParams.put("format", formatParam);
436467
}
437468

438-
var messages = new ArrayList<Message>();
469+
var messages = new ArrayList<Message>(this.request.messages);
439470
var textsAreValid = (StringUtils.hasText(processedUserText)
440471
|| StringUtils.hasText(this.request.systemText));
441-
var messagesAreValid = !this.request.messages.isEmpty();
442-
Assert.state(!(messagesAreValid && textsAreValid), "you must specify either " + Message.class.getName()
443-
+ " instances or user/system texts, but not both");
444472
if (textsAreValid) {
473+
if (StringUtils.hasText(this.request.systemText) || !this.request.systemParams.isEmpty()) {
474+
var systemMessage = new SystemMessage(
475+
new PromptTemplate(this.request.systemText, this.request.systemParams).render());
476+
messages.add(systemMessage);
477+
}
445478
UserMessage userMessage = null;
446479
if (!CollectionUtils.isEmpty(userParams)) {
447480
userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(),
@@ -450,16 +483,8 @@ private ChatResponse doGetChatResponse(String processedUserText, String formatPa
450483
else {
451484
userMessage = new UserMessage(processedUserText, this.request.media);
452485
}
453-
if (StringUtils.hasText(this.request.systemText) || !this.request.systemParams.isEmpty()) {
454-
var systemMessage = new SystemMessage(
455-
new PromptTemplate(this.request.systemText, this.request.systemParams).render());
456-
messages.add(systemMessage);
457-
}
458486
messages.add(userMessage);
459487
}
460-
else {
461-
messages.addAll(this.request.messages);
462-
}
463488
if (this.request.chatOptions instanceof FunctionCallingOptions functionCallingOptions) {
464489
// if (this.request.chatOptions instanceof
465490
// FunctionCallingOptionsBuilder.PortableFunctionCallingOptions

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@ public ChatClientPromptRequest prompt(Prompt prompt) {
3535
return new ChatClientPromptRequest(this.chatModel, prompt);
3636
}
3737

38+
/**
39+
* Return a {@code ChatClient.Builder} to create a new {@code ChatClient} whose
40+
* settings are replicated from this {@code ChatClientRequest}.
41+
*/
42+
@Override
43+
public Builder mutate() {
44+
return this.defaultChatClientRequest.mutate();
45+
}
46+
3847
/**
3948
* use the new fluid DSL starting in {@link #prompt()}
4049
* @param prompt the {@link Prompt prompt} object

0 commit comments

Comments
 (0)