Skip to content

Commit 7ff8d61

Browse files
committed
feat(client): add proxy::Tunnel legacy util
1 parent 4c4e062 commit 7ff8d61

File tree

6 files changed

+320
-0
lines changed

6 files changed

+320
-0
lines changed

src/client/legacy/connect/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ pub mod dns;
8080
#[cfg(feature = "tokio")]
8181
mod http;
8282

83+
pub mod proxy;
84+
8385
pub(crate) mod capture;
8486
pub use capture::{capture_connection, CaptureConnection};
8587

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
//! Proxy helpers
2+
3+
mod tunnel;
4+
5+
pub use self::tunnel::Tunnel;
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
use std::error::Error as StdError;
2+
use std::future::Future;
3+
use std::marker::{PhantomData, Unpin};
4+
use std::pin::Pin;
5+
use std::task::{self, Poll};
6+
7+
use http::{HeaderMap, HeaderValue, Uri};
8+
use hyper::rt::{Read, Write};
9+
use pin_project_lite::pin_project;
10+
use tower_service::Service;
11+
12+
/// Tunnel Proxy via HTTP CONNECT
13+
#[derive(Debug)]
14+
pub struct Tunnel<C> {
15+
headers: Headers,
16+
inner: C,
17+
proxy_dst: Uri,
18+
}
19+
20+
#[derive(Clone, Debug)]
21+
enum Headers {
22+
Empty,
23+
Auth(HeaderValue),
24+
Extra(HeaderMap),
25+
}
26+
27+
#[derive(Debug)]
28+
pub enum TunnelError {
29+
Inner(Box<dyn StdError + Send + Sync>),
30+
Io(std::io::Error),
31+
MissingHost,
32+
ProxyAuthRequired,
33+
ProxyHeadersTooLong,
34+
TunnelUnexpectedEof,
35+
TunnelUnsuccessful,
36+
}
37+
38+
pin_project! {
39+
// Not publicly exported (so missing_docs doesn't trigger).
40+
//
41+
// We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
42+
// so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
43+
// (and thus we can change the type in the future).
44+
#[must_use = "futures do nothing unless polled"]
45+
#[allow(missing_debug_implementations)]
46+
pub struct Tunneling<F, T> {
47+
#[pin]
48+
fut: BoxTunneling<T>,
49+
_marker: PhantomData<F>,
50+
}
51+
}
52+
53+
type BoxTunneling<T> = Pin<Box<dyn Future<Output = Result<T, TunnelError>> + Send>>;
54+
55+
impl<C> Tunnel<C> {
56+
/// Create a new Tunnel service.
57+
pub fn new(proxy_dst: Uri, connector: C) -> Self {
58+
Self {
59+
headers: Headers::Empty,
60+
inner: connector,
61+
proxy_dst,
62+
}
63+
}
64+
65+
/// Add `proxy-authorization` header value to the CONNECT request.
66+
pub fn with_auth(mut self, mut auth: HeaderValue) -> Self {
67+
// just in case the user forgot
68+
auth.set_sensitive(true);
69+
match self.headers {
70+
Headers::Empty => {
71+
self.headers = Headers::Auth(auth);
72+
},
73+
Headers::Auth(ref mut existing) => {
74+
*existing = auth;
75+
},
76+
Headers::Extra(ref mut extra) => {
77+
extra.insert(http::header::PROXY_AUTHORIZATION, auth);
78+
}
79+
}
80+
81+
self
82+
}
83+
84+
/// Add extra headers to be sent with the CONNECT request.
85+
///
86+
/// If existing headers have been set, these will be merged.
87+
pub fn with_headers(mut self, mut headers: HeaderMap) -> Self {
88+
match self.headers {
89+
Headers::Empty => {
90+
self.headers = Headers::Extra(headers);
91+
},
92+
Headers::Auth(auth) => {
93+
headers.entry(http::header::PROXY_AUTHORIZATION).or_insert(auth);
94+
self.headers = Headers::Extra(headers);
95+
},
96+
Headers::Extra(ref mut extra) => {
97+
extra.extend(headers);
98+
}
99+
}
100+
101+
self
102+
}
103+
}
104+
105+
impl<C> Service<Uri> for Tunnel<C>
106+
where
107+
C: Service<Uri>,
108+
C::Future: Send + 'static,
109+
C::Response: Read + Write + Unpin + Send + 'static,
110+
C::Error: Into<Box<dyn StdError + Send + Sync>>,
111+
{
112+
type Response = C::Response;
113+
type Error = TunnelError;
114+
type Future = Tunneling<C::Future, C::Response>;
115+
116+
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
117+
futures_util::ready!(self.inner.poll_ready(cx)).map_err(|e| TunnelError::Inner(e.into()))?;
118+
Poll::Ready(Ok(()))
119+
}
120+
121+
fn call(&mut self, dst: Uri) -> Self::Future {
122+
let connecting = self.inner.call(self.proxy_dst.clone());
123+
let headers = self.headers.clone();
124+
125+
Tunneling {
126+
fut: Box::pin(async move {
127+
let conn = connecting.await.map_err(|e| TunnelError::Inner(e.into()))?;
128+
tunnel(
129+
conn,
130+
dst.host().ok_or(TunnelError::MissingHost)?,
131+
dst.port().map(|p| p.as_u16()).unwrap_or(443),
132+
&headers,
133+
)
134+
.await
135+
}),
136+
_marker: PhantomData,
137+
}
138+
}
139+
}
140+
141+
impl<F, T, E> Future for Tunneling<F, T>
142+
where
143+
F: Future<Output = Result<T, E>>,
144+
{
145+
type Output = Result<T, TunnelError>;
146+
147+
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
148+
self.project().fut.poll(cx)
149+
}
150+
}
151+
152+
async fn tunnel<T>(
153+
mut conn: T,
154+
host: &str,
155+
port: u16,
156+
headers: &Headers,
157+
) -> Result<T, TunnelError>
158+
where
159+
T: Read + Write + Unpin,
160+
{
161+
let mut buf = format!(
162+
"\
163+
CONNECT {host}:{port} HTTP/1.1\r\n\
164+
Host: {host}:{port}\r\n\
165+
"
166+
)
167+
.into_bytes();
168+
169+
match headers {
170+
Headers::Auth(auth) => {
171+
buf.extend_from_slice(b"Proxy-Authorization: ");
172+
buf.extend_from_slice(auth.as_bytes());
173+
buf.extend_from_slice(b"\r\n");
174+
},
175+
Headers::Extra(extra) => {
176+
for (name, value) in extra {
177+
buf.extend_from_slice(name.as_str().as_bytes());
178+
buf.extend_from_slice(b": ");
179+
buf.extend_from_slice(value.as_bytes());
180+
buf.extend_from_slice(b"\r\n");
181+
}
182+
183+
},
184+
Headers::Empty => (),
185+
}
186+
187+
// headers end
188+
buf.extend_from_slice(b"\r\n");
189+
190+
crate::rt::write_all(&mut conn, &buf)
191+
.await
192+
.map_err(TunnelError::Io)?;
193+
194+
let mut buf = [0; 8192];
195+
let mut pos = 0;
196+
197+
loop {
198+
let n = crate::rt::read(&mut conn, &mut buf[pos..])
199+
.await
200+
.map_err(TunnelError::Io)?;
201+
202+
if n == 0 {
203+
return Err(TunnelError::TunnelUnexpectedEof);
204+
}
205+
pos += n;
206+
207+
let recvd = &buf[..pos];
208+
if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
209+
if recvd.ends_with(b"\r\n\r\n") {
210+
return Ok(conn);
211+
}
212+
if pos == buf.len() {
213+
return Err(TunnelError::ProxyHeadersTooLong);
214+
}
215+
// else read more
216+
} else if recvd.starts_with(b"HTTP/1.1 407") {
217+
return Err(TunnelError::ProxyAuthRequired);
218+
} else {
219+
return Err(TunnelError::TunnelUnsuccessful);
220+
}
221+
}
222+
}
223+
224+
impl std::fmt::Display for TunnelError {
225+
fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226+
todo!("Display for TunnelError");
227+
}
228+
}
229+
230+
impl std::error::Error for TunnelError {
231+
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
232+
match self {
233+
TunnelError::Io(ref e) => Some(e),
234+
TunnelError::Inner(ref e) => Some(&**e),
235+
_ => None,
236+
}
237+
}
238+
}

