Skip to content

Commit 377b5ff

Browse files
deepakn27markpollack
authored andcommitted
Use WebClient.Builder and RestClient.Builder as ctor args in OpenAiAutoConfiguration
Fixes #609
1 parent 8257c0d commit 377b5ff

File tree

4 files changed

+55
-42
lines changed

4 files changed

+55
-42
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public OpenAiApi(String openAiToken) {
7272
* @param openAiToken OpenAI apiKey.
7373
*/
7474
public OpenAiApi(String baseUrl, String openAiToken) {
75-
this(baseUrl, openAiToken, RestClient.builder());
75+
this(baseUrl, openAiToken, RestClient.builder(), WebClient.builder());
7676
}
7777

7878
/**
@@ -82,8 +82,8 @@ public OpenAiApi(String baseUrl, String openAiToken) {
8282
* @param openAiToken OpenAI apiKey.
8383
* @param restClientBuilder RestClient builder.
8484
*/
85-
public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {
86-
this(baseUrl, openAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
85+
public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) {
86+
this(baseUrl, openAiToken, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
8787
}
8888

8989
/**
@@ -94,15 +94,15 @@ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClie
9494
* @param restClientBuilder RestClient builder.
9595
* @param responseErrorHandler Response error handler.
9696
*/
97-
public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
97+
public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
9898

9999
this.restClient = restClientBuilder
100100
.baseUrl(baseUrl)
101101
.defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken))
102102
.defaultStatusHandler(responseErrorHandler)
103103
.build();
104104

105-
this.webClient = WebClient.builder()
105+
this.webClient = webClientBuilder
106106
.baseUrl(baseUrl)
107107
.defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken))
108108
.build();

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.springframework.http.MediaType;
4040
import org.springframework.test.web.client.MockRestServiceServer;
4141
import org.springframework.web.client.RestClient;
42+
import org.springframework.web.reactive.function.client.WebClient;
4243

