Skip to content

feat(rag:etl): Add custom template support to KeywordMetadataEnricher #3252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -709,18 +709,26 @@ class MyKeywordEnricher {
}

List<Document> enrichDocuments(List<Document> documents) {
KeywordMetadataEnricher enricher = new KeywordMetadataEnricher(this.chatModel, 5);
KeywordMetadataEnricher enricher = KeywordMetadataEnricher.builder(chatModel)
.keywordCount(5)
.build();

// Or use custom templates
KeywordMetadataEnricher enricher = KeywordMetadataEnricher.builder(chatModel)
.keywordsTemplate(YOUR_CUSTOM_TEMPLATE)
.build();

return enricher.apply(documents);
}
}
----

==== Constructor
==== Constructor Options

The `KeywordMetadataEnricher` constructor takes two parameters:
The `KeywordMetadataEnricher` provides two constructor options:

1. `ChatModel chatModel`: The AI model used for generating keywords.
2. `int keywordCount`: The number of keywords to extract for each document.
1. `KeywordMetadataEnricher(ChatModel chatModel, int keywordCount)`: To use the default template and extract a specified number of keywords.
2. `KeywordMetadataEnricher(ChatModel chatModel, PromptTemplate keywordsTemplate)`: To use a custom template for keyword extraction.

==== Behavior

Expand All @@ -734,7 +742,8 @@ The `KeywordMetadataEnricher` processes documents as follows:

==== Customization

The keyword extraction prompt can be customized by modifying the `KEYWORDS_TEMPLATE` constant in the class. The default template is:
You can use the default template or customize the template through the keywordsTemplate parameter.
The default template is:

[source,java]
----
Expand All @@ -748,7 +757,14 @@ Where `+{context_str}+` is replaced with the document content, and `%s` is repla
[source,java]
----
ChatModel chatModel = // initialize your chat model
KeywordMetadataEnricher enricher = new KeywordMetadataEnricher(chatModel, 5);
KeywordMetadataEnricher enricher = KeywordMetadataEnricher.builder(chatModel)
.keywordCount(5)
.build();

// Or use custom templates
KeywordMetadataEnricher enricher = KeywordMetadataEnricher.builder(chatModel)
.keywordsTemplate(new PromptTemplate("Extract 5 important keywords from the following text and separate them with commas:\n{context_str}"))
.build();

Document doc = new Document("This is a document about artificial intelligence and its applications in modern technology.");

Expand All @@ -766,6 +782,7 @@ System.out.println("Extracted keywords: " + keywords);
* The enricher adds the "excerpt_keywords" metadata field to each processed document.
* The generated keywords are returned as a comma-separated string.
* This enricher is particularly useful for improving document searchability and for generating tags or categories for documents.
* In the Builder pattern, if the `keywordsTemplate` parameter is set, the `keywordCount` parameter will be ignored.

=== SummaryMetadataEnricher
The `SummaryMetadataEnricher` is a `DocumentTransformer` that uses a generative AI model to create summaries for documents and add them as metadata. It can generate summaries for the current document, as well as adjacent documents (previous and next).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import java.util.List;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
Expand All @@ -30,45 +33,113 @@
* Keyword extractor that uses generative to extract 'excerpt_keywords' metadata field.
*
* @author Christian Tzolov
* @author YunKui Lu
*/
public class KeywordMetadataEnricher implements DocumentTransformer {

private static final Logger logger = LoggerFactory.getLogger(KeywordMetadataEnricher.class);

public static final String CONTEXT_STR_PLACEHOLDER = "context_str";

public static final String KEYWORDS_TEMPLATE = """
{context_str}. Give %s unique keywords for this
document. Format as comma separated. Keywords: """;

private static final String EXCERPT_KEYWORDS_METADATA_KEY = "excerpt_keywords";
public static final String EXCERPT_KEYWORDS_METADATA_KEY = "excerpt_keywords";

/**
* Model predictor
*/
private final ChatModel chatModel;

/**
* The number of keywords to extract.
* The prompt template to use for keyword extraction.
*/
private final int keywordCount;
private final PromptTemplate keywordsTemplate;

/**
* Create a new {@link KeywordMetadataEnricher} instance.
* @param chatModel the model predictor to use for keyword extraction.
* @param keywordCount the number of keywords to extract.
*/
public KeywordMetadataEnricher(ChatModel chatModel, int keywordCount) {
Assert.notNull(chatModel, "ChatModel must not be null");
Assert.isTrue(keywordCount >= 1, "Document count must be >= 1");
Assert.notNull(chatModel, "chatModel must not be null");
Assert.isTrue(keywordCount >= 1, "keywordCount must be >= 1");

this.chatModel = chatModel;
this.keywordsTemplate = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, keywordCount));
}

