Skip to content

Commit bcde900

Browse files
committed
Cache Azure credential obtained from environment
JAVA-4706
1 parent 1ef1b5e commit bcde900

File tree

3 files changed

+172
-18
lines changed

3 files changed

+172
-18
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mongodb.internal;
18+
19+
import com.mongodb.annotations.ThreadSafe;
20+
21+
import java.time.Duration;
22+
import java.util.Optional;
23+
24+
import static com.mongodb.assertions.Assertions.assertNotNull;
25+
import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE;
26+
27+
/**
28+
* A value associated with a lifetime.
29+
*
30+
* <p>Instances are shallowly immutable.</p>
31+
*/
32+
@ThreadSafe
33+
public final class ExpirableValue<T> {
34+
private final T value;
35+
private final long deadline;
36+
37+
public static <T> ExpirableValue<T> expired() {
38+
return new ExpirableValue<>(null, Duration.ofSeconds(-1), System.nanoTime());
39+
}
40+
41+
public static <T> ExpirableValue<T> unexpired(final T value, final Duration lifetime) {
42+
return unexpired(value, lifetime, System.nanoTime());
43+
}
44+
45+
@VisibleForTesting(otherwise = PRIVATE)
46+
static <T> ExpirableValue<T> unexpired(final T value, final Duration lifetime, final long currentNanoTime) {
47+
return new ExpirableValue<>(assertNotNull(value), lifetime, currentNanoTime);
48+
}
49+
50+
private ExpirableValue(final T value, final Duration lifetime, final long currentNanoTime) {
51+
this.value = value;
52+
deadline = currentNanoTime + lifetime.toNanos();
53+
}
54+
55+
/**
56+
* Returns {@link Optional#empty()} if the value is expired. Otherwise, returns an {@link Optional} describing the value.
57+
*/
58+
public Optional<T> getValue() {
59+
return getValue(System.nanoTime());
60+
}
61+
62+
@VisibleForTesting(otherwise = PRIVATE)
63+
Optional<T> getValue(final long currentNanoTime) {
64+
if (currentNanoTime - deadline > 0) {
65+
return Optional.empty();
66+
} else {
67+
return Optional.of(value);
68+
}
69+
}
70+
}

driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
package com.mongodb.internal.authentication;
1818

1919
import com.mongodb.MongoClientException;
20+
import com.mongodb.internal.ExpirableValue;
2021
import org.bson.BsonDocument;
22+
import org.bson.BsonString;
2123
import org.bson.json.JsonParseException;
2224

25+
import java.time.Duration;
2326
import java.util.HashMap;
2427
import java.util.Map;
28+
import java.util.Optional;
2529

2630
import static com.mongodb.internal.authentication.HttpHelper.getHttpContents;
2731

