Skip to content

Commit

Permalink
simplify client io types. (#1021)
Browse files Browse the repository at this point in the history
  • Loading branch information
fakeshadow authored Apr 14, 2024
1 parent 9d31f9c commit d6610b1
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 132 deletions.
44 changes: 27 additions & 17 deletions client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ impl Client {
let conn = self.make_tcp(connect, timer).await?;

if matches!(connect.uri, Uri::Tcp(_)) {
return Ok((ConnectionExclusive::Tcp(conn), expected_version));
return Ok((conn, expected_version));
}

timer
Expand All @@ -275,22 +275,25 @@ impl Client {

let (conn, version) = self
.connector
.call((connect.hostname(), Box::new(conn)))
.call((connect.hostname(), conn))
.timeout(timer.as_mut())
.await
.map_err(|_| TimeoutError::TlsHandshake)??;

Ok((ConnectionExclusive::Tls(conn), version))
}
#[cfg(unix)]
Uri::Unix(_) => {
let conn = self.make_unix(connect, timer).await?;
Ok((ConnectionExclusive::Unix(conn), expected_version))
Ok((conn, version))
}
Uri::Unix(_) => self
.make_unix(connect, timer)
.await
.map(|conn| (conn, expected_version)),
}
}

async fn make_tcp(&self, connect: &mut Connect<'_>, timer: &mut Pin<Box<Sleep>>) -> Result<TcpStream, Error> {
async fn make_tcp(
&self,
connect: &mut Connect<'_>,
timer: &mut Pin<Box<Sleep>>,
) -> Result<ConnectionExclusive, Error> {
self.resolver
.call(connect)
.timeout(timer.as_mut())
Expand All @@ -310,7 +313,7 @@ impl Client {
// TODO: make nodelay configurable?
let _ = stream.set_nodelay(true);

Ok(stream)
Ok(Box::new(stream))
}

async fn make_tcp_inner(&self, connect: &Connect<'_>) -> Result<TcpStream, Error> {
Expand Down Expand Up @@ -353,12 +356,11 @@ impl Client {
}
}

#[cfg(unix)]
async fn make_unix(
&self,
connect: &Connect<'_>,
timer: &mut Pin<Box<Sleep>>,
) -> Result<xitca_io::net::UnixStream, Error> {
) -> Result<ConnectionExclusive, Error> {
timer
.as_mut()
.reset(Instant::now() + self.timeout_config.connect_timeout);
Expand All @@ -369,12 +371,20 @@ impl Client {
connect.uri.path_and_query().unwrap().as_str()
);

let stream = xitca_io::net::UnixStream::connect(path)
.timeout(timer.as_mut())
.await
.map_err(|_| TimeoutError::Connect)??;
#[cfg(unix)]
{
let stream = xitca_io::net::UnixStream::connect(path)
.timeout(timer.as_mut())
.await
.map_err(|_| TimeoutError::Connect)??;

Ok(Box::new(stream))
}

Ok(stream)
#[cfg(not(unix))]
{
unimplemented!("only unix supports unix domain socket")
}
}
}

Expand Down
112 changes: 2 additions & 110 deletions client/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
use core::{
hash::{Hash, Hasher},
pin::Pin,
task::{Context, Poll},
};

use std::io;
use core::hash::{Hash, Hasher};

use xitca_http::http::uri::{Authority, PathAndQuery};
use xitca_io::{
io::{AsyncIo, Interest, Ready},
net::TcpStream,
};

#[cfg(unix)]
use xitca_io::net::UnixStream;

use super::{tls::TlsStream, uri::Uri};

Expand All @@ -26,102 +13,7 @@ pub type H1ConnectionWithKey<'a> = crate::pool::exclusive::Conn<'a, ConnectionKe
pub type H1ConnectionWithoutKey = crate::pool::exclusive::PooledConn<ConnectionExclusive>;

/// exclusive connection for http1 and in certain case they can be upgraded to [ConnectionShared]
#[allow(clippy::large_enum_variant)]
#[non_exhaustive]
pub enum ConnectionExclusive {
Tcp(TcpStream),
Tls(TlsStream),
#[cfg(unix)]
Unix(UnixStream),
}

impl AsyncIo for ConnectionExclusive {
async fn ready(&mut self, interest: Interest) -> io::Result<Ready> {
match self {
Self::Tcp(ref mut io) => io.ready(interest).await,
Self::Tls(ref mut io) => io.ready(interest).await,
#[cfg(unix)]
Self::Unix(ref mut io) => io.ready(interest).await,
}
}

fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
match self {
Self::Tcp(ref mut io) => io.poll_ready(interest, cx),
Self::Tls(ref mut io) => io.poll_ready(interest, cx),
#[cfg(unix)]
Self::Unix(ref mut io) => io.poll_ready(interest, cx),
}
}

fn is_vectored_write(&self) -> bool {
match self {
Self::Tcp(ref io) => io.is_vectored_write(),
Self::Tls(ref io) => io.is_vectored_write(),
#[cfg(unix)]
Self::Unix(ref io) => io.is_vectored_write(),
}
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Tcp(io) => Pin::new(io).poll_shutdown(cx),
Self::Tls(io) => Pin::new(io).poll_shutdown(cx),
#[cfg(unix)]
Self::Unix(io) => Pin::new(io).poll_shutdown(cx),
}
}
}