/**
* Create a new {@link KeywordMetadataEnricher} instance.
* @param chatModel the model predictor to use for keyword extraction.
* @param keywordsTemplate the prompt template to use for keyword extraction.
*/
public KeywordMetadataEnricher(ChatModel chatModel, PromptTemplate keywordsTemplate) {
Assert.notNull(chatModel, "chatModel must not be null");
Assert.notNull(keywordsTemplate, "keywordsTemplate must not be null");

this.chatModel = chatModel;
this.keywordCount = keywordCount;
this.keywordsTemplate = keywordsTemplate;
}

@Override
public List<Document> apply(List<Document> documents) {
for (Document document : documents) {

var template = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, this.keywordCount));
Prompt prompt = template.create(Map.of(CONTEXT_STR_PLACEHOLDER, document.getText()));
Prompt prompt = this.keywordsTemplate.create(Map.of(CONTEXT_STR_PLACEHOLDER, document.getText()));
String keywords = this.chatModel.call(prompt).getResult().getOutput().getText();
document.getMetadata().putAll(Map.of(EXCERPT_KEYWORDS_METADATA_KEY, keywords));
document.getMetadata().put(EXCERPT_KEYWORDS_METADATA_KEY, keywords);
}
return documents;
}

// Exposed for testing purposes
PromptTemplate getKeywordsTemplate() {
return this.keywordsTemplate;
}

public static Builder builder(ChatModel chatModel) {
return new Builder(chatModel);
}

