Skip to content

Commit 7976023

Browse files
committed
fix(client): don't error on read before writing request
1 parent 5ce269a commit 7976023

File tree

2 files changed

+77
-18
lines changed

2 files changed

+77
-18
lines changed

src/proto/conn.rs

+16-2
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,21 @@ where I: AsyncRead + AsyncWrite,
238238
ret
239239
}
240240

241-
pub fn maybe_park_read(&mut self) {
241+
pub fn read_keep_alive(&mut self) -> Result<(), ::Error> {
242+
debug_assert!(!self.can_read_head() && !self.can_read_body());
243+
244+
trace!("Conn::read_keep_alive");
245+
246+
if T::should_read_first() || !self.state.is_idle() {
247+
self.maybe_park_read();
248+
} else {
249+
self.try_empty_read()?;
250+
}
251+
252+
Ok(())
253+
}
254+
255+
fn maybe_park_read(&mut self) {
242256
if !self.io.is_read_blocked() {
243257
// the Io object is ready to read, which means it will never alert
244258
// us that it is ready until we drain it. However, we're currently
@@ -258,7 +272,7 @@ where I: AsyncRead + AsyncWrite,
258272
//
259273
// This should only be called for Clients wanting to enter the idle
260274
// state.
261-
pub fn try_empty_read(&mut self) -> io::Result<()> {
275+
fn try_empty_read(&mut self) -> io::Result<()> {
262276
assert!(!self.can_read_head() && !self.can_read_body());
263277

264278
if !self.io.read_buf().is_empty() {

src/proto/dispatch.rs

+61-16
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,20 @@ where
6666
self.conn.disable_keep_alive()
6767
}
6868

69+
fn poll2(&mut self) -> Poll<(), ::Error> {
70+
self.poll_read()?;
71+
self.poll_write()?;
72+
self.poll_flush()?;
73+
74+
if self.is_done() {
75+
try_ready!(self.conn.shutdown());
76+
trace!("Dispatch::poll done");
77+
Ok(Async::Ready(()))
78+
} else {
79+
Ok(Async::NotReady)
80+
}
81+
}
82+
6983
fn poll_read(&mut self) -> Poll<(), ::Error> {
7084
loop {
7185
if self.is_closing {
@@ -163,12 +177,8 @@ where
163177
} else {
164178
// just drop, the body will close automatically
165179
}
166-
} else if !T::should_read_first() {
167-
self.conn.try_empty_read()?;
168-
return Ok(Async::NotReady);
169180
} else {
170-
self.conn.maybe_park_read();
171-
return Ok(Async::Ready(()));
181+
return self.conn.read_keep_alive().map(Async::Ready);
172182
}
173183
}
174184
}
@@ -266,17 +276,13 @@ where
266276
#[inline]
267277
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
268278
trace!("Dispatcher::poll");
269-
self.poll_read()?;
270-
self.poll_write()?;
271-
self.poll_flush()?;
272-
273-
if self.is_done() {
274-
try_ready!(self.conn.shutdown());
275-
trace!("Dispatch::poll done");
276-
Ok(Async::Ready(()))
277-
} else {
278-
Ok(Async::NotReady)
279-
}
279+
self.poll2().or_else(|e| {
280+
// An error means we're shutting down either way.
281+
// We just try to give the error to the user,
282+
// and close the connection with an Ok. If we
283+
// cannot give it to the user, then return the Err.
284+
self.dispatch.recv_msg(Err(e)).map(Async::Ready)
285+
})
280286
}
281287
}
282288

@@ -399,6 +405,9 @@ where
399405
if let Some(cb) = self.callback.take() {
400406
let _ = cb.send(Err(err));
401407
Ok(())
408+
} else if let Ok(Async::Ready(Some(ClientMsg::Request(_, _, cb)))) = self.rx.poll() {
409+
let _ = cb.send(Err(err));
410+
Ok(())
402411
} else {
403412
Err(err)
404413
}
@@ -424,3 +433,39 @@ where
424433
self.callback.is_none()
425434
}
426435
}
436+
437+
#[cfg(test)]
438+
mod tests {
439+
use futures::Sink;
440+
441+
use super::*;
442+
use mock::AsyncIo;
443+
use proto::ClientTransaction;
444+
445+
#[test]
446+
fn client_read_response_before_writing_request() {
447+
extern crate pretty_env_logger;
448+
let _ = pretty_env_logger::try_init();
449+
::futures::lazy(|| {
450+
let io = AsyncIo::new_buf(b"HTTP/1.1 200 OK\r\n\r\n".to_vec(), 100);
451+
let (mut tx, rx) = mpsc::channel(0);
452+
let conn = Conn::<_, ::Chunk, ClientTransaction>::new(io, Default::default());
453+
let mut dispatcher = Dispatcher::new(Client::new(rx), conn);
454+
455+
let req = RequestHead {
456+
version: ::HttpVersion::Http11,
457+
subject: ::proto::RequestLine::default(),
458+
headers: Default::default(),
459+
};
460+
let (res_tx, res_rx) = oneshot::channel();
461+
tx.start_send(ClientMsg::Request(req, None::<::Body>, res_tx)).unwrap();
462+
463+
dispatcher.poll().expect("dispatcher poll 1");
464+
dispatcher.poll().expect("dispatcher poll 2");
465+
let _res = res_rx.wait()
466+
.expect("callback poll")
467+
.expect("callback response");
468+
Ok::<(), ()>(())
469+
}).wait().unwrap();
470+
}
471+
}

0 commit comments

Comments
 (0)