Skip to content

RestTemplateBuilder.basicAuth causes the entire body to be read into memory #17010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
Expand All @@ -41,6 +38,8 @@

import org.springframework.beans.BeanInstantiationException;
import org.springframework.beans.BeanUtils;
import org.springframework.boot.web.client.BasicAuthentication;
import org.springframework.boot.web.client.BasicAuthenticationClientHttpRequestFactory;
import org.springframework.boot.web.client.ClientHttpRequestFactorySupplier;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.boot.web.client.RootUriTemplateHandler;
Expand All @@ -50,12 +49,11 @@
import org.springframework.http.HttpMethod;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.InterceptingClientHttpRequestFactory;
import org.springframework.http.client.support.BasicAuthenticationInterceptor;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.client.DefaultResponseErrorHandler;
Expand Down Expand Up @@ -86,6 +84,7 @@
* @author Phillip Webb
* @author Andy Wilkinson
* @author Kristine Jetzke
* @author Dmytro Nosan
* @since 1.4.0
*/
public class TestRestTemplate {
Expand Down Expand Up @@ -154,31 +153,35 @@ private TestRestTemplate(RestTemplate restTemplate, String username, String pass

private Class<? extends ClientHttpRequestFactory> getRequestFactoryClass(
RestTemplate restTemplate) {
return getRequestFactory(restTemplate).getClass();
}

private ClientHttpRequestFactory getRequestFactory(RestTemplate restTemplate) {
ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory();
if (InterceptingClientHttpRequestFactory.class
.isAssignableFrom(requestFactory.getClass())) {
Field requestFactoryField = ReflectionUtils.findField(RestTemplate.class,
"requestFactory");
ReflectionUtils.makeAccessible(requestFactoryField);
requestFactory = (ClientHttpRequestFactory) ReflectionUtils
.getField(requestFactoryField, restTemplate);
while (requestFactory instanceof InterceptingClientHttpRequestFactory
|| requestFactory instanceof BasicAuthenticationClientHttpRequestFactory) {
requestFactory = unwrapRequestFactory(
((AbstractClientHttpRequestFactoryWrapper) requestFactory));
}
return requestFactory.getClass();
return requestFactory;
}

private ClientHttpRequestFactory unwrapRequestFactory(
AbstractClientHttpRequestFactoryWrapper requestFactory) {
Field field = ReflectionUtils.findField(
AbstractClientHttpRequestFactoryWrapper.class, "requestFactory");
ReflectionUtils.makeAccessible(field);
return (ClientHttpRequestFactory) ReflectionUtils.getField(field, requestFactory);
}

private void addAuthentication(RestTemplate restTemplate, String username,
String password) {
if (username == null) {
if (username == null || password == null) {
return;
}
List<ClientHttpRequestInterceptor> interceptors = restTemplate.getInterceptors();
if (interceptors == null) {
interceptors = Collections.emptyList();
}
interceptors = new ArrayList<>(interceptors);
interceptors.removeIf(BasicAuthenticationInterceptor.class::isInstance);
interceptors.add(new BasicAuthenticationInterceptor(username, password));
restTemplate.setInterceptors(interceptors);
ClientHttpRequestFactory requestFactory = getRequestFactory(restTemplate);
restTemplate.setRequestFactory(new BasicAuthenticationClientHttpRequestFactory(
new BasicAuthentication(username, password), requestFactory));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.net.URI;
import java.util.List;

import org.apache.http.client.config.RequestConfig;
import org.junit.jupiter.api.Test;

import org.springframework.boot.test.web.client.TestRestTemplate.CustomHttpComponentsClientHttpRequestFactory;
import org.springframework.boot.test.web.client.TestRestTemplate.HttpClientOption;
import org.springframework.boot.web.client.BasicAuthenticationClientHttpRequestFactory;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
Expand All @@ -35,12 +35,9 @@
import org.springframework.http.RequestEntity;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.InterceptingClientHttpRequestFactory;
import org.springframework.http.client.OkHttp3ClientHttpRequestFactory;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.http.client.support.BasicAuthenticationInterceptor;
import org.springframework.mock.env.MockEnvironment;
import org.springframework.mock.http.client.MockClientHttpRequest;
import org.springframework.mock.http.client.MockClientHttpResponse;
Expand Down Expand Up @@ -150,7 +147,7 @@ public void getRootUriRootUriNotSet() {
public void authenticated() {
assertThat(new TestRestTemplate("user", "password").getRestTemplate()
.getRequestFactory())
.isInstanceOf(InterceptingClientHttpRequestFactory.class);
.isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class);
}

@Test
Expand Down Expand Up @@ -227,43 +224,42 @@ private Object mockArgument(Class<?> type) throws Exception {
}

@Test
public void withBasicAuthAddsBasicAuthInterceptorWhenNotAlreadyPresent() {
public void withBasicAuthAddsBasicAuthClientFactoryWhenNotAlreadyPresent() {
TestRestTemplate originalTemplate = new TestRestTemplate();
TestRestTemplate basicAuthTemplate = originalTemplate.withBasicAuth("user",
"password");
assertThat(basicAuthTemplate.getRestTemplate().getMessageConverters())
.containsExactlyElementsOf(
originalTemplate.getRestTemplate().getMessageConverters());
assertThat(basicAuthTemplate.getRestTemplate().getRequestFactory())
.isInstanceOf(InterceptingClientHttpRequestFactory.class);
.isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class);
assertThat(ReflectionTestUtils.getField(
basicAuthTemplate.getRestTemplate().getRequestFactory(),
"requestFactory"))
.isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class);
assertThat(basicAuthTemplate.getRestTemplate().getUriTemplateHandler())
.isSameAs(originalTemplate.getRestTemplate().getUriTemplateHandler());
assertThat(basicAuthTemplate.getRestTemplate().getInterceptors()).hasSize(1);
assertBasicAuthorizationInterceptorCredentials(basicAuthTemplate, "user",
"password");
assertThat(basicAuthTemplate.getRestTemplate().getInterceptors()).isEmpty();
assertBasicAuthorizationCredentials(basicAuthTemplate, "user", "password");
}

@Test
public void withBasicAuthReplacesBasicAuthInterceptorWhenAlreadyPresent() {
public void withBasicAuthReplacesBasicAuthClientFactoryWhenAlreadyPresent() {
TestRestTemplate original = new TestRestTemplate("foo", "bar")
.withBasicAuth("replace", "replace");
TestRestTemplate basicAuth = original.withBasicAuth("user", "password");
assertThat(basicAuth.getRestTemplate().getMessageConverters())
.containsExactlyElementsOf(
original.getRestTemplate().getMessageConverters());
assertThat(basicAuth.getRestTemplate().getRequestFactory())
.isInstanceOf(InterceptingClientHttpRequestFactory.class);
.isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class);
assertThat(ReflectionTestUtils.getField(
basicAuth.getRestTemplate().getRequestFactory(), "requestFactory"))
.isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class);
assertThat(basicAuth.getRestTemplate().getUriTemplateHandler())
.isSameAs(original.getRestTemplate().getUriTemplateHandler());
assertThat(basicAuth.getRestTemplate().getInterceptors()).hasSize(1);
assertBasicAuthorizationInterceptorCredentials(basicAuth, "user", "password");
assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty();
assertBasicAuthorizationCredentials(basicAuth, "user", "password");
}

@Test
Expand Down Expand Up @@ -394,17 +390,14 @@ private void verifyRelativeUriHandling(TestRestTemplateCallback callback)
verify(requestFactory).createRequest(eq(absoluteUri), any(HttpMethod.class));
}

private void assertBasicAuthorizationInterceptorCredentials(
TestRestTemplate testRestTemplate, String username, String password) {
@SuppressWarnings("unchecked")
List<ClientHttpRequestInterceptor> requestFactoryInterceptors = (List<ClientHttpRequestInterceptor>) ReflectionTestUtils
.getField(testRestTemplate.getRestTemplate().getRequestFactory(),
"interceptors");
assertThat(requestFactoryInterceptors).hasSize(1);
ClientHttpRequestInterceptor interceptor = requestFactoryInterceptors.get(0);
assertThat(interceptor).isInstanceOf(BasicAuthenticationInterceptor.class);
assertThat(interceptor).hasFieldOrPropertyWithValue("username", username);
assertThat(interceptor).hasFieldOrPropertyWithValue("password", password);
private void assertBasicAuthorizationCredentials(TestRestTemplate testRestTemplate,
String username, String password) {
ClientHttpRequestFactory requestFactory = testRestTemplate.getRestTemplate()
.getRequestFactory();
Object authentication = ReflectionTestUtils.getField(requestFactory,
"authentication");
assertThat(authentication).hasFieldOrPropertyWithValue("username", username);
assertThat(authentication).hasFieldOrPropertyWithValue("password", password);

}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.boot.web.client;

import java.nio.charset.Charset;

import org.springframework.util.Assert;

/**
* Basic authentication properties which are used by
* {@link BasicAuthenticationClientHttpRequestFactory}.
*
* @author Dmytro Nosan
* @since 2.2.0
* @see BasicAuthenticationClientHttpRequestFactory
*/
public class BasicAuthentication {

private final String username;

private final String password;

private final Charset charset;

/**
* Create a new {@link BasicAuthentication}.
* @param username the username to use
* @param password the password to use
*/
public BasicAuthentication(String username, String password) {
this(username, password, null);
}

/**
* Create a new {@link BasicAuthentication}.
* @param username the username to use
* @param password the password to use
* @param charset the charset to use
*/
public BasicAuthentication(String username, String password, Charset charset) {
Assert.notNull(username, "Username must not be null");
Assert.notNull(password, "Password must not be null");
this.username = username;
this.password = password;
this.charset = charset;
}

/**
* The username to use.
* @return the username, never {@code null} or {@code empty}.
*/
public String getUsername() {
return this.username;
}

/**
* The password to use.
* @return the password, never {@code null} or {@code empty}.
*/
public String getPassword() {
return this.password;
}

/**
* The charset to use.
* @return the charset, or {@code null}.
*/
public Charset getCharset() {
return this.charset;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.boot.web.client;

import java.io.IOException;
import java.net.URI;

import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.util.Assert;

/**
* {@link ClientHttpRequestFactory} to apply a given HTTP Basic Authentication
* username/password pair, unless a custom Authorization header has been set before.
*
* @author Dmytro Nosan
* @since 2.2.0
*/
public class BasicAuthenticationClientHttpRequestFactory
extends AbstractClientHttpRequestFactoryWrapper {

private final BasicAuthentication authentication;

/**
* Create a new {@link BasicAuthenticationClientHttpRequestFactory} which adds
* {@link HttpHeaders#AUTHORIZATION} header for the given authentication.
* @param authentication the authentication to use
* @param clientHttpRequestFactory the factory to use
*/
public BasicAuthenticationClientHttpRequestFactory(BasicAuthentication authentication,
ClientHttpRequestFactory clientHttpRequestFactory) {
super(clientHttpRequestFactory);
Assert.notNull(authentication, "Authentication must not be null");
this.authentication = authentication;
}

@Override
protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod,
ClientHttpRequestFactory requestFactory) throws IOException {
BasicAuthentication authentication = this.authentication;
ClientHttpRequest request = requestFactory.createRequest(uri, httpMethod);
HttpHeaders headers = request.getHeaders();
if (!headers.containsKey(HttpHeaders.AUTHORIZATION)) {
headers.setBasicAuth(authentication.getUsername(),
authentication.getPassword(), authentication.getCharset());
}
return request;
}

}
Loading