Skip to content

Commit 7cb72d2

Browse files
committed
fix(server): send 400 responses on parse errors before closing connection
1 parent 44c34ce commit 7cb72d2

File tree

4 files changed

+109
-2
lines changed

4 files changed

+109
-2
lines changed

src/proto/conn.rs

+25-2
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ where I: AsyncRead + AsyncWrite,
186186
let was_mid_parse = !self.io.read_buf().is_empty();
187187
return if was_mid_parse || must_error {
188188
debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len());
189-
Err(e)
189+
self.on_parse_error(e)
190+
.map(|()| Async::NotReady)
190191
} else {
191192
debug!("read eof");
192193
Ok(Async::Ready(None))
@@ -213,7 +214,8 @@ where I: AsyncRead + AsyncWrite,
213214
Err(e) => {
214215
debug!("decoder error = {:?}", e);
215216
self.state.close_read();
216-
return Err(e);
217+
return self.on_parse_error(e)
218+
.map(|()| Async::NotReady);
217219
}
218220
};
219221

@@ -548,6 +550,27 @@ where I: AsyncRead + AsyncWrite,
548550
Ok(AsyncSink::Ready)
549551
}
550552

553+
// When we get a parse error, depending on what side we are, we might be able
554+
// to write a response before closing the connection.
555+
//
556+
// - Client: there is nothing we can do
557+
// - Server: if Response hasn't been written yet, we can send a 4xx response
558+
fn on_parse_error(&mut self, err: ::Error) -> ::Result<()> {
559+
match self.state.writing {
560+
Writing::Init => {
561+
if let Some(msg) = T::on_error(&err) {
562+
self.write_head(msg, false);
563+
self.state.error = Some(err);
564+
return Ok(());
565+
}
566+
}
567+
_ => (),
568+
}
569+
570+
// fallback is pass the error back up
571+
Err(err)
572+
}
573+
551574
fn write_queued(&mut self) -> Poll<(), io::Error> {
552575
trace!("Conn::write_queued()");
553576
let state = match self.state.writing {

src/proto/h1/parse.rs

+25
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,26 @@ impl Http1Transaction for ServerTransaction {
150150
ret
151151
}
152152

153+
fn on_error(err: &::Error) -> Option<MessageHead<Self::Outgoing>> {
154+
let status = match err {
155+
&::Error::Method |
156+
&::Error::Version |
157+
&::Error::Header |
158+
&::Error::Uri(_) => {
159+
StatusCode::BadRequest
160+
},
161+
&::Error::TooLarge => {
162+
StatusCode::RequestHeaderFieldsTooLarge
163+
}
164+
_ => return None,
165+
};
166+
167+
debug!("sending automatic response ({}) for parse error", status);
168+
let mut msg = MessageHead::default();
169+
msg.subject = status;
170+
Some(msg)
171+
}
172+
153173
fn should_error_on_parse_eof() -> bool {
154174
false
155175
}
@@ -317,6 +337,11 @@ impl Http1Transaction for ClientTransaction {
317337
Ok(body)
318338
}
319339

340+
fn on_error(_err: &::Error) -> Option<MessageHead<Self::Outgoing>> {
341+
// we can't tell the server about any errors it creates
342+
None
343+
}
344+
320345
fn should_error_on_parse_eof() -> bool {
321346
true
322347
}

src/proto/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ pub trait Http1Transaction {
149149
fn parse(bytes: &mut BytesMut) -> ParseResult<Self::Incoming>;
150150
fn decoder(head: &MessageHead<Self::Incoming>, method: &mut Option<::Method>) -> ::Result<Option<h1::Decoder>>;
151151
fn encode(head: MessageHead<Self::Outgoing>, has_body: bool, method: &mut Option<Method>, dst: &mut Vec<u8>) -> ::Result<h1::Encoder>;
152+
fn on_error(err: &::Error) -> Option<MessageHead<Self::Outgoing>>;
152153

153154
fn should_error_on_parse_eof() -> bool;
154155
fn should_read_first() -> bool;

tests/server.rs

+58
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,64 @@ fn returning_1xx_response_is_error() {
900900
core.run(fut).unwrap_err();
901901
}
902902

903+
#[test]
904+
fn parse_errors_send_4xx_response() {
905+
let mut core = Core::new().unwrap();
906+
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
907+
let addr = listener.local_addr().unwrap();
908+
909+
thread::spawn(move || {
910+
let mut tcp = connect(&addr);
911+
tcp.write_all(b"GE T / HTTP/1.1\r\n\r\n").unwrap();
912+
let mut buf = [0; 256];
913+
tcp.read(&mut buf).unwrap();
914+
915+
let expected = "HTTP/1.1 400 ";
916+
assert_eq!(s(&buf[..expected.len()]), expected);
917+
});
918+
919+
let fut = listener.incoming()
920+
.into_future()
921+
.map_err(|_| unreachable!())
922+
.and_then(|(item, _incoming)| {
923+
let (socket, _) = item.unwrap();
924+
Http::<hyper::Chunk>::new()
925+
.serve_connection(socket, HelloWorld)
926+
.map(|_| ())
927+
});
928+
929+
core.run(fut).unwrap_err();
930+
}
931+
932+
#[test]
933+
fn illegal_request_length_returns_400_response() {
934+
let mut core = Core::new().unwrap();
935+
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
936+
let addr = listener.local_addr().unwrap();
937+
938+
thread::spawn(move || {
939+
let mut tcp = connect(&addr);
940+
tcp.write_all(b"POST / HTTP/1.1\r\nContent-Length: foo\r\n\r\n").unwrap();
941+
let mut buf = [0; 256];
942+
tcp.read(&mut buf).unwrap();
943+
944+
let expected = "HTTP/1.1 400 ";
945+
assert_eq!(s(&buf[..expected.len()]), expected);
946+
});
947+
948+
let fut = listener.incoming()
949+
.into_future()
950+
.map_err(|_| unreachable!())
951+
.and_then(|(item, _incoming)| {
952+
let (socket, _) = item.unwrap();
953+
Http::<hyper::Chunk>::new()
954+
.serve_connection(socket, HelloWorld)
955+
.map(|_| ())
956+
});
957+
958+
core.run(fut).unwrap_err();
959+
}
960+
903961
#[test]
904962
fn remote_addr() {
905963
let server = serve();

0 commit comments

Comments
 (0)