Skip to content

Commit 8257c0d

Browse files
committed
Improve ChatClient impl. parameter cohesion
1 parent cd3b374 commit 8257c0d

File tree

2 files changed

+14
-15
lines changed

2 files changed

+14
-15
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ void customOutputConverter() {
123123

124124
logger.info("ice cream flavors" + flavors);
125125
assertThat(flavors).hasSize(5);
126-
assertThat(flavors).contains("Vanilla");
126+
assertThat(flavors).containsAnyOf("Vanilla", "vanilla");
127127
}
128128

129129
@Test

spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@
2626
import java.util.Map;
2727
import java.util.function.Consumer;
2828

29-
import org.springframework.ai.chat.model.ChatModel;
30-
import org.springframework.ai.chat.model.ChatResponse;
31-
import org.springframework.ai.chat.model.StreamingChatModel;
3229
import reactor.core.publisher.Flux;
3330

3431
import org.springframework.ai.chat.messages.Media;
3532
import org.springframework.ai.chat.messages.Message;
3633
import org.springframework.ai.chat.messages.SystemMessage;
3734
import org.springframework.ai.chat.messages.UserMessage;
35+
import org.springframework.ai.chat.model.ChatModel;
36+
import org.springframework.ai.chat.model.ChatResponse;
37+
import org.springframework.ai.chat.model.StreamingChatModel;
3838
import org.springframework.ai.chat.prompt.ChatOptions;
3939
import org.springframework.ai.chat.prompt.Prompt;
4040
import org.springframework.ai.chat.prompt.PromptTemplate;
@@ -443,9 +443,7 @@ public <T> T entity(StructuredOutputConverter<T> structuredOutputConverter) {
443443
}
444444

445445
private <T> T doSingleWithBeanOutputConverter(StructuredOutputConverter<T> boc) {
446-
var processedUserText = this.request.userText + System.lineSeparator() + System.lineSeparator()
447-
+ "{format}";
448-
var chatResponse = doGetChatResponse(processedUserText, boc.getFormat());
446+
var chatResponse = doGetChatResponse(boc.getFormat());
449447
var stringResponse = chatResponse.getResult().getOutput().getContent();
450448
return boc.convert(stringResponse);
451449
}
@@ -456,11 +454,15 @@ public <T> T entity(Class<T> type) {
456454
return doSingleWithBeanOutputConverter(boc);
457455
}
458456

459-
private ChatResponse doGetChatResponse(String processedUserText) {
460-
return this.doGetChatResponse(processedUserText, "");
457+
private ChatResponse doGetChatResponse() {
458+
return this.doGetChatResponse("");
461459
}
462460

463-
private ChatResponse doGetChatResponse(String processedUserText, String formatParam) {
461+
private ChatResponse doGetChatResponse(String formatParam) {
462+
463+
var processedUserText = StringUtils.hasText(formatParam)
464+
? this.request.userText + System.lineSeparator() + "{format}" : this.request.userText;
465+
464466
Map<String, Object> userParams = new HashMap<>(this.request.userParams);
465467
if (StringUtils.hasText(formatParam)) {
466468
userParams.put("format", formatParam);
@@ -486,9 +488,6 @@ private ChatResponse doGetChatResponse(String processedUserText, String formatPa
486488
messages.add(userMessage);
487489
}
488490
if (this.request.chatOptions instanceof FunctionCallingOptions functionCallingOptions) {
489-
// if (this.request.chatOptions instanceof
490-
// FunctionCallingOptionsBuilder.PortableFunctionCallingOptions
491-
// functionCallingOptions) {
492491
if (!this.request.functionNames.isEmpty()) {
493492
functionCallingOptions.setFunctions(new HashSet<>(this.request.functionNames));
494493
}
@@ -501,11 +500,11 @@ private ChatResponse doGetChatResponse(String processedUserText, String formatPa
501500
}
502501

503502
public ChatResponse chatResponse() {
504-
return doGetChatResponse(this.request.userText);
503+
return doGetChatResponse();
505504
}
506505

507506
public String content() {
508-
return doGetChatResponse(this.request.userText).getResult().getOutput().getContent();
507+
return doGetChatResponse().getResult().getOutput().getContent();
509508
}
510509

511510
}

0 commit comments

Comments
 (0)