Skip to content

Commit 3a3e086

Browse files
committed
Merge pull request #771 from hyperium/allow-proxy
feat(client): add Proxy support
2 parents 4828437 + 25010fc commit 3a3e086

File tree

4 files changed

+115
-16
lines changed

4 files changed

+115
-16
lines changed

src/client/mod.rs

+56-8
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@
5555
//! clone2.post("http://example.domain/post").body("foo=bar").send().unwrap();
5656
//! });
5757
//! ```
58+
use std::borrow::Cow;
5859
use std::default::Default;
5960
use std::io::{self, copy, Read};
60-
use std::iter::Extend;
6161
use std::fmt;
6262

6363
use std::time::Duration;
@@ -66,7 +66,7 @@ use url::Url;
6666
use url::ParseError as UrlError;
6767

6868
use header::{Headers, Header, HeaderFormat};
69-
use header::{ContentLength, Location};
69+
use header::{ContentLength, Host, Location};
7070
use method::Method;
7171
use net::{NetworkConnector, NetworkStream};
7272
use Error;
@@ -90,6 +90,7 @@ pub struct Client {
9090
redirect_policy: RedirectPolicy,
9191
read_timeout: Option<Duration>,
9292
write_timeout: Option<Duration>,
93+
proxy: Option<(Cow<'static, str>, Cow<'static, str>, u16)>
9394
}
9495

9596
impl fmt::Debug for Client {
@@ -98,6 +99,7 @@ impl fmt::Debug for Client {
9899
.field("redirect_policy", &self.redirect_policy)
99100
.field("read_timeout", &self.read_timeout)
100101
.field("write_timeout", &self.write_timeout)
102+
.field("proxy", &self.proxy)
101103
.finish()
102104
}
103105
}
@@ -127,6 +129,7 @@ impl Client {
127129
redirect_policy: Default::default(),
128130
read_timeout: None,
129131
write_timeout: None,
132+
proxy: None,
130133
}
131134
}
132135

@@ -145,6 +148,12 @@ impl Client {
145148
self.write_timeout = dur;
146149
}
147150

