Skip to content

Commit f11c451

Browse files
authored
Merge pull request #1 from sfackler/connect-timeout
Add connect_timeout
2 parents b465cb3 + b749cdc commit f11c451

File tree

3 files changed

+163
-1
lines changed

3 files changed

+163
-1
lines changed

src/socket.rs

+53
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,29 @@ impl Socket {
6262
self.inner.connect(addr)
6363
}
6464

65+
/// Initiate a connection on this socket to the specified address, only
66+
/// only waiting for a certain period of time for the connection to be
67+
/// established.
68+
///
69+
/// Unlike many other methods on `Socket`, this does *not* correspond to a
70+
/// single C function. It sets the socket to nonblocking mode, connects via
71+
/// connect(2), and then waits for the connection to complete with poll(2)
72+
/// on Unix and select on Windows. When the connection is complete, the
73+
/// socket is set back to blocking mode. On Unix, this will loop over
74+
/// `EINTR` errors.
75+
///
76+
/// # Warnings
77+
///
78+
/// The nonblocking state of the socket is overridden by this function -
79+
/// it will be returned in blocking mode on success, and in an indeterminate
80+
/// state on failure.
81+
///
82+
/// If the connection request times out, it may still be processing in the
83+
/// background - a second call to `connect` or `connect_timeout` may fail.
84+
pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
85+
self.inner.connect_timeout(addr, timeout)
86+
}
87+
6588
/// Binds this socket to the specified address.
6689
///
6790
/// This function directly corresponds to the bind(2) function on Windows
@@ -667,3 +690,33 @@ impl From<Protocol> for i32 {
667690
a.into()
668691
}
669692
}
693+
694+
#[cfg(test)]
695+
mod test {
696+
use super::*;
697+
698+
#[test]
699+
fn connect_timeout_unrouteable() {
700+
// this IP is unroutable, so connections should always time out
701+
let addr: SocketAddr = "10.255.255.1:80".parse().unwrap();
702+
703+
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
704+
match socket.connect_timeout(&addr, Duration::from_millis(250)) {
705+
Ok(_) => panic!("unexpected success"),
706+
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {}
707+
Err(e) => panic!("unexpected error {}", e),
708+
}
709+
}
710+
711+
#[test]
712+
fn connect_timeout_valid() {
713+
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
714+
socket.bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
715+
socket.listen(128).unwrap();
716+
717+
let addr = socket.local_addr().unwrap();
718+
719+
let socket = Socket::new(Domain::ipv4(), Type::stream(), None).unwrap();
720+
socket.connect_timeout(&addr, Duration::from_millis(250)).unwrap();
721+
}
722+
}

src/sys/unix/mod.rs

+62-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, SocketAddr}
1818
use std::ops::Neg;
1919
use std::os::unix::prelude::*;
2020
use std::sync::atomic::{AtomicBool, Ordering, ATOMIC_BOOL_INIT};
21-
use std::time::Duration;
21+
use std::time::{Duration, Instant};
2222

2323
use libc::{self, c_void, c_int, sockaddr_in, sockaddr_storage, sockaddr_in6};
2424
use libc::{sockaddr, socklen_t, AF_INET, AF_INET6, ssize_t};
@@ -118,6 +118,67 @@ impl Socket {
118118
}
119119
}
120120

121+
pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
122+
self.set_nonblocking(true)?;
123+
let r = self.connect(addr);
124+
self.set_nonblocking(false)?;
125+
126+
match r {
127+
Ok(()) => return Ok(()),
128+
// there's no io::ErrorKind conversion registered for EINPROGRESS :(
129+
Err(ref e) if e.raw_os_error() == Some(libc::EINPROGRESS) => {}
130+
Err(e) => return Err(e),
131+
}
132+
133+
let mut pollfd = libc::pollfd {
134+
fd: self.fd,
135+
events: libc::POLLOUT,
136+
revents: 0,
137+
};
138+
139+
if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
140+
return Err(io::Error::new(io::ErrorKind::InvalidInput,
141+
"cannot set a 0 duration timeout"));
142+
}
143+
144+
let start = Instant::now();
145+
146+
loop {
147+
let elapsed = start.elapsed();
148+
if elapsed >= timeout {
149+
return Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out"));
150+
}
151+
152+
let timeout = timeout - elapsed;
153+
let mut timeout = timeout.as_secs()
154+
.saturating_mul(1_000)
155+
.saturating_add(timeout.subsec_nanos() as u64 / 1_000_000);
156+
if timeout == 0 {
157+
timeout = 1;
158+
}
159+
160+
let timeout = cmp::min(timeout, c_int::max_value() as u64) as c_int;
161+
162+
match unsafe { libc::poll(&mut pollfd, 1, timeout) } {
163+
-1 => {
164+
let err = io::Error::last_os_error();
165+
if err.kind() != io::ErrorKind::Interrupted {
166+
return Err(err);
167+
}
168+
}
169+
0 => return Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")),
170+
_ => {
171+
if pollfd.revents & libc::POLLOUT == 0 {
172+
if let Some(e) = self.take_error()? {
173+
return Err(e);
174+
}
175+
}
176+
return Ok(());
177+
}
178+
}
179+
}
180+
}
181+
121182
pub fn local_addr(&self) -> io::Result<SocketAddr> {
122183
unsafe {
123184
let mut storage: libc::sockaddr_storage = mem::zeroed();

src/sys/windows.rs

+48
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,54 @@ impl Socket {
109109
}
110110
}
111111

112+
pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
113+
self.set_nonblocking(true)?;
114+
let r = self.connect(addr);
115+
self.set_nonblocking(true)?;
116+
117+
match r {
118+
Ok(()) => return Ok(()),
119+
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
120+
Err(e) => return Err(e),
121+
}
122+
123+
if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
124+
return Err(io::Error::new(io::ErrorKind::InvalidInput,
125+
"cannot set a 0 duration timeout"));
126+
}
127+
128+
let mut timeout = timeval {
129+
tv_sec: timeout.as_secs() as c_long,
130+
tv_usec: (timeout.subsec_nanos() / 1000) as c_long,
131+
};
132+
if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
133+
timeout.tv_usec = 1;
134+
}
135+
136+
let fds = unsafe {
137+
let mut fds = mem::zeroed::<fd_set>();
138+
fds.fd_count = 1;
139+
fds.fd_array[0] = self.socket;
140+
fds
141+
};
142+
143+
let mut writefds = fds;
144+
let mut errorfds = fds;
145+
146+
match unsafe { ws2_32::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout) } {
147+
SOCKET_ERROR => return Err(io::Error::last_os_error()),
148+
0 => return Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")),
149+
_ => {
150+
if writefds.fd_count != 1 {
151+
if let Some(e) = self.take_error()? {
152+
return Err(e);
153+
}
154+
}
155+
Ok(())
156+
}
157+
}
158+
}
159+
112160
pub fn local_addr(&self) -> io::Result<SocketAddr> {
113161
unsafe {
114162
let mut storage: SOCKADDR_STORAGE = mem::zeroed();

0 commit comments

Comments
 (0)