Skip to content

Implement TcpStream::connect_timeout #43062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 7, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/libstd/net/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,24 @@ impl TcpStream {
super::each_addr(addr, net_imp::TcpStream::connect).map(TcpStream)
}

/// Opens a TCP connection to a remote host with a timeout.
///
/// Unlike `connect`, `connect_timeout` takes a single [`SocketAddr`] since
/// timeout must be applied to individual addresses.
///
/// It is an error to pass a zero `Duration` to this function.
///
/// Unlike other methods on `TcpStream`, this does not correspond to a
/// single system call. It instead calls `connect` in nonblocking mode and
/// then uses an OS-specific mechanism to await the completion of the
/// connection request.
///
/// [`SocketAddr`]: ../../std/net/enum.SocketAddr.html
#[unstable(feature = "tcpstream_connect_timeout", issue = "43709")]
pub fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
net_imp::TcpStream::connect_timeout(addr, timeout).map(TcpStream)
}

/// Returns the socket address of the remote peer of this TCP connection.
///
/// # Examples
Expand Down Expand Up @@ -1509,4 +1527,19 @@ mod tests {
t!(txdone.send(()));
})
}

#[test]
fn connect_timeout_unroutable() {
// this IP is unroutable, so connections should always time out.
let addr = "10.255.255.1:80".parse().unwrap();
let e = TcpStream::connect_timeout(&addr, Duration::from_millis(250)).unwrap_err();
assert_eq!(e.kind(), io::ErrorKind::TimedOut);
}

#[test]
fn connect_timeout_valid() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
TcpStream::connect_timeout(&addr, Duration::from_secs(2)).unwrap();
}
}
4 changes: 4 additions & 0 deletions src/libstd/sys/redox/net/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ impl TcpStream {
Ok(TcpStream(File::open(&Path::new(path.as_str()), &options)?))
}

pub fn connect_timeout(_addr: &SocketAddr, _timeout: Duration) -> Result<()> {
Err(Error::new(ErrorKind::Other, "TcpStream::connect_timeout not implemented"))
}

pub fn duplicate(&self) -> Result<TcpStream> {
Ok(TcpStream(self.0.dup(&[])?))
}
Expand Down
67 changes: 66 additions & 1 deletion src/libstd/sys/unix/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use str;
use sys::fd::FileDesc;
use sys_common::{AsInner, FromInner, IntoInner};
use sys_common::net::{getsockopt, setsockopt, sockaddr_to_addr};
use time::Duration;
use time::{Duration, Instant};
use cmp;

pub use sys::{cvt, cvt_r};
pub extern crate libc as netc;
Expand Down Expand Up @@ -122,6 +123,70 @@ impl Socket {
}
}

pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
self.set_nonblocking(true)?;
let r = unsafe {
let (addrp, len) = addr.into_inner();
cvt(libc::connect(self.0.raw(), addrp, len))
};
self.set_nonblocking(false)?;

match r {
Ok(_) => return Ok(()),
// there's no ErrorKind for EINPROGRESS :(
Err(ref e) if e.raw_os_error() == Some(libc::EINPROGRESS) => {}
Err(e) => return Err(e),
}

let mut pollfd = libc::pollfd {
fd: self.0.raw(),
events: libc::POLLOUT,
revents: 0,
};

if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
return Err(io::Error::new(io::ErrorKind::InvalidInput,
"cannot set a 0 duration timeout"));
}

let start = Instant::now();

loop {
let elapsed = start.elapsed();
if elapsed >= timeout {
return Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out"));
}

let timeout = timeout - elapsed;
let mut timeout = timeout.as_secs()
.saturating_mul(1_000)
.saturating_add(timeout.subsec_nanos() as u64 / 1_000_000);
if timeout == 0 {
timeout = 1;
}

let timeout = cmp::min(timeout, c_int::max_value() as u64) as c_int;

match unsafe { libc::poll(&mut pollfd, 1, timeout) } {
-1 => {
let err = io::Error::last_os_error();
if err.kind() != io::ErrorKind::Interrupted {
return Err(err);
}
}
0 => {}
_ => {
if pollfd.revents & libc::POLLOUT == 0 {
if let Some(e) = self.take_error()? {
return Err(e);
}
}
return Ok(());
}
}
}
}

