Skip to content

Commit 719c558

Browse files
committed
Add OpenAI auto-configuraiton integraton tests
1 parent d04d80d commit 719c558

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import com.azure.ai.openai.OpenAIClient;
2222
import com.azure.ai.openai.OpenAIClientBuilder;
2323
import com.azure.core.credential.AzureKeyCredential;
24+
2425
import org.springframework.ai.azure.openai.client.AzureOpenAiClient;
2526
import org.springframework.ai.azure.openai.embedding.AzureOpenAiEmbeddingClient;
2627
import org.springframework.ai.client.AiClient;
@@ -60,6 +61,7 @@ public OpenAIClient msoftSdkOpenAiClient(AzureOpenAiProperties azureOpenAiProper
6061
}
6162

6263
@Bean
64+
@ConditionalOnMissingBean
6365
public AiClient azureOpenAiClient(OpenAIClient msoftSdkOpenAiClient, AzureOpenAiProperties azureOpenAiProperties,
6466
RetryTemplate retryTemplate) {
6567
AzureOpenAiClient azureOpenAiClient = new AzureOpenAiClient(msoftSdkOpenAiClient);
@@ -71,6 +73,7 @@ public AiClient azureOpenAiClient(OpenAIClient msoftSdkOpenAiClient, AzureOpenAi
7173
}
7274

7375
@Bean
76+
@ConditionalOnMissingBean
7477
public EmbeddingClient azureOpenAiEmbeddingClient(OpenAIClient msoftSdkOpenAiClient,
7578
AzureOpenAiProperties azureOpenAiProperties, RetryTemplate retryTemplate) {
7679
var embeddingClient = new AzureOpenAiEmbeddingClient(msoftSdkOpenAiClient,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright 2023-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.autoconfigure.openai;
18+
19+
import java.util.List;
20+
21+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
22+
import org.junit.jupiter.params.ParameterizedTest;
23+
import org.junit.jupiter.params.provider.ValueSource;
24+
25+
import org.springframework.ai.client.AiClient;
26+
import org.springframework.ai.client.AiResponse;
27+
import org.springframework.ai.client.RetryAiClient;
28+
import org.springframework.ai.document.Document;
29+
import org.springframework.ai.embedding.EmbeddingClient;
30+
import org.springframework.ai.embedding.EmbeddingResponse;
31+
import org.springframework.ai.embedding.RetryEmbeddingClient;
32+
import org.springframework.ai.openai.client.OpenAiClient;
33+
import org.springframework.ai.openai.embedding.OpenAiEmbeddingClient;
34+
import org.springframework.ai.prompt.Prompt;
35+
import org.springframework.boot.autoconfigure.AutoConfigurations;
36+
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
37+
38+
import static org.assertj.core.api.Assertions.assertThat;
39+
40+
/**
41+
* @author Christian Tzolov
42+
*/
43+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
44+
public class OpenAiIT {
45+
46+
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
47+
.withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class))
48+
.withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"));
49+
50+
@ParameterizedTest(name = "{0} : {displayName} ")
51+
@ValueSource(booleans = { false, true })
52+
public void embeddingClient(boolean retryEnabled) {
53+
contextRunner.withPropertyValues("spring.ai.openai.retryEnabled=" + retryEnabled).run(context -> {
54+
OpenAiProperties properties = context.getBean(OpenAiProperties.class);
55+
assertThat(properties.isRetryEnabled()).isEqualTo(retryEnabled);
56+
57+
EmbeddingClient embeddingClient = context.getBean(EmbeddingClient.class);
58+
if (retryEnabled) {
59+
assertThat(embeddingClient).isInstanceOf(RetryEmbeddingClient.class);
60+
}
61+
else {
62+
assertThat(embeddingClient).isInstanceOf(OpenAiEmbeddingClient.class);
63+
}
64+
65+
List<List<Double>> embeddings = embeddingClient.embed(List.of("Spring Framework", "Spring AI"));
66+
67+
assertThat(embeddings.size()).isEqualTo(2); // batch size
68+
assertThat(embeddings.get(0).size()).isEqualTo(embeddingClient.dimensions()); // dimensions
69+
70+
List<Double> embedding = embeddingClient.embed(new Document("test"));
71+
assertThat(embedding).hasSize(embeddingClient.dimensions());
72+
73+
embedding = embeddingClient.embed("test");
74+
assertThat(embedding).hasSize(embeddingClient.dimensions());
75+
76+
EmbeddingResponse response = embeddingClient.embedForResponse(List.of("test1", "test2"));
77+
78+
assertThat(response).isNotNull();
79+
assertThat(response.data()).hasSize(2);
80+
assertThat(response.data().get(0).embedding()).hasSize(embeddingClient.dimensions());
81+
assertThat(response.data().get(1).embedding()).hasSize(embeddingClient.dimensions());
82+
});
83+
}
84+
85+
@ParameterizedTest(name = "{0} : {displayName} ")
86+
@ValueSource(booleans = { false, true })
87+
public void aiClient(boolean retryEnabled) {
88+
contextRunner.withPropertyValues("spring.ai.openai.retryEnabled=" + retryEnabled).run(context -> {
89+
OpenAiProperties properties = context.getBean(OpenAiProperties.class);
90+
assertThat(properties.isRetryEnabled()).isEqualTo(retryEnabled);
91+
92+
AiClient aiClient = context.getBean(AiClient.class);
93+
if (retryEnabled) {
94+
assertThat(aiClient).isInstanceOf(RetryAiClient.class);
95+
}
96+
else {
97+
assertThat(aiClient).isInstanceOf(OpenAiClient.class);
98+
}
99+
100+
AiResponse response = aiClient.generate(new Prompt("content"));
101+
102+
assertThat(response).isNotNull();
103+
assertThat(response.getGeneration()).isNotNull();
104+
});
105+
}
106+
107+
}

0 commit comments

Comments
 (0)