151+
/// Set a proxy for requests of this Client.
152+
pub fn set_proxy<S, H>(&mut self, scheme: S, host: H, port: u16)
153+
where S: Into<Cow<'static, str>>, H: Into<Cow<'static, str>> {
154+
self.proxy = Some((scheme.into(), host.into(), port));
155+
}
156+
148157
/// Build a Get request.
149158
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
150159
self.request(Method::Get, url)
@@ -247,7 +256,7 @@ impl<'a> RequestBuilder<'a> {
247256
pub fn send(self) -> ::Result<Response> {
248257
let RequestBuilder { client, method, url, headers, body } = self;
249258
let mut url = try!(url);
250-
trace!("send {:?} {:?}", method, url);
259+
trace!("send method={:?}, url={:?}, client={:?}", method, url, client);
251260

252261
let can_have_body = match method {
253262
Method::Get | Method::Head => false,
@@ -261,12 +270,25 @@ impl<'a> RequestBuilder<'a> {
261270
};
262271

263272
loop {
264-
let message = {
265-
let (host, port) = try!(get_host_and_port(&url));
266-
try!(client.protocol.new_message(&host, port, url.scheme()))
273+
let mut req = {
274+
let (scheme, host, port) = match client.proxy {
275+
Some(ref proxy) => (proxy.0.as_ref(), proxy.1.as_ref(), proxy.2),
276+
None => {
277+
let hp = try!(get_host_and_port(&url));
278+
(url.scheme(), hp.0, hp.1)
279+
}
280+
};
281+
let mut headers = match headers {
282+
Some(ref headers) => headers.clone(),
283+
None => Headers::new(),
284+
};
285+
headers.set(Host {
286+
hostname: host.to_owned(),
287+
port: Some(port),
288+
});
289+
let message = try!(client.protocol.new_message(&host, port, scheme));
290+
Request::with_headers_and_message(method.clone(), url.clone(), headers, message)
267291
};
268-
let mut req = try!(Request::with_message(method.clone(), url.clone(), message));
269-
headers.as_ref().map(|headers| req.headers_mut().extend(headers.iter()));
270292

271293
try!(req.set_write_timeout(client.write_timeout));
272294
try!(req.set_read_timeout(client.read_timeout));
@@ -456,6 +478,8 @@ fn get_host_and_port(url: &Url) -> ::Result<(&str, u16)> {
456478
mod tests {
457479
use std::io::Read;
458480
use header::Server;
481+
use http::h1::Http11Message;
482+
use mock::{MockStream};
459483
use super::{Client, RedirectPolicy};
460484
use super::pool::Pool;
461485
use url::Url;
@@ -477,6 +501,30 @@ mod tests {
477501
"
478502
});
479503

504+
505+
#[test]
506+
fn test_proxy() {
507+
use super::pool::PooledStream;
508+
mock_connector!(ProxyConnector {
509+
b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"
510+
});
511+
let mut client = Client::with_connector(Pool::with_connector(Default::default(), ProxyConnector));
512+
client.set_proxy("http", "example.proxy", 8008);
513+
let mut dump = vec![];
514+
client.get("http://127.0.0.1/foo/bar").send().unwrap().read_to_end(&mut dump).unwrap();
515+
516+
{
517+
let box_message = client.protocol.new_message("example.proxy", 8008, "http").unwrap();
518+
let message = box_message.downcast::<Http11Message>().unwrap();
519+
let stream = message.into_inner().downcast::<PooledStream<MockStream>>().unwrap().into_inner();
520+
let s = ::std::str::from_utf8(&stream.write).unwrap();
521+
let request_line = "GET http://127.0.0.1/foo/bar HTTP/1.1\r\n";
522+
assert_eq!(&s[..request_line.len()], request_line);
523+
assert!(s.contains("Host: example.proxy:8008\r\n"));
524+
}
525+
526+
}
527+
480528
#[test]
481529
fn test_redirect_followall() {
482530
let mut client = Client::with_connector(MockRedirectPolicy);

src/client/pool.rs

+10-4
Original file line numberDiff line numberDiff line change
@@ -102,22 +102,21 @@ impl<C: NetworkConnector<Stream=S>, S: NetworkStream + Send> NetworkConnector fo
102102
type Stream = PooledStream<S>;
103103
fn connect(&self, host: &str, port: u16, scheme: &str) -> ::Result<PooledStream<S>> {
104104
let key = key(host, port, scheme);
105-
let mut locked = self.inner.lock().unwrap();
106105
let mut should_remove = false;
107-
let inner = match locked.conns.get_mut(&key) {
106+
let inner = match self.inner.lock().unwrap().conns.get_mut(&key) {
108107
Some(ref mut vec) => {
109108
trace!("Pool had connection, using");
110109
should_remove = vec.len() == 1;
111110
vec.pop().unwrap()
112111
}
113-
_ => PooledStreamInner {
112+
None => PooledStreamInner {
114113
key: key.clone(),
115114
stream: try!(self.connector.connect(host, port, scheme)),
116115
previous_response_expected_no_content: false,
117116
}
118117
};
119118
if should_remove {
120-
locked.conns.remove(&key);
119+
self.inner.lock().unwrap().conns.remove(&key);
121120
}
122121
Ok(PooledStream {
123122
inner: Some(inner),
@@ -134,6 +133,13 @@ pub struct PooledStream<S> {
134133
pool: Arc<Mutex<PoolImpl<S>>>,
135134
}
136135

136+
impl<S: NetworkStream> PooledStream<S> {
137+
/// Take the wrapped stream out of the pool completely.
138+
pub fn into_inner(mut self) -> S {
139+
self.inner.take().expect("PooledStream lost its inner stream").stream
140+
}
141+
}
142+
137143
#[derive(Debug)]
138144
struct PooledStreamInner<S> {
139145
key: Key,

src/client/request.rs

+36-2
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,20 @@ impl Request<Fresh> {
7070
});
7171
}
7272

73-
Ok(Request {
73+
Ok(Request::with_headers_and_message(method, url, headers, message))
74+
}
75+
76+
#[doc(hidden)]
77+
pub fn with_headers_and_message(method: Method, url: Url, headers: Headers, message: Box<HttpMessage>)
78+
-> Request<Fresh> {
79+
Request {
7480
method: method,
7581
headers: headers,
7682
url: url,
7783
version: version::HttpVersion::Http11,
7884
message: message,
7985
_marker: PhantomData,
80-
})
86+
}
8187
}
8288

8389
/// Create a new client request.
@@ -129,6 +135,8 @@ impl Request<Fresh> {
129135
pub fn headers_mut(&mut self) -> &mut Headers { &mut self.headers }
130136
}
131137

138+
139+
132140
impl Request<Streaming> {
133141
/// Completes writing the request, and returns a response to read from.
134142
///
@@ -246,6 +254,32 @@ mod tests {
246254
assert!(!s.contains("Content-Length:"));
247255
}
248256

257+
#[test]
258+
fn test_host_header() {
259+
let url = Url::parse("http://example.dom").unwrap();
260+
let req = Request::with_connector(
261+
Get, url, &mut MockConnector
262+
).unwrap();
263+
let bytes = run_request(req);
264+
let s = from_utf8(&bytes[..]).unwrap();
265+
assert!(s.contains("Host: example.dom"));
266+
}
267+
268+
#[test]
269+
fn test_proxy() {
270+
let url = Url::parse("http://example.dom").unwrap();
271+
let proxy_url = Url::parse("http://pro.xy").unwrap();
272+
let mut req = Request::with_connector(
273+
Get, proxy_url, &mut MockConnector
274+
).unwrap();
275+
req.url = url;
276+
let bytes = run_request(req);
277+
let s = from_utf8(&bytes[..]).unwrap();
278+
let request_line = "GET http://example.dom/ HTTP/1.1";
279+
assert_eq!(&s[..request_line.len()], request_line);
280+
assert!(s.contains("Host: pro.xy"));
281+
}
282+
249283
#[test]
250284
fn test_post_chunked_with_encoding() {
251285
let url = Url::parse("http://example.dom").unwrap();

src/http/h1.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use url::Position as UrlPosition;
1111

1212
use buffer::BufReader;
1313
use Error;
14-
use header::{Headers, ContentLength, TransferEncoding};
14+
use header::{Headers, Host, ContentLength, TransferEncoding};
1515
use header::Encoding::Chunked;
1616
use method::{Method};
1717
use net::{NetworkConnector, NetworkStream};
@@ -144,7 +144,18 @@ impl HttpMessage for Http11Message {
144144
let mut stream = BufWriter::new(stream);
145145

146146
{
147-
let uri = &head.url[UrlPosition::BeforePath..UrlPosition::AfterQuery];
147+
let uri = match head.headers.get::<Host>() {
148+
Some(host)
149+
if Some(&*host.hostname) == head.url.host_str()
150+
&& host.port == head.url.port_or_known_default() => {
151+
&head.url[UrlPosition::BeforePath..UrlPosition::AfterQuery]
152+
},
153+
_ => {
154+
trace!("url and host header dont match, using absolute uri form");
155+
head.url.as_ref()
156+
}
157+
158+
};
148159

149160
let version = version::HttpVersion::Http11;
150161
debug!("request line: {:?} {:?} {:?}", head.method, uri, version);

0 commit comments

Comments
 (0)