src/rt/io.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
use std::marker::Unpin;
2+
use std::pin::Pin;
3+
use std::task::Poll;
4+
5+
use futures_util::future;
6+
use futures_util::ready;
7+
use hyper::rt::{Read, ReadBuf, Write};
8+
9+
pub(crate) async fn read<T>(io: &mut T, buf: &mut [u8]) -> Result<usize, std::io::Error>
10+
where
11+
T: Read + Unpin,
12+
{
13+
future::poll_fn(move |cx| {
14+
let mut buf = ReadBuf::new(buf);
15+
ready!(Pin::new(&mut *io).poll_read(cx, buf.unfilled()))?;
16+
Poll::Ready(Ok(buf.filled().len()))
17+
})
18+
.await
19+
}
20+
21+
pub(crate) async fn write_all<T>(io: &mut T, buf: &[u8]) -> Result<(), std::io::Error>
22+
where
23+
T: Write + Unpin,
24+
{
25+
let mut n = 0;
26+
future::poll_fn(move |cx| {
27+
while n < buf.len() {
28+
n += ready!(Pin::new(&mut *io).poll_write(cx, &buf[n..])?);
29+
}
30+
Poll::Ready(Ok(()))
31+
})
32+
.await
33+
}

src/rt/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
//! Runtime utilities
22
3+
#[cfg(feature = "client-legacy")]
4+
mod io;
5+
#[cfg(feature = "client-legacy")]
6+
pub(crate) use self::io::{read, write_all};
7+
38
#[cfg(feature = "tokio")]
49
pub mod tokio;
510

tests/proxy.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
2+
use tokio::net::TcpListener;
3+
use tower_service::Service;
4+
5+
use hyper_util::client::legacy::connect::{proxy::Tunnel, HttpConnector};
6+
7+
#[cfg(not(miri))]
8+
#[tokio::test]
9+
async fn test_tunnel_works() {
10+
let tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind");
11+
let addr = tcp.local_addr().expect("local_addr");
12+
13+
let proxy_dst = format!("http://{}", addr).parse().expect("uri");
14+
let mut connector = Tunnel::new(proxy_dst, HttpConnector::new());
15+
let t1 = tokio::spawn(async move {
16+
let _conn = connector
17+
.call("https://hyper.rs".parse().unwrap())
18+
.await
19+
.expect("tunnel");
20+
});
21+
22+
let t2 = tokio::spawn(async move {
23+
let (mut io, _) = tcp.accept().await.expect("accept");
24+
let mut buf = [0u8; 64];
25+
let n = io.read(&mut buf).await.expect("read 1");
26+
assert_eq!(
27+
&buf[..n],
28+
b"CONNECT hyper.rs:443 HTTP/1.1\r\nHost: hyper.rs:443\r\n\r\n"
29+
);
30+
io.write_all(b"HTTP/1.1 200 OK\r\n\r\n")
31+
.await
32+
.expect("write 1");
33+
});
34+
35+
t1.await.expect("task 1");
36+
t2.await.expect("task 2");
37+
}

0 commit comments

Comments
 (0)