4344
import static org.assertj.core.api.Assertions.assertThat;
4445
import static org.springframework.test.web.client.match.MockRestRequestMatchers.header;
@@ -166,8 +167,8 @@ private String getJson() {
166167
static class Config {
167168

168169
@Bean
169-
public OpenAiApi chatCompletionApi(RestClient.Builder builder) {
170-
return new OpenAiApi("", TEST_API_KEY, builder);
170+
public OpenAiApi chatCompletionApi(RestClient.Builder builder, WebClient.Builder webClientBuilder) {
171+
return new OpenAiApi("", TEST_API_KEY, builder, webClientBuilder);
171172
}
172173

173174
@Bean

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.springframework.util.StringUtils;
4040
import org.springframework.web.client.ResponseErrorHandler;
4141
import org.springframework.web.client.RestClient;
42+
import org.springframework.web.reactive.function.client.WebClient;
4243

4344
/**
4445
* @author Christian Tzolov
@@ -56,11 +57,13 @@ public class OpenAiAutoConfiguration {
5657
matchIfMissing = true)
5758
public OpenAiChatModel openAiChatModel(OpenAiConnectionProperties commonProperties,
5859
OpenAiChatProperties chatProperties, RestClient.Builder restClientBuilder,
59-
List<FunctionCallback> toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext,
60-
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) {
60+
WebClient.Builder webClientBuilder, List<FunctionCallback> toolFunctionCallbacks,
61+
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate,
62+
ResponseErrorHandler responseErrorHandler) {
6163

6264
var openAiApi = openAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(),
63-
chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler);
65+
chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, webClientBuilder,
66+
responseErrorHandler);
6467

6568
if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
6669
chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks);
@@ -75,25 +78,29 @@ public OpenAiChatModel openAiChatModel(OpenAiConnectionProperties commonProperti
7578
matchIfMissing = true)
7679
public OpenAiEmbeddingModel openAiEmbeddingModel(OpenAiConnectionProperties commonProperties,
7780
OpenAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder,
78-
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) {
81+
WebClient.Builder webClientBuilder, RetryTemplate retryTemplate,
82+
ResponseErrorHandler responseErrorHandler) {
7983

8084
var openAiApi = openAiApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(),
81-
embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler);
85+
embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, webClientBuilder,
86+
responseErrorHandler);
8287

8388
return new OpenAiEmbeddingModel(openAiApi, embeddingProperties.getMetadataMode(),
8489
embeddingProperties.getOptions(), retryTemplate);
8590
}
8691

8792
private OpenAiApi openAiApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey,
88-
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
93+
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
94+
ResponseErrorHandler responseErrorHandler) {
8995

9096
String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl;
9197
Assert.hasText(resolvedBaseUrl, "OpenAI base URL must be set");
9298

9399
String resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey;
94100
Assert.hasText(resolvedApiKey, "OpenAI API key must be set");
95101

96-
return new OpenAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, responseErrorHandler);
102+
return new OpenAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, webClientBuilder,
103+
responseErrorHandler);
97104
}
98105

99106
@Bean
@@ -122,6 +129,7 @@ public OpenAiImageModel openAiImageModel(OpenAiConnectionProperties commonProper
122129
@ConditionalOnMissingBean
123130
public OpenAiAudioTranscriptionModel openAiAudioTranscriptionModel(OpenAiConnectionProperties commonProperties,
124131
OpenAiAudioTranscriptionProperties transcriptionProperties, RetryTemplate retryTemplate,
132+
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
125133
ResponseErrorHandler responseErrorHandler) {
126134

127135
String apiKey = StringUtils.hasText(transcriptionProperties.getApiKey()) ? transcriptionProperties.getApiKey()
@@ -133,7 +141,8 @@ public OpenAiAudioTranscriptionModel openAiAudioTranscriptionModel(OpenAiConnect
133141
Assert.hasText(apiKey, "OpenAI API key must be set");
134142
Assert.hasText(baseUrl, "OpenAI base URL must be set");
135143

136-
var openAiAudioApi = new OpenAiAudioApi(baseUrl, apiKey, RestClient.builder(), responseErrorHandler);
144+
var openAiAudioApi = new OpenAiAudioApi(baseUrl, apiKey, restClientBuilder, webClientBuilder,
145+
responseErrorHandler);
137146

138147
OpenAiAudioTranscriptionModel openAiChatModel = new OpenAiAudioTranscriptionModel(openAiAudioApi,
139148
transcriptionProperties.getOptions(), retryTemplate);
@@ -143,8 +152,9 @@ public OpenAiAudioTranscriptionModel openAiAudioTranscriptionModel(OpenAiConnect
143152

144153
@Bean
145154
@ConditionalOnMissingBean
146-
public OpenAiAudioSpeechModel openAiAudioSpeechModel(OpenAiConnectionProperties commonProperties,
147-
OpenAiAudioSpeechProperties speechProperties, ResponseErrorHandler responseErrorHandler) {
155+
public OpenAiAudioSpeechModel openAiAudioSpeechClient(OpenAiConnectionProperties commonProperties,
156+
OpenAiAudioSpeechProperties speechProperties, RestClient.Builder restClientBuilder,
157+
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
148158

149159
String apiKey = StringUtils.hasText(speechProperties.getApiKey()) ? speechProperties.getApiKey()
150160
: commonProperties.getApiKey();
@@ -155,7 +165,8 @@ public OpenAiAudioSpeechModel openAiAudioSpeechModel(OpenAiConnectionProperties
155165
Assert.hasText(apiKey, "OpenAI API key must be set");
156166
Assert.hasText(baseUrl, "OpenAI base URL must be set");
157167

158-
var openAiAudioApi = new OpenAiAudioApi(baseUrl, apiKey, RestClient.builder(), responseErrorHandler);
168+
var openAiAudioApi = new OpenAiAudioApi(baseUrl, apiKey, restClientBuilder, webClientBuilder,
169+
responseErrorHandler);
159170

160171
OpenAiAudioSpeechModel openAiSpeechModel = new OpenAiAudioSpeechModel(openAiAudioApi,
161172
speechProperties.getOptions());

0 commit comments

Comments
 (0)