Skip to content

Bedrock API: Add option to configure API endpoint (#1018) #1063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.ai.bedrock.anthropic.api;

import java.net.URI;
import java.time.Duration;
import java.util.List;

Expand Down Expand Up @@ -109,6 +110,21 @@ public AnthropicChatBedrockApi(String modelId, AwsCredentialsProvider credential
super(modelId, credentialsProvider, region, objectMapper, timeout);
}

/**
* Create a new AnthropicChatBedrockApi instance using the provided credentials provider, region and object mapper.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can also specify here that we require a new parameter, the endpoint override.

*
* @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models.
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
* @param timeout The timeout to use.
* @param endpointOverride The endpoint to use.
*/
public AnthropicChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
ObjectMapper objectMapper, Duration timeout, URI endpointOverride) {
super(modelId, credentialsProvider, region, objectMapper, timeout, endpointOverride);
}

// https://github.com/build-on-aws/amazon-bedrock-java-examples/blob/main/example_code/bedrock-runtime/src/main/java/aws/community/examples/InvokeBedrockStreamingAsync.java

// https://docs.anthropic.com/claude/reference/complete_post
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;

import java.net.URI;
import java.time.Duration;
import java.util.List;

Expand Down Expand Up @@ -113,6 +114,21 @@ public Anthropic3ChatBedrockApi(String modelId, AwsCredentialsProvider credentia
super(modelId, credentialsProvider, region, objectMapper, timeout);
}

/**
* Create a new AnthropicChatBedrockApi instance using the provided credentials provider, region and object mapper.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

*
* @param modelId The model id to use. See the {@link AnthropicChatModel} for the supported models.
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
* @param timeout The timeout to use.
* @param endpointOverride The endpoint to use.
*/
public Anthropic3ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
ObjectMapper objectMapper, Duration timeout, URI endpointOverride) {
super(modelId, credentialsProvider, region, objectMapper, timeout, endpointOverride);
}

// https://github.com/build-on-aws/amazon-bedrock-java-examples/blob/main/example_code/bedrock-runtime/src/main/java/aws/community/examples/InvokeBedrockStreamingAsync.java

// https://docs.anthropic.com/claude/reference/complete_post
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.bedrock.api;

import java.io.UncheckedIOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.Duration;

Expand Down Expand Up @@ -107,6 +108,11 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv
this(modelId, credentialsProvider, region, objectMapper, Duration.ofMinutes(5));
}

public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
ObjectMapper objectMapper, Duration timeout) {
this(modelId, credentialsProvider, region, objectMapper, timeout, null);
}

/**
* Create a new AbstractBedrockApi instance using the provided credentials provider, region and object mapper.
*
Expand All @@ -120,7 +126,7 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv
*/
public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region,
ObjectMapper objectMapper, Duration timeout) {
this(modelId, credentialsProvider, Region.of(region), objectMapper, timeout);
this(modelId, credentialsProvider, Region.of(region), objectMapper, timeout, null);
}

/**
Expand All @@ -135,7 +141,7 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv
* all HTTP requests including retries, unmarshalling, etc. This value should always be positive, if present.
*/
public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
ObjectMapper objectMapper, Duration timeout) {
ObjectMapper objectMapper, Duration timeout, URI endpointOverride) {

Assert.hasText(modelId, "Model id must not be empty");
Assert.notNull(credentialsProvider, "Credentials provider must not be null");
Expand All @@ -152,12 +158,14 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv
.region(this.region)
.credentialsProvider(credentialsProvider)
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
.endpointOverride(endpointOverride)
.build();

this.clientStreaming = BedrockRuntimeAsyncClient.builder()
.region(this.region)
.credentialsProvider(credentialsProvider)
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
.endpointOverride(endpointOverride)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// @formatter:off
package org.springframework.ai.bedrock.cohere.api;

import java.net.URI;
import java.time.Duration;
import java.util.List;

Expand Down Expand Up @@ -108,6 +109,21 @@ public CohereChatBedrockApi(String modelId, AwsCredentialsProvider credentialsPr
super(modelId, credentialsProvider, region, objectMapper, timeout);
}

/**
* Create a new CohereChatBedrockApi instance using the provided credentials provider, region and object mapper.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

*
* @param modelId The model id to use. See the {@link CohereChatModel} for the supported models.
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
* @param timeout The timeout to use.
* @param endpointOverride The endpoint to use.
*/
public CohereChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
ObjectMapper objectMapper, Duration timeout, URI endpointOverride) {
super(modelId, credentialsProvider, region, objectMapper, timeout, endpointOverride);
}

