Skip to content

Commit e75d82e

Browse files
committed
feat: Add MessageType.DEVELOPER and support in message classes, models, and tests
- Included DEVELOPER in the MessageType enum with value "developer" - Updated AbstractMessage constructor to accept DEVELOPER messages without validation errors - Created DeveloperMessage class with constructors, copy(), mutate(), and builder methods - Added validation to ensure textContent is not null or empty when creating DeveloperMessage - Modified message creation methods (e.g., Neo4jChatMemoryRepositoryIT) to handle MessageType.DEVELOPER - Updated ChatModel and ChatClient methods to support DEVELOPER message type - Developed unit tests for DeveloperMessage, validating exception cases and proper inclusion in prompts - Adjusted existing SystemMessage validation tests to include DEVELOPER for consistency Signed-off-by: Andres da Silva Santos <[email protected]>
1 parent 62cc6f1 commit e75d82e

File tree

20 files changed

+1327
-31
lines changed

20 files changed

+1327
-31
lines changed

memory/spring-ai-model-chat-memory-neo4j/src/test/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepositoryIT.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,8 @@
2525
import org.neo4j.driver.Result;
2626
import org.neo4j.driver.Session;
2727
import org.springframework.ai.chat.memory.ChatMemoryRepository;
28-
import org.springframework.ai.chat.messages.AssistantMessage;
29-
import org.springframework.ai.chat.messages.Message;
30-
import org.springframework.ai.chat.messages.MessageType;
31-
import org.springframework.ai.chat.messages.SystemMessage;
32-
import org.springframework.ai.chat.messages.ToolResponseMessage;
28+
import org.springframework.ai.chat.messages.*;
3329
import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
34-
import org.springframework.ai.chat.messages.UserMessage;
3530
import org.springframework.ai.chat.messages.SystemMessage;
3631
import org.springframework.ai.content.Media;
3732
import org.springframework.util.MimeType;
@@ -402,6 +397,7 @@ private Message createMessageByType(String content, MessageType messageType) {
402397
case ASSISTANT -> new AssistantMessage(content);
403398
case USER -> new UserMessage(content);
404399
case SYSTEM -> new SystemMessage(content);
400+
case DEVELOPER -> new DeveloperMessage(content);
405401
case TOOL -> new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData")));
406402
};
407403
}

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
5050
import com.azure.ai.openai.models.ChatRequestMessage;
5151
import com.azure.ai.openai.models.ChatRequestSystemMessage;
52+
import com.azure.ai.openai.models.ChatRequestDeveloperMessage;
5253
import com.azure.ai.openai.models.ChatRequestToolMessage;
5354
import com.azure.ai.openai.models.ChatRequestUserMessage;
5455
import com.azure.ai.openai.models.CompletionsFinishReason;
@@ -575,6 +576,8 @@ private List<ChatRequestMessage> fromSpringAiMessage(Message message) {
575576
return List.of(new ChatRequestUserMessage(items));
576577
case SYSTEM:
577578
return List.of(new ChatRequestSystemMessage(message.getText()));
579+
case DEVELOPER:
580+
return List.of(new ChatRequestDeveloperMessage(message.getText()));
578581
case ASSISTANT:
579582
AssistantMessage assistantMessage = (AssistantMessage) message;
580583
List<ChatCompletionsToolCall> toolCalls = null;

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
* @author Ilayaperumal Gopinathan
105105
* @author Alexandros Pappas
106106
* @author Soby Chacko
107+
* @author Andres da Silva Santos
107108
* @see ChatModel
108109
* @see StreamingChatModel
109110
* @see OpenAiApi
@@ -552,7 +553,8 @@ private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHead
552553
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
553554

554555
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
555-
if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) {
556+
if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM
557+
|| message.getMessageType() == MessageType.DEVELOPER) {
556558
Object content = message.getText();
557559
if (message instanceof UserMessage userMessage) {
558560
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
* @author Josh Long
5454
* @author Arjen Poutsma
5555
* @author Thomas Vitale
56+
* @author Andres da Silva Santos
5657
* @since 1.0.0
5758
*/
5859
public interface ChatClient {
@@ -133,6 +134,23 @@ interface PromptSystemSpec {
133134

134135
}
135136

137+
/**
138+
* Specification for a prompt developer.
139+
*/
140+
interface PromptDeveloperSpec {
141+
142+
PromptDeveloperSpec text(String text);
143+
144+
PromptDeveloperSpec text(Resource text, Charset charset);
145+
146+
PromptDeveloperSpec text(Resource text);
147+
148+
PromptDeveloperSpec params(Map<String, Object> p);
149+
150+
PromptDeveloperSpec param(String k, Object v);
151+
152+
}
153+
136154
interface AdvisorSpec {
137155

138156
AdvisorSpec param(String k, Object v);
@@ -232,6 +250,14 @@ interface ChatClientRequestSpec {
232250

233251
ChatClientRequestSpec toolContext(Map<String, Object> toolContext);
234252

253+
ChatClientRequestSpec developer(String text);
254+
255+
ChatClientRequestSpec developer(Resource textResource, Charset charset);
256+
257+
ChatClientRequestSpec developer(Resource text);
258+
259+
ChatClientRequestSpec developer(Consumer<PromptDeveloperSpec> consumer);
260+
235261
ChatClientRequestSpec system(String text);
236262

237263
ChatClientRequestSpec system(Resource textResource, Charset charset);
@@ -277,6 +303,14 @@ interface Builder {
277303

278304
Builder defaultUser(Consumer<PromptUserSpec> userSpecConsumer);
279305

306+
Builder defaultDeveloper(String text);
307+
308+
Builder defaultDeveloper(Resource text, Charset charset);
309+
310+
Builder defaultDeveloper(Resource text);
311+
312+
Builder defaultDeveloper(Consumer<PromptDeveloperSpec> developerSpecConsumer);
313+
280314
Builder defaultSystem(String text);
281315

282316
Builder defaultSystem(Resource text, Charset charset);

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 126 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
* @author Soby Chacko
7676
* @author Dariusz Jedrzejczyk
7777
* @author Thomas Vitale
78+
* @author Andres da Silva Santos
7879
* @since 1.0.0
7980
*/
8081
public class DefaultChatClient implements ChatClient {
@@ -288,6 +289,68 @@ protected Map<String, Object> params() {
288289

289290
}
290291

292+
public static class DefaultPromptDeveloperSpec implements PromptDeveloperSpec {
293+
294+
private final Map<String, Object> params = new HashMap<>();
295+
296+
@Nullable
297+
private String text;
298+
299+
@Override
300+
public PromptDeveloperSpec text(String text) {
301+
Assert.hasText(text, "text cannot be null or empty");
302+
this.text = text;
303+
return this;
304+
}
305+
306+
@Override
307+
public PromptDeveloperSpec text(Resource text, Charset charset) {
308+
Assert.notNull(text, "text cannot be null");
309+
Assert.notNull(charset, "charset cannot be null");
310+
try {
311+
this.text(text.getContentAsString(charset));
312+
}
313+
catch (IOException e) {
314+
throw new RuntimeException(e);
315+
}
316+
return this;
317+
}
318+
319+
@Override
320+
public PromptDeveloperSpec text(Resource text) {
321+
Assert.notNull(text, "text cannot be null");
322+
this.text(text, Charset.defaultCharset());
323+
return this;
324+
}
325+
326+
@Override
327+
public PromptDeveloperSpec param(String key, Object value) {
328+
Assert.hasText(key, "key cannot be null or empty");
329+
Assert.notNull(value, "value cannot be null");
330+
this.params.put(key, value);
331+
return this;
332+
}
333+
334+
@Override
335+
public PromptDeveloperSpec params(Map<String, Object> params) {
336+
Assert.notNull(params, "params cannot be null");
337+
Assert.noNullElements(params.keySet(), "param keys cannot contain null elements");
338+
Assert.noNullElements(params.values(), "param values cannot contain null elements");
339+
this.params.putAll(params);
340+
return this;
341+
}
342+
343+
@Nullable
344+
protected String text() {
345+
return this.text;
346+
}
347+
348+
protected Map<String, Object> params() {
349+
return this.params;
350+
}
351+
352+
}
353+
291354
public static class DefaultAdvisorSpec implements AdvisorSpec {
292355

293356
private final List<Advisor> advisors = new ArrayList<>();
@@ -577,6 +640,8 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
577640

578641
private final Map<String, Object> systemParams = new HashMap<>();
579642

643+
private final Map<String, Object> developerParams = new HashMap<>();
644+
580645
private final List<Advisor> advisors = new ArrayList<>();
581646

582647
private final Map<String, Object> advisorParams = new HashMap<>();
@@ -591,27 +656,32 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
591656
@Nullable
592657
private String systemText;
593658

659+
@Nullable
660+
private String developerText;
661+
594662
@Nullable
595663
private ChatOptions chatOptions;
596664

597665
/* copy constructor */
598666
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
599-
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.toolCallbacks,
600-
ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
601-
ccr.observationRegistry, ccr.observationConvention, ccr.toolContext, ccr.templateRenderer);
667+
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.developerText,
668+
ccr.developerParams, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions,
669+
ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention,
670+
ccr.toolContext, ccr.templateRenderer);
602671
}
603672

604673
public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
605674
Map<String, Object> userParams, @Nullable String systemText, Map<String, Object> systemParams,
606-
List<ToolCallback> toolCallbacks, List<Message> messages, List<String> toolNames, List<Media> media,
607-
@Nullable ChatOptions chatOptions, List<Advisor> advisors, Map<String, Object> advisorParams,
608-
ObservationRegistry observationRegistry,
675+
@Nullable String developerText, Map<String, Object> developerParams, List<ToolCallback> toolCallbacks,
676+
List<Message> messages, List<String> toolNames, List<Media> media, @Nullable ChatOptions chatOptions,
677+
List<Advisor> advisors, Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
609678
@Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext,
610679
@Nullable TemplateRenderer templateRenderer) {
611680

612681
Assert.notNull(chatModel, "chatModel cannot be null");
613682
Assert.notNull(userParams, "userParams cannot be null");
614683
Assert.notNull(systemParams, "systemParams cannot be null");
684+
Assert.notNull(developerParams, "developerParams cannot be null");
615685
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
616686
Assert.notNull(messages, "messages cannot be null");
617687
Assert.notNull(toolNames, "toolNames cannot be null");
@@ -629,6 +699,8 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe
629699
this.userParams.putAll(userParams);
630700
this.systemText = systemText;
631701
this.systemParams.putAll(systemParams);
702+
this.developerText = developerText;
703+
this.developerParams.putAll(developerParams);
632704

633705
this.toolNames.addAll(toolNames);
634706
this.toolCallbacks.addAll(toolCallbacks);
@@ -661,6 +733,15 @@ public Map<String, Object> getSystemParams() {
661733
return this.systemParams;
662734
}
663735

736+
@Nullable
737+
public String getDeveloperText() {
738+
return this.developerText;
739+
}
740+
741+
public Map<String, Object> getDeveloperParams() {
742+
return this.developerParams;
743+
}
744+
664745
@Nullable
665746
public ChatOptions getChatOptions() {
666747
return this.chatOptions;
@@ -719,6 +800,10 @@ public Builder mutate() {
719800
builder.defaultSystem(s -> s.text(this.systemText).params(this.systemParams));
720801
}
721802

803+
if (StringUtils.hasText(this.developerText)) {
804+
builder.defaultDeveloper(s -> s.text(this.developerText).params(this.developerParams));
805+
}
806+
722807
if (this.chatOptions != null) {
723808
builder.defaultOptions(this.chatOptions);
724809
}
@@ -821,6 +906,41 @@ public ChatClientRequestSpec toolContext(Map<String, Object> toolContext) {
821906
return this;
822907
}
823908

909+
public ChatClientRequestSpec developer(String text) {
910+
Assert.hasText(text, "text cannot be null or empty");
911+
this.developerText = text;
912+
return this;
913+
}
914+
915+
public ChatClientRequestSpec developer(Resource text, Charset charset) {
916+
Assert.notNull(text, "text cannot be null");
917+
Assert.notNull(charset, "charset cannot be null");
918+
919+
try {
920+
this.developerText = text.getContentAsString(charset);
921+
}
922+
catch (IOException e) {
923+
throw new RuntimeException(e);
924+
}
925+
return this;
926+
}
927+
928+
public ChatClientRequestSpec developer(Resource text) {
929+
Assert.notNull(text, "text cannot be null");
930+
return this.developer(text, Charset.defaultCharset());
931+
}
932+
933+
public ChatClientRequestSpec developer(Consumer<PromptDeveloperSpec> consumer) {
934+
Assert.notNull(consumer, "consumer cannot be null");
935+
936+
var developerSpec = new DefaultPromptDeveloperSpec();
937+
consumer.accept(developerSpec);
938+
this.developerText = StringUtils.hasText(developerSpec.text()) ? developerSpec.text() : this.developerText;
939+
this.developerParams.putAll(developerSpec.params());
940+
941+
return this;
942+
}
943+
824944
public ChatClientRequestSpec system(String text) {
825945
Assert.hasText(text, "text cannot be null or empty");
826946
this.systemText = text;

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import org.springframework.ai.chat.client.ChatClient.Builder;
2828
import org.springframework.ai.chat.client.ChatClient.PromptSystemSpec;
29+
import org.springframework.ai.chat.client.ChatClient.PromptDeveloperSpec;
2930
import org.springframework.ai.chat.client.ChatClient.PromptUserSpec;
3031
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
3132
import org.springframework.ai.chat.client.advisor.api.Advisor;
@@ -50,6 +51,7 @@
5051
* @author Josh Long
5152
* @author Arjen Poutsma
5253
* @author Thomas Vitale
54+
* @author Andres da Silva Santos
5355
* @since 1.0.0
5456
*/
5557
public class DefaultChatClientBuilder implements Builder {
@@ -64,8 +66,8 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa
6466
@Nullable ChatClientObservationConvention customObservationConvention) {
6567
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
6668
Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null");
67-
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), null, Map.of(), List.of(),
68-
List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
69+
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), null, Map.of(), null,
70+
Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
6971
customObservationConvention, Map.of(), null);
7072
}
7173

@@ -149,6 +151,32 @@ public Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer) {
149151
return this;
150152
}
151153

154+
public Builder defaultDeveloper(String text) {
155+
this.defaultRequest.developer(text);
156+
return this;
157+
}
158+
159+
public Builder defaultDeveloper(Resource text, Charset charset) {
160+
Assert.notNull(text, "text cannot be null");
161+
Assert.notNull(charset, "charset cannot be null");
162+
try {
163+
this.defaultRequest.developer(text.getContentAsString(charset));
164+
}
165+
catch (IOException e) {
166+
throw new RuntimeException(e);
167+
}
168+
return this;
169+
}
170+
171+
public Builder defaultDeveloper(Resource text) {
172+
return this.defaultDeveloper(text, Charset.defaultCharset());
173+
}
174+
175+
public Builder defaultDeveloper(Consumer<PromptDeveloperSpec> developerSpecConsumer) {
176+
this.defaultRequest.developer(developerSpecConsumer);
177+
return this;
178+
}
179+
152180
@Override
153181
public Builder defaultToolNames(String... toolNames) {
154182
this.defaultRequest.toolNames(toolNames);

0 commit comments

Comments
 (0)