impl io::Read for ConnectionExclusive {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(ref mut io) => io.read(buf),
Self::Tls(ref mut io) => io.read(buf),
#[cfg(unix)]
Self::Unix(ref mut io) => io.read(buf),
}
}
}

impl io::Write for ConnectionExclusive {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(ref mut io) => io.write(buf),
Self::Tls(ref mut io) => io.write(buf),
#[cfg(unix)]
Self::Unix(ref mut io) => io.write(buf),
}
}

fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(ref mut io) => io.flush(),
Self::Tls(ref mut io) => io.flush(),
#[cfg(unix)]
Self::Unix(ref mut io) => io.flush(),
}
}
}

impl From<TcpStream> for ConnectionExclusive {
fn from(tcp: TcpStream) -> Self {
Self::Tcp(tcp)
}
}

impl From<TlsStream> for ConnectionExclusive {
fn from(io: TlsStream) -> Self {
Self::Tls(io)
}
}

#[cfg(unix)]
impl From<UnixStream> for ConnectionExclusive {
fn from(unix: UnixStream) -> Self {
Self::Unix(unix)
}
}
pub type ConnectionExclusive = TlsStream;

/// high level shared connection that support multiplexing over single socket
/// used for http2 and http3
Expand Down
2 changes: 1 addition & 1 deletion client/src/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pub(crate) mod connector;

pub type TlsStream = Box<dyn xitca_io::io::AsyncIoDyn + Send + Sync>;
pub type TlsStream = Box<dyn xitca_io::io::AsyncIoDyn + Send>;
4 changes: 0 additions & 4 deletions client/src/uri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use crate::{error::InvalidUri, http::uri};
pub enum Uri<'a> {
Tcp(&'a uri::Uri),
Tls(&'a uri::Uri),
#[cfg(unix)]
Unix(&'a uri::Uri),
}

Expand All @@ -19,7 +18,6 @@ impl Deref for Uri<'_> {
match *self {
Self::Tcp(uri) => uri,
Self::Tls(uri) => uri,
#[cfg(unix)]
Self::Unix(uri) => uri,
}
}
Expand All @@ -33,7 +31,6 @@ impl<'a> Uri<'a> {
(None, _, _) => Err(InvalidUri::MissingScheme),
(Some("http" | "ws"), _, _) => Ok(Uri::Tcp(uri)),
(Some("https" | "wss"), _, _) => Ok(Uri::Tls(uri)),
#[cfg(unix)]
(Some("unix"), _, _) => Ok(Uri::Unix(uri)),
(Some(_), _, _) => Err(InvalidUri::UnknownScheme),
}
Expand Down Expand Up @@ -62,7 +59,6 @@ mod test {
let _ = Uri::try_parse(&uri).unwrap();
}

#[cfg(unix)]
#[test]
fn uds_parse() {
let uri = uri::Uri::from_static("unix://tmp/foo.socket");
Expand Down
11 changes: 11 additions & 0 deletions postgres/src/proxy.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
//! proxy serves as a sample implementation of server side traffic forwarder
//! between a xitca-postgres Client with `quic` feature enabled and the postgres
//! database server.
use std::{
collections::HashSet,
error, fs,
Expand Down Expand Up @@ -119,10 +123,17 @@ fn cfg_from_cert(cert: impl AsRef<Path>, key: impl AsRef<Path>) -> Result<Server
async fn listen_task(conn: Incoming, addr: SocketAddr) -> Result<(), Error> {
let conn = conn.await?;

// bridge quic client connection to tcp connection to database.
let mut upstream = TcpStream::connect(addr).await?;

// the proxy does not multiplex over streams from a quic client connection but it's not a hard
// requirement.
// an alternative proxy implementation can multiplex tcp socket connection to database with
// additional bidirectional stream.
let (mut tx, mut rx) = conn.accept_bi().await?;

// loop and copy bytes between the quic stream and tcp socket.

let mut buf = [0; 4096];

loop {
Expand Down

0 comments on commit d6610b1

Please sign in to comment.