pub fn accept(&self, storage: *mut sockaddr, len: *mut socklen_t)
-> io::Result<Socket> {
// Unfortunately the only known way right now to accept a socket and
Expand Down
27 changes: 27 additions & 0 deletions src/libstd/sys/windows/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ pub const PIPE_TYPE_BYTE: DWORD = 0x00000000;
pub const PIPE_REJECT_REMOTE_CLIENTS: DWORD = 0x00000008;
pub const PIPE_READMODE_BYTE: DWORD = 0x00000000;

pub const FD_SETSIZE: usize = 64;

#[repr(C)]
#[cfg(target_arch = "x86")]
pub struct WSADATA {
Expand Down Expand Up @@ -837,6 +839,26 @@ pub struct CONSOLE_READCONSOLE_CONTROL {
}
pub type PCONSOLE_READCONSOLE_CONTROL = *mut CONSOLE_READCONSOLE_CONTROL;

#[repr(C)]
#[derive(Copy)]
pub struct fd_set {
pub fd_count: c_uint,
pub fd_array: [SOCKET; FD_SETSIZE],
}

impl Clone for fd_set {
fn clone(&self) -> fd_set {
*self
}
}

#[repr(C)]
#[derive(Copy, Clone)]
pub struct timeval {
pub tv_sec: c_long,
pub tv_usec: c_long,
}

extern "system" {
pub fn WSAStartup(wVersionRequested: WORD,
lpWSAData: LPWSADATA) -> c_int;
Expand Down Expand Up @@ -1125,6 +1147,11 @@ extern "system" {
lpOverlapped: LPOVERLAPPED,
lpNumberOfBytesTransferred: LPDWORD,
bWait: BOOL) -> BOOL;
pub fn select(nfds: c_int,
readfds: *mut fd_set,
writefds: *mut fd_set,
exceptfds: *mut fd_set,
timeout: *const timeval) -> c_int;
}

// Functions that aren't available on Windows XP, but we still use them and just
Expand Down
56 changes: 55 additions & 1 deletion src/libstd/sys/windows/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

use cmp;
use io::{self, Read};
use libc::{c_int, c_void, c_ulong};
use libc::{c_int, c_void, c_ulong, c_long};
use mem;
use net::{SocketAddr, Shutdown};
use ptr;
Expand Down Expand Up @@ -115,6 +115,60 @@ impl Socket {
Ok(socket)
}

pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
self.set_nonblocking(true)?;
let r = unsafe {
let (addrp, len) = addr.into_inner();
cvt(c::connect(self.0, addrp, len))
};
self.set_nonblocking(false)?;

match r {
Ok(_) => return Ok(()),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => return Err(e),
}

if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
return Err(io::Error::new(io::ErrorKind::InvalidInput,
"cannot set a 0 duration timeout"));
}

let mut timeout = c::timeval {
tv_sec: timeout.as_secs() as c_long,
tv_usec: (timeout.subsec_nanos() / 1000) as c_long,
};
if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
timeout.tv_usec = 1;
}

let fds = unsafe {
let mut fds = mem::zeroed::<c::fd_set>();
fds.fd_count = 1;
fds.fd_array[0] = self.0;
fds
};

let mut writefds = fds;
let mut errorfds = fds;

let n = unsafe {
cvt(c::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout))?
};

match n {
0 => Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")),
_ => {
if writefds.fd_count != 1 {
if let Some(e) = self.take_error()? {
return Err(e);
}
}
Ok(())
}
}
}

pub fn accept(&self, storage: *mut c::SOCKADDR,
len: *mut c_int) -> io::Result<Socket> {
let socket = unsafe {
Expand Down
8 changes: 8 additions & 0 deletions src/libstd/sys_common/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,14 @@ impl TcpStream {
Ok(TcpStream { inner: sock })
}

pub fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
init();

let sock = Socket::new(addr, c::SOCK_STREAM)?;
sock.connect_timeout(addr, timeout)?;
Ok(TcpStream { inner: sock })
}

pub fn socket(&self) -> &Socket { &self.inner }

pub fn into_socket(self) -> Socket { self.inner }
Expand Down