Skip to content

Make STSCredentialsProvider prefetch and stale times configurable #1995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/next-release/feature-AmazonSTS-289d9e7.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"category": "Amazon STS",
"type": "feature",
"description": "Make the STSCredentialsProvider stale and prefetch times configurable so clients can control when session credentials are refreshed"
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
import java.util.function.Function;
import software.amazon.awssdk.annotations.NotThreadSafe;
import software.amazon.awssdk.annotations.SdkInternalApi;
Expand All @@ -31,6 +32,7 @@
import software.amazon.awssdk.utils.cache.NonBlocking;
import software.amazon.awssdk.utils.cache.RefreshResult;


/**
* An implementation of {@link AwsCredentialsProvider} that is extended within this package to provide support for periodically-
* updating session credentials. When credentials get close to expiration, this class will attempt to update them asynchronously
Expand All @@ -40,6 +42,10 @@
@ThreadSafe
@SdkInternalApi
abstract class StsCredentialsProvider implements AwsCredentialsProvider, SdkAutoCloseable {

private static final Duration DEFAULT_STALE_TIME = Duration.ofMinutes(1);
private static final Duration DEFAULT_PREFETCH_TIME = Duration.ofMinutes(5);

/**
* The STS client that should be used for periodically updating the session credentials in the background.
*/
Expand All @@ -50,9 +56,15 @@ abstract class StsCredentialsProvider implements AwsCredentialsProvider, SdkAuto
*/
private final CachedSupplier<SessionCredentialsHolder> sessionCache;

private final Duration staleTime;
private final Duration prefetchTime;