public static class Builder {

private final ChatModel chatModel;

private int keywordCount;

private PromptTemplate keywordsTemplate;

public Builder(ChatModel chatModel) {
Assert.notNull(chatModel, "The chatModel must not be null");
this.chatModel = chatModel;
}

public Builder keywordCount(int keywordCount) {
Assert.isTrue(keywordCount >= 1, "The keywordCount must be >= 1");
this.keywordCount = keywordCount;
return this;
}

public Builder keywordsTemplate(PromptTemplate keywordsTemplate) {
Assert.notNull(keywordsTemplate, "The keywordsTemplate must not be null");
this.keywordsTemplate = keywordsTemplate;
return this;
}

public KeywordMetadataEnricher build() {
if (this.keywordsTemplate != null) {

if (this.keywordCount != 0) {
logger.warn("keywordCount will be ignored as keywordsTemplate is set.");
}

return new KeywordMetadataEnricher(this.chatModel, this.keywordsTemplate);
}

return new KeywordMetadataEnricher(this.chatModel, this.keywordCount);
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package org.springframework.ai.model.transformer;

import java.util.List;
import java.util.Map;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.*;
import static org.springframework.ai.model.transformer.KeywordMetadataEnricher.*;

/**
* @author YunKui Lu
*/
@ExtendWith(MockitoExtension.class)
class KeywordMetadataEnricherTest {

@Mock
private ChatModel chatModel;

@Captor
private ArgumentCaptor<Prompt> promptCaptor;

private final String CUSTOM_TEMPLATE = "Custom template: {context_str}";

@Test
void testUseWithDefaultTemplate() {
// 1. Prepare test data
// @formatter:off
List<Document> documents = List.of(
new Document("content1"),
new Document("content2"),
new Document("content3"));// @formatter:on
int keywordCount = 3;

// 2. Mock
given(chatModel.call(any(Prompt.class))).willReturn(
new ChatResponse(List.of(new Generation(new AssistantMessage("keyword1-1, keyword1-2, keyword1-3")))),
new ChatResponse(List.of(new Generation(new AssistantMessage("keyword2-1, keyword2-2, keyword2-3")))),
new ChatResponse(List.of(new Generation(new AssistantMessage("keyword3-1, keyword3-2, keyword3-3")))));

// 3. Create instance
KeywordMetadataEnricher keywordMetadataEnricher = new KeywordMetadataEnricher(chatModel, keywordCount);

// 4. Apply
keywordMetadataEnricher.apply(documents);

// 5. Assert
verify(chatModel, times(3)).call(promptCaptor.capture());

assertThat(promptCaptor.getAllValues().get(0).getUserMessage().getText())
.isEqualTo(getDefaultTemplatePromptText(keywordCount, "content1"));
assertThat(promptCaptor.getAllValues().get(1).getUserMessage().getText())
.isEqualTo(getDefaultTemplatePromptText(keywordCount, "content2"));
assertThat(promptCaptor.getAllValues().get(2).getUserMessage().getText())
.isEqualTo(getDefaultTemplatePromptText(keywordCount, "content3"));

assertThat(documents.get(0).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY,
"keyword1-1, keyword1-2, keyword1-3");
assertThat(documents.get(1).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY,
"keyword2-1, keyword2-2, keyword2-3");
assertThat(documents.get(2).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY,
"keyword3-1, keyword3-2, keyword3-3");
}

@Test
void testUseCustomTemplate() {
// 1. Prepare test data
// @formatter:off
List<Document> documents = List.of(
new Document("content1"),
new Document("content2"),
new Document("content3"));// @formatter:on
PromptTemplate promptTemplate = new PromptTemplate(CUSTOM_TEMPLATE);

// 2. Mock
given(chatModel.call(any(Prompt.class))).willReturn(
new ChatResponse(List.of(new Generation(new AssistantMessage("keyword1-1, keyword1-2, keyword1-3")))),
new ChatResponse(List.of(new Generation(new AssistantMessage("keyword2-1, keyword2-2, keyword2-3")))),
new ChatResponse(List.of(new Generation(new AssistantMessage("keyword3-1, keyword3-2, keyword3-3")))));

// 3. Create instance
KeywordMetadataEnricher keywordMetadataEnricher = new KeywordMetadataEnricher(this.chatModel, promptTemplate);

// 4. Apply
keywordMetadataEnricher.apply(documents);

// 5. Assert
verify(chatModel, times(documents.size())).call(promptCaptor.capture());

assertThat(promptCaptor.getAllValues().get(0).getUserMessage().getText())
.isEqualTo("Custom template: content1");
assertThat(promptCaptor.getAllValues().get(1).getUserMessage().getText())
.isEqualTo("Custom template: content2");
assertThat(promptCaptor.getAllValues().get(2).getUserMessage().getText())
.isEqualTo("Custom template: content3");

assertThat(documents.get(0).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY,
"keyword1-1, keyword1-2, keyword1-3");
assertThat(documents.get(1).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY,
"keyword2-1, keyword2-2, keyword2-3");
assertThat(documents.get(2).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY,
"keyword3-1, keyword3-2, keyword3-3");
}

@Test
void testConstructorThrowsException() {
assertThrows(IllegalArgumentException.class, () -> new KeywordMetadataEnricher(null, 3),
"chatModel must not be null");

assertThrows(IllegalArgumentException.class, () -> new KeywordMetadataEnricher(chatModel, 0),
"keywordCount must be >= 1");

assertThrows(IllegalArgumentException.class, () -> new KeywordMetadataEnricher(chatModel, null),
"keywordsTemplate must not be null");
}

@Test
void testBuilderThrowsException() {
assertThrows(IllegalArgumentException.class, () -> KeywordMetadataEnricher.builder(null),
"The chatModel must not be null");

Builder builder = builder(chatModel);
assertThrows(IllegalArgumentException.class, () -> builder.keywordCount(0), "The keywordCount must be >= 1");

assertThrows(IllegalArgumentException.class, () -> builder.keywordsTemplate(null),
"The keywordsTemplate must not be null");
}

@Test
void testBuilderWithKeywordCount() {
int keywordCount = 3;
KeywordMetadataEnricher enricher = builder(chatModel).keywordCount(keywordCount).build();

assertThat(enricher.getKeywordsTemplate().getTemplate())
.isEqualTo(String.format(KEYWORDS_TEMPLATE, keywordCount));
}

@Test
void testBuilderWithKeywordsTemplate() {
PromptTemplate template = new PromptTemplate(CUSTOM_TEMPLATE);
KeywordMetadataEnricher enricher = builder(chatModel).keywordsTemplate(template).build();

assertThat(enricher).extracting("chatModel", "keywordsTemplate").containsExactly(chatModel, template);
}

private String getDefaultTemplatePromptText(int keywordCount, String documentContent) {
PromptTemplate promptTemplate = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, keywordCount));
Prompt prompt = promptTemplate.create(Map.of(CONTEXT_STR_PLACEHOLDER, documentContent));
return prompt.getContents();
}

}