/**
* CohereChatRequest encapsulates the request parameters for the Cohere command model.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// @formatter:off
package org.springframework.ai.bedrock.cohere.api;

import java.net.URI;
import java.time.Duration;
import java.util.List;

Expand Down Expand Up @@ -109,6 +110,22 @@ public CohereEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credenti
super(modelId, credentialsProvider, region, objectMapper, timeout);
}

/**
* Create a new CohereEmbeddingBedrockApi instance using the provided credentials provider, region and object
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

* mapper.
*
* @param modelId The model id to use. See the {@link CohereEmbeddingModel} for the supported models.
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
* @param timeout The timeout to use.
* @param endpointOverride The endpoint to use.
*/
public CohereEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
ObjectMapper objectMapper, Duration timeout, URI endpointOverride) {
super(modelId, credentialsProvider, region, objectMapper, timeout, endpointOverride);
}

/**
* The Cohere Embed model request.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// @formatter:off
package org.springframework.ai.bedrock.jurassic2.api;

import java.net.URI;
import java.time.Duration;
import java.util.List;

Expand Down Expand Up @@ -109,6 +110,21 @@ public Ai21Jurassic2ChatBedrockApi(String modelId, AwsCredentialsProvider creden
super(modelId, credentialsProvider, region, objectMapper, timeout);
}

/**
* Create a new Ai21Jurassic2ChatBedrockApi instance.
*
* @param modelId The model id to use. See the {@link Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatModel} for the supported models.
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
* @param timeout The timeout to use.
* @param endpointOverride The endpoint to use.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"The endpoint override to use."

*/
public Ai21Jurassic2ChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
ObjectMapper objectMapper, Duration timeout, URI endpointOverride) {
super(modelId, credentialsProvider, region, objectMapper, timeout, endpointOverride);
}

/**
* AI21 Jurassic2 chat request parameters.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse;
import org.springframework.ai.model.ChatModelDescription;

import java.net.URI;
import java.time.Duration;

// @formatter:off
Expand Down Expand Up @@ -106,6 +107,21 @@ public LlamaChatBedrockApi(String modelId, AwsCredentialsProvider credentialsPro
super(modelId, credentialsProvider, region, objectMapper, timeout);
}

/**
* Create a new LlamaChatBedrockApi instance using the provided credentials provider, region and object mapper.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same small changes here and also all places similarly set below.

*
* @param modelId The model id to use. See the {@link LlamaChatModel} for the supported models.
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
* @param timeout The timeout to use.
* @param endpointOverride The endpoint to use.
*/
public LlamaChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
ObjectMapper objectMapper, Duration timeout, URI endpointOverride) {
super(modelId, credentialsProvider, region, objectMapper, timeout, endpointOverride);
}

/**
* LlamaChatRequest encapsulates the request parameters for the Meta Llama chat model.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.ai.bedrock.titan.api;

import java.net.URI;
import java.time.Duration;
import java.util.List;

Expand Down Expand Up @@ -109,6 +110,21 @@ public TitanChatBedrockApi(String modelId, AwsCredentialsProvider credentialsPro
super(modelId, credentialsProvider, region, objectMapper, timeout);
}

/**
* Create a new TitanEmbeddingBedrockApi instance.
*
* @param modelId The model id to use. See the {@link TitanEmbeddingBedrockApi.TitanEmbeddingModel} for the supported models.
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
* @param timeout The timeout to use.
* @param endpointOverride The endpoint to use.
*/
public TitanChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
ObjectMapper objectMapper, Duration timeout, URI endpointOverride) {
super(modelId, credentialsProvider, region, objectMapper, timeout, endpointOverride);
}

/**
* TitanChatRequest encapsulates the request parameters for the Titan chat model.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.ai.bedrock.titan.api;

import java.net.URI;
import java.time.Duration;
import java.util.List;

Expand Down Expand Up @@ -81,6 +82,21 @@ public TitanEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentia
super(modelId, credentialsProvider, region, objectMapper, timeout);
}

/**
* Create a new TitanEmbeddingBedrockApi instance.
*
* @param modelId The model id to use. See the {@link TitanEmbeddingModel} for the supported models.
* @param credentialsProvider The credentials provider to connect to AWS.
* @param region The AWS region to use.
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
* @param timeout The timeout to use.
* @param endpointOverride The endpoint to use.
*/
public TitanEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
ObjectMapper objectMapper, Duration timeout, URI endpointOverride) {
super(modelId, credentialsProvider, region, objectMapper, timeout, endpointOverride);
}

/**
* Titan Embedding request parameters.
*
Expand Down