Skip to content

Commit 404eaf2

Browse files
committed
Add operations retry capabilities for the AiClient, EmbeddingClient and VectorStore clients
- Refactor the EmbeddingResponse and Embedding from “class” into “record” types. Fix the affected code get calls. - Refactor the TransformersEmbeddingClient to avoid telescopic public methods call anti-patter. - Add spring-retry dependency to spring-ai-core. - Implement RetryAiClient, RetryEmbeddingClient and RetryVectorStore as decorators around AiClient, EmbeddingClient and VectorStore instance. The retry decorators use RetryTemplate to wrap all delegate methods calls. Default RetryTemplate is provided with each instance. - Add Tests for all Retry decorators. - Add missing javadocs to existing core classes. Resolves #123
1 parent ad7af60 commit 404eaf2

File tree

20 files changed

+783
-117
lines changed

20 files changed

+783
-117
lines changed

embedding-clients/spring-ai-postgresml-embedding-client/src/test/java/org/springframework/ai/embedding/PostgresMlEmbeddingClientIT.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,15 @@ void embedForResponse(String vectorType) {
108108
EmbeddingResponse embeddingResponse = embeddingClient
109109
.embedForResponse(List.of("Hello World!", "Spring AI!", "LLM!"));
110110
assertThat(embeddingResponse).isNotNull();
111-
assertThat(embeddingResponse.getData()).hasSize(3);
112-
assertThat(embeddingResponse.getMetadata()).containsExactlyEntriesOf(
111+
assertThat(embeddingResponse.data()).hasSize(3);
112+
assertThat(embeddingResponse.metadata()).containsExactlyEntriesOf(
113113
Map.of("transformer", "distilbert-base-uncased", "vector-type", vectorType, "kwargs", "{}"));
114-
assertThat(embeddingResponse.getData().get(0).getIndex()).isEqualTo(0);
115-
assertThat(embeddingResponse.getData().get(0).getEmbedding()).hasSize(768);
116-
assertThat(embeddingResponse.getData().get(1).getIndex()).isEqualTo(1);
117-
assertThat(embeddingResponse.getData().get(1).getEmbedding()).hasSize(768);
118-
assertThat(embeddingResponse.getData().get(2).getIndex()).isEqualTo(2);
119-
assertThat(embeddingResponse.getData().get(2).getEmbedding()).hasSize(768);
114+
assertThat(embeddingResponse.data().get(0).index()).isEqualTo(0);
115+
assertThat(embeddingResponse.data().get(0).embedding()).hasSize(768);
116+
assertThat(embeddingResponse.data().get(1).index()).isEqualTo(1);
117+
assertThat(embeddingResponse.data().get(1).embedding()).hasSize(768);
118+
assertThat(embeddingResponse.data().get(2).index()).isEqualTo(2);
119+
assertThat(embeddingResponse.data().get(2).embedding()).hasSize(768);
120120
// embeddingClient.dropPgmlExtension();
121121
}
122122

embedding-clients/transformers-embedding/src/main/java/org/springframework/ai/embedding/TransformersEmbeddingClient.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,18 +194,19 @@ private Resource getCachedResource(Resource resource) {
194194

195195
@Override
196196
public List<Double> embed(String text) {
197-
return embed(List.of(text)).get(0);
197+
return this.internalEmbed(List.of(text)).get(0);
198198
}
199199

200200
@Override
201201
public List<Double> embed(Document document) {
202-
return this.embed(document.getFormattedContent(this.metadataMode));
202+
String content = document.getFormattedContent(this.metadataMode);
203+
return this.internalEmbed(List.of(content)).get(0);
203204
}
204205

205206
@Override
206207
public EmbeddingResponse embedForResponse(List<String> texts) {
207208
List<Embedding> data = new ArrayList<>();
208-
List<List<Double>> embed = this.embed(texts);
209+
List<List<Double>> embed = this.internalEmbed(texts);
209210
for (int i = 0; i < embed.size(); i++) {
210211
data.add(new Embedding(embed.get(i), i));
211212
}
@@ -214,6 +215,10 @@ public EmbeddingResponse embedForResponse(List<String> texts) {
214215

215216
@Override
216217
public List<List<Double>> embed(List<String> texts) {
218+
return this.internalEmbed(texts);
219+
}
220+
221+
private List<List<Double>> internalEmbed(List<String> texts) {
217222

218223
List<List<Double>> resultEmbeddings = new ArrayList<>();
219224

embedding-clients/transformers-embedding/src/test/java/org/springframework/ai/embedding/TransformersEmbeddingClientTests.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,16 @@ void embedForResponse() throws Exception {
7575
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
7676
embeddingClient.afterPropertiesSet();
7777
EmbeddingResponse embed = embeddingClient.embedForResponse(List.of("Hello world", "World is big"));
78-
assertThat(embed.getData()).hasSize(2);
79-
assertThat(embed.getMetadata()).isEmpty();
78+
assertThat(embed.data()).hasSize(2);
79+
assertThat(embed.metadata()).isEmpty();
8080

81-
assertThat(embed.getData().get(0).getEmbedding()).hasSize(384);
82-
assertThat(DF.format(embed.getData().get(0).getEmbedding().get(0))).isEqualTo(DF.format(-0.19744634628295898));
83-
assertThat(DF.format(embed.getData().get(0).getEmbedding().get(383))).isEqualTo(DF.format(0.17298996448516846));
81+
assertThat(embed.data().get(0).embedding()).hasSize(384);
82+
assertThat(DF.format(embed.data().get(0).embedding().get(0))).isEqualTo(DF.format(-0.19744634628295898));
83+
assertThat(DF.format(embed.data().get(0).embedding().get(383))).isEqualTo(DF.format(0.17298996448516846));
8484

85-
assertThat(embed.getData().get(1).getEmbedding()).hasSize(384);
86-
assertThat(DF.format(embed.getData().get(1).getEmbedding().get(0))).isEqualTo(DF.format(0.4293745160102844));
87-
assertThat(DF.format(embed.getData().get(1).getEmbedding().get(383))).isEqualTo(DF.format(0.05501303821802139));
85+
assertThat(embed.data().get(1).embedding()).hasSize(384);
86+
assertThat(DF.format(embed.data().get(1).embedding().get(0))).isEqualTo(DF.format(0.4293745160102844));
87+
assertThat(DF.format(embed.data().get(1).embedding().get(383))).isEqualTo(DF.format(0.05501303821802139));
8888
}
8989

9090
@Test

pom.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
<azure-open-ai-client.version>1.0.0-beta.3</azure-open-ai-client.version>
8888
<jtokkit.version>0.6.1</jtokkit.version>
8989
<victools.version>4.31.1</victools.version>
90+
<spring-retry.version>2.0.4</spring-retry.version>
9091

9192
<!-- readers/writer/stores dependencies-->
9293
<pdfbox.version>3.0.0</pdfbox.version>

spring-ai-core/pom.xml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@
2626

2727
<dependencies>
2828
<!-- production dependencies -->
29+
<dependency>
30+
<groupId>org.springframework.retry</groupId>
31+
<artifactId>spring-retry</artifactId>
32+
<version>${spring-retry.version}</version>
33+
</dependency>
34+
<!-- <dependency>
35+
<groupId>org.springframework</groupId>
36+
<artifactId>spring-aspects</artifactId>
37+
<version>6.0.14</version>
38+
</dependency> -->
2939
<dependency>
3040
<groupId>org.antlr</groupId>
3141
<artifactId>stringtemplate</artifactId>

spring-ai-core/src/main/java/org/springframework/ai/client/AiResponse.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
*/
1616
package org.springframework.ai.client;
1717

18-
import java.util.Collections;
1918
import java.util.List;
20-
import java.util.Map;
2119

2220
import org.springframework.ai.metadata.GenerationMetadata;
2321
import org.springframework.ai.metadata.PromptMetadata;
@@ -56,8 +54,6 @@ public Generation getGeneration() {
5654
}
5755

5856
/**
59-
* Returns {@link GenerationMetadata} containing information about the use of the AI
60-
* provider's API.
6157
* @return {@link GenerationMetadata} containing information about the use of the AI
6258
* provider's API.
6359
*/
@@ -66,8 +62,6 @@ public GenerationMetadata getGenerationMetadata() {
6662
}
6763

6864
/**
69-
* Returns {@link PromptMetadata} containing information on prompt processing by the
70-
* AI.
7165
* @return {@link PromptMetadata} containing information on prompt processing by the
7266
* AI.
7367
*/
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright 2023-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.client;
18+
19+
import java.time.Duration;
20+
21+
import org.springframework.ai.prompt.Prompt;
22+
import org.springframework.retry.support.RetryTemplate;
23+
24+
/**
25+
* The {@link RetryAiClient} is a {@link AiClient} decorator that automatically re-invoke
26+
* the failed generate operations according to pre-configured retry policies. This is
27+
* helpful transient errors such as a momentary network glitch.
28+
*
29+
* @author Christian Tzolov
30+
*/
31+
public class RetryAiClient implements AiClient {
32+
33+
private final RetryTemplate retryTemplate;
34+
35+
private final AiClient delegate;
36+
37+
public RetryAiClient(AiClient delegate) {
38+
this(RetryTemplate.builder()
39+
.maxAttempts(10)
40+
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
41+
.build(), delegate);
42+
}
43+
44+
public RetryAiClient(RetryTemplate retryTemplate, AiClient delegate) {
45+
this.retryTemplate = retryTemplate;
46+
this.delegate = delegate;
47+
}
48+
49+
@Override
50+
public AiResponse generate(Prompt prompt) {
51+
return this.retryTemplate.execute(ctx -> {
52+
return this.delegate.generate(prompt);
53+
});
54+
}
55+
56+
}
Lines changed: 19 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,24 @@
1+
/*
2+
* Copyright 2023-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
116
package org.springframework.ai.embedding;
217

318
import java.util.List;
4-
import java.util.Objects;
5-
6-
public class Embedding {
7-
8-
private List<Double> embedding;
9-
10-
private Integer index;
11-
12-
public Embedding(List<Double> embedding, Integer index) {
13-
this.embedding = embedding;
14-
this.index = index;
15-
}
16-
17-
public List<Double> getEmbedding() {
18-
return embedding;
19-
}
20-
21-
public Integer getIndex() {
22-
return index;
23-
}
24-
25-
@Override
26-
public boolean equals(Object o) {
27-
if (this == o)
28-
return true;
29-
if (o == null || getClass() != o.getClass())
30-
return false;
31-
Embedding embedding1 = (Embedding) o;
32-
return Objects.equals(embedding, embedding1.embedding) && Objects.equals(index, embedding1.index);
33-
}
34-
35-
@Override
36-
public int hashCode() {
37-
return Objects.hash(embedding, index);
38-
}
39-
40-
@Override
41-
public String toString() {
42-
String message = this.embedding.size() == 0 ? "<empty>" : "<has data>";
43-
return "Embedding{" + "embedding=" + message + ", index=" + index + '}';
44-
}
4519

20+
/**
21+
* Represents a single embedding response.
22+
*/
23+
public record Embedding(List<Double> embedding, Integer index) {
4624
}

spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,63 @@
1+
/*
2+
* Copyright 2023-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
116
package org.springframework.ai.embedding;
217

318
import org.springframework.ai.document.Document;
419

520
import java.util.List;
621

22+
/**
23+
* Converts a {@link Document} text into its vector representation (e.g. embedding).
24+
*/
725
public interface EmbeddingClient {
826

27+
/**
28+
* Computes embedding for provided, raw, text.
29+
* @param text Input text to commute the embedding for.
30+
* @return Returns a raw, double, list of the embedding representation of the input
31+
* text.
32+
*/
933
List<Double> embed(String text);
1034

35+
/**
36+
* Computes embedding for provided document.
37+
* @param document Document to commute the embedding for.
38+
* @return Returns a raw, double, list of the embedding representation of the input
39+
* text.
40+
*/
1141
List<Double> embed(Document document);
1242

43+
/**
44+
* Computes embeddings for provided list of text.
45+
* @param texts list of input text to compute embeddings for.
46+
* @return Returns a list of embeddings. The order corresponds to the input text list.
47+
*/
1348
List<List<Double>> embed(List<String> texts);
1449

50+
/**
51+
* Computes embeddings for provided list of text.
52+
* @param texts list of input text to compute embeddings for.
53+
* @return Returns
54+
*/
1555
EmbeddingResponse embedForResponse(List<String> texts);
1656

57+
/**
58+
* Retrieves embedding model's dimensions.
59+
* @return Returns the vector dimensions for the configured embedding model.
60+
*/
1761
default int dimensions() {
1862
return embed("Test String").size();
1963
}
Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,26 @@
1+
/*
2+
* Copyright 2023-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
116
package org.springframework.ai.embedding;
217

3-
import java.util.HashMap;
418
import java.util.List;
519
import java.util.Map;
6-
import java.util.Objects;
7-
8-
public class EmbeddingResponse {
9-
10-
private List<Embedding> data;
11-
12-
private Map<String, Object> metadata = new HashMap<>();
13-
14-
public EmbeddingResponse(List<Embedding> data, Map<String, Object> metadata) {
15-
this.data = data;
16-
this.metadata = metadata;
17-
}
18-
19-
public List<Embedding> getData() {
20-
return data;
21-
}
22-
23-
public Map<String, Object> getMetadata() {
24-
return metadata;
25-
}
26-
27-
@Override
28-
public boolean equals(Object o) {
29-
if (this == o)
30-
return true;
31-
if (o == null || getClass() != o.getClass())
32-
return false;
33-
EmbeddingResponse that = (EmbeddingResponse) o;
34-
return Objects.equals(data, that.data) && Objects.equals(metadata, that.metadata);
35-
}
36-
37-
@Override
38-
public int hashCode() {
39-
return Objects.hash(data, metadata);
40-
}
41-
42-
@Override
43-
public String toString() {
44-
return "EmbeddingResult{" + "data=" + data + ", metadata=" + metadata + '}';
45-
}
4620

21+
/**
22+
* Structured embedding response, containing list of {@link Embedding}s data and related
23+
* metadata.
24+
*/
25+
public record EmbeddingResponse(List<Embedding> data, Map<String, Object> metadata) {
4726
}

0 commit comments

Comments
 (0)