@@ -31,25 +35,44 @@
3135
* <p>This class should not be considered a part of the public API.</p>
3236
*/
3337
public final class AzureCredentialHelper {
34-
public static BsonDocument obtainFromEnvironment() {
35-
String endpoint = "http://" + "169.254.169.254:80"
36-
+ "/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https://vault.azure.net";
37-
38-
Map<String, String> headers = new HashMap<>();
39-
headers.put("Metadata", "true");
40-
headers.put("Accept", "application/json");
41-
42-
String response = getHttpContents("GET", endpoint, headers);
43-
try {
44-
BsonDocument responseDocument = BsonDocument.parse(response);
45-
if (responseDocument.containsKey("access_token")) {
46-
return new BsonDocument("accessToken", responseDocument.get("access_token"));
47-
} else {
48-
throw new MongoClientException("The access_token is missing from Azure IMDS metadata response.");
38+
private static final String ACCESS_TOKEN_FIELD = "access_token";
39+
private static final String EXPIRES_IN_FIELD = "expires_in";
40+
41+
private static ExpirableValue<String> cachedAccessToken = ExpirableValue.expired();
42+
43+
public static synchronized BsonDocument obtainFromEnvironment() {
44+
String accessToken;
45+
Optional<String> cachedValue = cachedAccessToken.getValue();
46+
if (cachedValue.isPresent()) {
47+
accessToken = cachedValue.get();
48+
} else {
49+
String endpoint = "http://" + "169.254.169.254:80"
50+
+ "/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https://vault.azure.net";
51+
52+
Map<String, String> headers = new HashMap<>();
53+
headers.put("Metadata", "true");
54+
headers.put("Accept", "application/json");
55+
56+
BsonDocument responseDocument;
57+
try {
58+
responseDocument = BsonDocument.parse(getHttpContents("GET", endpoint, headers));
59+
} catch (JsonParseException e) {
60+
throw new MongoClientException("Exception parsing JSON from Azure IMDS metadata response.", e);
61+
}
62+
63+
if (!responseDocument.isString(ACCESS_TOKEN_FIELD)) {
64+
throw new MongoClientException(String.format(
65+
"The %s field from Azure IMDS metadata response is missing or is not a string", ACCESS_TOKEN_FIELD));
66+
}
67+
if (!responseDocument.isString(EXPIRES_IN_FIELD)) {
68+
throw new MongoClientException(String.format(
69+
"The %s field from Azure IMDS metadata response is missing or is not a string", EXPIRES_IN_FIELD));
4970
}
50-
} catch (JsonParseException e) {
51-
throw new MongoClientException("Exception parsing JSON from Azure IMDS metadata response.", e);
52-
}
71+
accessToken = responseDocument.getString(ACCESS_TOKEN_FIELD).getValue();
72+
int expiresInSeconds = Integer.parseInt(responseDocument.getString(EXPIRES_IN_FIELD).getValue());
73+
cachedAccessToken = ExpirableValue.unexpired(accessToken, Duration.ofSeconds(expiresInSeconds).minus(Duration.ofMinutes(1)));
74+
}
75+
return new BsonDocument("accessToken", new BsonString(accessToken));
5376
}
5477

5578
private AzureCredentialHelper() {
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mongodb.internal;
18+
19+
import org.junit.jupiter.api.Test;
20+
21+
import java.time.Duration;
22+
23+
import static com.mongodb.internal.ExpirableValue.expired;
24+
import static com.mongodb.internal.ExpirableValue.unexpired;
25+
import static org.junit.jupiter.api.Assertions.assertAll;
26+
import static org.junit.jupiter.api.Assertions.assertEquals;
27+
import static org.junit.jupiter.api.Assertions.assertFalse;
28+
29+
class ExpirableValueTest {
30+
31+
@Test
32+
void testExpired() {
33+
assertFalse(expired().getValue().isPresent());
34+
}
35+
36+
@SuppressWarnings("OptionalGetWithoutIsPresent")
37+
@Test
38+
void testUnexpired() {
39+
assertAll(
40+
() -> assertFalse(unexpired(1, Duration.ZERO).getValue().isPresent()),
41+
() -> assertEquals(1, unexpired(1, Duration.ofSeconds(1)).getValue().get()),
42+
() -> {
43+
ExpirableValue<Integer> expirableValue = unexpired(1, Duration.ofNanos(1));
44+
Thread.sleep(1);
45+
assertFalse(expirableValue.getValue().isPresent());
46+
},
47+
() -> {
48+
ExpirableValue<Integer> expirableValue = unexpired(1, Duration.ofMinutes(60), Long.MAX_VALUE);
49+
assertEquals(1, expirableValue.getValue(Long.MAX_VALUE + Duration.ofMinutes(30).toNanos()).get());
50+
},
51+
() -> {
52+
ExpirableValue<Integer> expirableValue = unexpired(1, Duration.ofMinutes(60), Long.MAX_VALUE);
53+
assertEquals(1, expirableValue.getValue(Long.MAX_VALUE + Duration.ofMinutes(30).toNanos()).get());
54+
assertFalse(expirableValue.getValue(Long.MAX_VALUE + Duration.ofMinutes(61).toNanos()).isPresent());
55+
},
56+
() -> {
57+
ExpirableValue<Integer> expirableValue = unexpired(1, Duration.ofNanos(10), Long.MAX_VALUE - 20);
58+
assertFalse(expirableValue.getValue(Long.MAX_VALUE - 20 + Duration.ofNanos(30).toNanos()).isPresent());
59+
});
60+
}
61+
}

0 commit comments

Comments
 (0)