Skip to content

add aws bedrock prompt-caching support #3213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,6 @@

package org.springframework.ai.bedrock.converse;

import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLConnection;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Set;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
Expand Down Expand Up @@ -73,6 +61,8 @@
import software.amazon.awssdk.services.bedrockruntime.model.VideoBlock;
import software.amazon.awssdk.services.bedrockruntime.model.VideoFormat;
import software.amazon.awssdk.services.bedrockruntime.model.VideoSource;
import software.amazon.awssdk.services.bedrockruntime.model.CachePointBlock;
import software.amazon.awssdk.services.bedrockruntime.model.CachePointType;

import org.springframework.ai.bedrock.converse.api.BedrockMediaFormat;
import org.springframework.ai.bedrock.converse.api.ConverseApiUtils;
Expand All @@ -96,17 +86,33 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.content.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.model.tool.*;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StreamUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;
import reactor.core.publisher.Sinks.EmitFailureHandler;
import reactor.core.scheduler.Schedulers;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.*;

import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLConnection;
import java.time.Duration;
import java.util.*;

/**
* A {@link ChatModel} implementation that uses the Amazon Bedrock Converse API to
Expand All @@ -127,12 +133,15 @@
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
* <p>
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
* <p>
* https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
*
* @author Christian Tzolov
* @author Wei Jiang
* @author Alexandros Pappas
* @author Jihoon Kim
* @author Soby Chacko
* @author Brave Lin [email protected]
* @since 1.0.0
*/
public class BedrockProxyChatModel implements ChatModel {
Expand All @@ -150,34 +159,34 @@ public class BedrockProxyChatModel implements ChatModel {
private ToolCallingChatOptions defaultOptions;

/**
* Observation registry used for instrumentation.
*/
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

private final ToolCallingManager toolCallingManager;

/**
* The tool execution eligibility predicate used to determine if a tool can be
* executed.
*/
* The tool execution eligibility predicate used to determine if a tool can be
* executed.
*/
private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate;

/**
* Conventions to use for generating observations.
*/
* Conventions to use for generating observations.
*/
private ChatModelObservationConvention observationConvention;

public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions,
ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager) {
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions,
ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager) {
this(bedrockRuntimeClient, bedrockRuntimeAsyncClient, defaultOptions, observationRegistry, toolCallingManager,
new DefaultToolExecutionEligibilityPredicate());
}

public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions,
ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager,
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions,
ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager,
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {

Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null");
Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null");
Expand All @@ -203,13 +212,14 @@ private static ToolCallingChatOptions from(ChatOptions options) {
}

/**
* Invoke the model and return the response.
*
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
* https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient.html#converse
* @return The model invocation response.
*/
* Invoke the model and return the response.
* <p>
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
* https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient.html#converse
*
* @return The model invocation response.
*/
@Override
public ChatResponse call(Prompt prompt) {
Prompt requestPrompt = buildRequestPrompt(prompt);
Expand Down Expand Up @@ -384,7 +394,19 @@ else if (message.getMessageType() == MessageType.TOOL) {
List<SystemContentBlock> systemMessages = prompt.getInstructions()
.stream()
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
.map(sysMessage -> SystemContentBlock.builder().text(sysMessage.getText()).build())
.map(sysMessage -> {
/**
* add CachePointBlock support
* url: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
*/
if(sysMessage.getMetadata()!=null&&sysMessage.getMetadata().get(ConverseApiUtils.CACHE_POINT)!=null){
return SystemContentBlock.fromCachePoint(CachePointBlock.builder()
.type(CachePointType.DEFAULT)
.build());
}else{
return SystemContentBlock.builder().text(sysMessage.getText()).build();
}
})
.toList();

ToolCallingChatOptions updatedRuntimeOptions = prompt.getOptions().copy();
Expand Down Expand Up @@ -551,12 +573,13 @@ else if (mediaData instanceof URL url) {
}

/**
* Convert {@link ConverseResponse} to {@link ChatResponse} includes model output,
* stopReason, usage, metrics etc.
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseSyntax
* @param response The Bedrock Converse response.
* @return The ChatResponse entity.
*/
* Convert {@link ConverseResponse} to {@link ChatResponse} includes model output,
* stopReason, usage, metrics etc.
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseSyntax
*
* @param response The Bedrock Converse response.
* @return The ChatResponse entity.
*/
private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perviousChatResponse) {

Assert.notNull(response, "'response' must not be null.");
Expand Down Expand Up @@ -630,13 +653,14 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv
}

/**
* Invoke the model and return the response stream.
*
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
* https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream
* @return The model invocation response stream.
*/
* Invoke the model and return the response stream.
* <p>
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
* https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream
*
* @return The model invocation response stream.
*/
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
Prompt requestPrompt = buildRequestPrompt(prompt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,8 @@

package org.springframework.ai.bedrock.converse.api;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDelta;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStart;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStopEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetrics;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput.EventType;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler.Visitor;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamTrace;
import software.amazon.awssdk.services.bedrockruntime.model.MessageStartEvent;
import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlockStart;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
Expand All @@ -56,17 +28,37 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.services.bedrockruntime.model.*;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput.EventType;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler.Visitor;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

/**
* Amazon Bedrock Converse API utils.
*
* @author Wei Jiang
* @author Christian Tzolov
* @author Alexandros Pappas
* @author Brave Lin
* @since 1.0.0
*/
public final class ConverseApiUtils {

//cachePoint support https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
public static final String CACHE_POINT="cachePoint";

public static final ChatResponse EMPTY_CHAT_RESPONSE = ChatResponse.builder()
.generations(List.of())
.metadata("empty", true)
Expand All @@ -76,6 +68,17 @@ private ConverseApiUtils() {

}

/**
* buidl aws bedrock prompt-caching
* url: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
* @return
*/
public static SystemMessage buildCachePointMesssage(){
SystemMessage message = new SystemMessage(CACHE_POINT);
message.getMetadata().put(CACHE_POINT, CachePointType.DEFAULT);
return message;
}

public static boolean isToolUseStart(ConverseStreamOutput event) {
if (event == null || event.sdkEventType() == null || event.sdkEventType() != EventType.CONTENT_BLOCK_START) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,17 @@

package org.springframework.ai.bedrock.converse;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.bedrock.converse.api.ConverseApiUtils;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
Expand All @@ -55,6 +48,14 @@
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.util.MimeTypeUtils;
import reactor.core.publisher.Flux;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -347,4 +348,37 @@ record ActorsFilmsRecord(String actor, List<String> movies) {

}

/**
* @author Brave Lin
* URL:https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
* @param modelName
*/
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "us.anthropic.claude-3-7-sonnet-20250219-v1:0" })
void cachePointTest(String modelName) {
String systemMessageStr= """
You are a helpful AI assistant. Your name is spring.
You are an AI assistant that helps people find information.
Your name is spring
You should reply to the user's request with your name and also in the style of a pirate.
""";
//Loop 50 times to make the token reach the minimum condition for using the cache
String newSystemMessage=systemMessageStr.repeat(50);

UserMessage userMessage = new UserMessage(
"Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.");
Message systemMessage = new SystemMessage(newSystemMessage);
Prompt prompt = new Prompt(List.of(
systemMessage,
ConverseApiUtils.buildCachePointMesssage(),//add chache point
userMessage),
ToolCallingChatOptions.builder().model(modelName).build());
ChatResponse response = this.chatModel.call(prompt);
assertThat(response.getResults()).hasSize(1);
Generation generation = response.getResults().get(0);
assertThat(generation.getOutput().getText()).contains("Blackbeard");
assertThat(generation.getMetadata().getFinishReason()).isEqualTo("end_turn");
logger.info(response.toString());
}

}