protected StsCredentialsProvider(BaseBuilder<?, ?> builder, String asyncThreadName) {
this.stsClient = Validate.notNull(builder.stsClient, "STS client must not be null.");

this.staleTime = Optional.ofNullable(builder.staleTime).orElse(DEFAULT_STALE_TIME);
this.prefetchTime = Optional.ofNullable(builder.prefetchTime).orElse(DEFAULT_PREFETCH_TIME);

CachedSupplier.Builder<SessionCredentialsHolder> cacheBuilder = CachedSupplier.builder(this::updateSessionCredentials);
if (builder.asyncCredentialUpdateEnabled) {
cacheBuilder.prefetchStrategy(new NonBlocking(asyncThreadName));
Expand All @@ -67,9 +79,10 @@ protected StsCredentialsProvider(BaseBuilder<?, ?> builder, String asyncThreadNa
private RefreshResult<SessionCredentialsHolder> updateSessionCredentials() {
SessionCredentialsHolder credentials = new SessionCredentialsHolder(getUpdatedCredentials(stsClient));
Instant actualTokenExpiration = credentials.getSessionCredentialsExpiration().toInstant();

return RefreshResult.builder(credentials)
.staleTime(actualTokenExpiration.minus(Duration.ofMinutes(1)))
.prefetchTime(actualTokenExpiration.minus(Duration.ofMinutes(5)))
.staleTime(actualTokenExpiration.minus(staleTime))
.prefetchTime(actualTokenExpiration.minus(prefetchTime))
.build();
}

Expand All @@ -83,6 +96,21 @@ public void close() {
sessionCache.close();
}

/**
* The amount of time, relative to STS token expiration, that the cached credentials are considered stale and should no longer be used.
* All threads will block until the value is updated.
*/
public Duration staleTime() {
return staleTime;
}

/**
* The amount of time, relative to STS token expiration, that the cached credentials are considered close to stale and should be updated.
*/
public Duration prefetchTime() {
return prefetchTime;
}

/**
* Implemented by a child class to call STS and get a new set of credentials to be used by this provider.
*/
Expand All @@ -97,6 +125,8 @@ protected abstract static class BaseBuilder<B extends BaseBuilder<B, T>, T> {

private Boolean asyncCredentialUpdateEnabled = false;
private StsClient stsClient;
private Duration staleTime;
private Duration prefetchTime;

protected BaseBuilder(Function<B, T> providerConstructor) {
this.providerConstructor = providerConstructor;
Expand Down Expand Up @@ -127,6 +157,31 @@ public B asyncCredentialUpdateEnabled(Boolean asyncCredentialUpdateEnabled) {
return (B) this;
}

/**
* Configure the amount of time, relative to STS token expiration, that the cached credentials are considered stale and should no longer be used.
* All threads will block until the value is updated.
*
* <p>By default, this is 1 minute.</p>
*/
@SuppressWarnings("unchecked")
public B staleTime(Duration staleTime) {
this.staleTime = staleTime;
return (B) this;
}

/**
* Configure the amount of time, relative to STS token expiration, that the cached credentials are considered close to stale and should be updated.
* See {@link #asyncCredentialUpdateEnabled}.
*
* <p>By default, this is 5 minutes.</p>
*/
@SuppressWarnings("unchecked")
public B prefetchTime(Duration prefetchTime) {
this.prefetchTime = prefetchTime;
return (B) this;
}


/**
* Build the credentials provider using the configuration applied to this builder.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,43 @@ public abstract class StsCredentialsProviderTestBase<RequestT, ResponseT> {

@Test
public void cachingDoesNotApplyToExpiredSession() {
callClientWithCredentialsProvider(Instant.now().minus(Duration.ofSeconds(5)), 2);
callClientWithCredentialsProvider(Instant.now().minus(Duration.ofSeconds(5)), 2, false);
callClient(verify(stsClient, times(2)), Mockito.any());
}

@Test
public void cachingDoesNotApplyToExpiredSession_OverridePrefetchAndStaleTimes() {
callClientWithCredentialsProvider(Instant.now().minus(Duration.ofSeconds(5)), 2, true);
callClient(verify(stsClient, times(2)), Mockito.any());
}

@Test
public void cachingAppliesToNonExpiredSession() {
callClientWithCredentialsProvider(Instant.now().plus(Duration.ofHours(5)), 2);
callClientWithCredentialsProvider(Instant.now().plus(Duration.ofHours(5)), 2, false);
callClient(verify(stsClient, times(1)), Mockito.any());
}

@Test
public void cachingAppliesToNonExpiredSession_OverridePrefetchAndStaleTimes() {
callClientWithCredentialsProvider(Instant.now().plus(Duration.ofHours(5)), 2, true);
callClient(verify(stsClient, times(1)), Mockito.any());
}

@Test
public void distantExpiringCredentialsUpdatedInBackground() throws InterruptedException {
callClientWithCredentialsProvider(Instant.now().plusSeconds(90), 2);
callClientWithCredentialsProvider(Instant.now().plusSeconds(90), 2, false);

Instant endCheckTime = Instant.now().plus(Duration.ofSeconds(5));
while (Mockito.mockingDetails(stsClient).getInvocations().size() < 2 && endCheckTime.isAfter(Instant.now())) {
Thread.sleep(100);
}

callClient(verify(stsClient, times(2)), Mockito.any());
}

@Test
public void distantExpiringCredentialsUpdatedInBackground_OverridePrefetchAndStaleTimes() throws InterruptedException {
callClientWithCredentialsProvider(Instant.now().plusSeconds(90), 2, true);

Instant endCheckTime = Instant.now().plus(Duration.ofSeconds(5));
while (Mockito.mockingDetails(stsClient).getInvocations().size() < 2 && endCheckTime.isAfter(Instant.now())) {
Expand All @@ -72,14 +96,32 @@ public void distantExpiringCredentialsUpdatedInBackground() throws InterruptedEx

protected abstract ResponseT callClient(StsClient client, RequestT request);

public void callClientWithCredentialsProvider(Instant credentialsExpirationDate, int numTimesInvokeCredentialsProvider) {
public void callClientWithCredentialsProvider(Instant credentialsExpirationDate, int numTimesInvokeCredentialsProvider, boolean overrideStaleAndPrefetchTimes) {
Credentials credentials = Credentials.builder().accessKeyId("a").secretAccessKey("b").sessionToken("c").expiration(credentialsExpirationDate).build();
RequestT request = getRequest();
ResponseT response = getResponse(credentials);

when(callClient(stsClient, request)).thenReturn(response);

try (StsCredentialsProvider credentialsProvider = createCredentialsProviderBuilder(request).stsClient(stsClient).build()) {
StsCredentialsProvider.BaseBuilder<?, ? extends StsCredentialsProvider> credentialsProviderBuilder = createCredentialsProviderBuilder(request);

if(overrideStaleAndPrefetchTimes) {
//do the same values as we would do without overriding the stale and prefetch times
credentialsProviderBuilder.staleTime(Duration.ofMinutes(2));
credentialsProviderBuilder.prefetchTime(Duration.ofMinutes(4));
}

try (StsCredentialsProvider credentialsProvider = credentialsProviderBuilder.stsClient(stsClient).build()) {
if(overrideStaleAndPrefetchTimes) {
//validate that we actually stored the override values in the build provider
assertThat(credentialsProvider.staleTime()).as("stale time").isEqualTo(Duration.ofMinutes(2));
assertThat(credentialsProvider.prefetchTime()).as("prefetch time").isEqualTo(Duration.ofMinutes(4));
} else {
//validate that the default values are used
assertThat(credentialsProvider.staleTime()).as("stale time").isEqualTo(Duration.ofMinutes(1));
assertThat(credentialsProvider.prefetchTime()).as("prefetch time").isEqualTo(Duration.ofMinutes(5));
}

for (int i = 0; i < numTimesInvokeCredentialsProvider; ++i) {
AwsSessionCredentials providedCredentials = (AwsSessionCredentials) credentialsProvider.resolveCredentials();
assertThat(providedCredentials.accessKeyId()).isEqualTo("a");
Expand Down