Skip to content

Commit

Permalink
Merge pull request #569 from jmlaka/add_conn_timeout
Browse files Browse the repository at this point in the history
- add connection timeout
  • Loading branch information
Enet4 authored Oct 12, 2024
2 parents 6eee09a + 13d847e commit e2162f2
Showing 1 changed file with 55 additions and 9 deletions.
64 changes: 55 additions & 9 deletions ul/src/association/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use std::{
borrow::Cow,
convert::TryInto,
io::Write,
net::{TcpStream, ToSocketAddrs}, time::Duration,
net::{TcpStream, ToSocketAddrs},
time::Duration,
};

use crate::{
Expand All @@ -34,20 +35,31 @@ pub enum Error {
/// missing abstract syntax to begin negotiation
MissingAbstractSyntax { backtrace: Backtrace },

/// could not convert to socket address
ToAddress {
source: std::io::Error,
backtrace: Backtrace,
},

/// converted SocketAddrs iterator did not yield
#[snafu(display("not a single tcp addreess provided"))]
#[non_exhaustive]
NoAddress { backtrace: Backtrace },

/// could not connect to server
Connect {
source: std::io::Error,
backtrace: Backtrace,
},

/// Could not set tcp read timeout
SetReadTimeout{
SetReadTimeout {
source: std::io::Error,
backtrace: Backtrace,
},

/// Could not set tcp write timeout
SetWriteTimeout{
SetWriteTimeout {
source: std::io::Error,
backtrace: Backtrace,
},
Expand Down Expand Up @@ -193,6 +205,8 @@ pub struct ClientAssociationOptions<'a> {
read_timeout: Option<Duration>,
/// TCP write timeout
write_timeout: Option<Duration>,
/// TCP connection timeout
connection_timeout: Option<Duration>,
}

impl<'a> Default for ClientAssociationOptions<'a> {
Expand All @@ -216,6 +230,7 @@ impl<'a> Default for ClientAssociationOptions<'a> {
jwt: None,
read_timeout: None,
write_timeout: None,
connection_timeout: None,
}
}
}
Expand Down Expand Up @@ -465,6 +480,14 @@ impl<'a> ClientAssociationOptions<'a> {
}
}

/// Set the connection timeout for the underlying TCP socket
pub fn connection_timeout(self, timeout: Duration) -> Self {
Self {
connection_timeout: Some(timeout),
..self
}
}

fn establish_impl<T>(self, ae_address: AeAddr<T>) -> Result<ClientAssociation>
where
T: ToSocketAddrs,
Expand All @@ -483,7 +506,8 @@ impl<'a> ClientAssociationOptions<'a> {
saml_assertion,
jwt,
read_timeout,
write_timeout
write_timeout,
connection_timeout,
} = self;

// fail if no presentation contexts were provided: they represent intent,
Expand Down Expand Up @@ -546,11 +570,33 @@ impl<'a> ClientAssociationOptions<'a> {
user_variables,
});

let mut socket = std::net::TcpStream::connect(ae_address)
.context(ConnectSnafu)?;
socket.set_read_timeout(read_timeout)
let conn_result: Result<TcpStream> = if let Some(timeout) = connection_timeout {
let mut addresses = ae_address.to_socket_addrs().context(ToAddressSnafu)?;

if addresses.by_ref().count() == 0 {
return NoAddressSnafu.fail();
}

let mut result: Result<TcpStream, std::io::Error> =
Result::Err(std::io::Error::from(std::io::ErrorKind::NotConnected));

for address in addresses {
result = std::net::TcpStream::connect_timeout(&address, timeout);
if result.is_ok() {
break;
}
}
result.context(ConnectSnafu)
} else {
std::net::TcpStream::connect(ae_address).context(ConnectSnafu)
};

let mut socket = conn_result?;
socket
.set_read_timeout(read_timeout)
.context(SetReadTimeoutSnafu)?;
socket.set_write_timeout(write_timeout)
socket
.set_write_timeout(write_timeout)
.context(SetWriteTimeoutSnafu)?;
let mut buffer: Vec<u8> = Vec::with_capacity(max_pdu_length as usize);
// send request
Expand Down

0 comments on commit e2162f2

Please sign in to comment.