Skip to content

Commit

Permalink
expose Driver enum and GenericDriver to public.
Browse files Browse the repository at this point in the history
  • Loading branch information
fakeshadow committed Apr 4, 2024
1 parent 30301ae commit 0eebca4
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 79 deletions.
101 changes: 29 additions & 72 deletions postgres/src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,27 @@ where
/// tokio::spawn(drv.into_future());
/// }
/// ```
pub struct Driver {
inner: _Driver,
// TODO: use Box<dyn AsyncIterator> when life time GAT is object safe.
pub enum Driver {
Tcp(GenericDriver<TcpStream>),
Dynamic(GenericDriver<Box<dyn AsyncIoDyn + Send>>),
#[cfg(feature = "tls")]
Tls(GenericDriver<TlsStream<ClientConnection, TcpStream>>),
#[cfg(unix)]
Unix(GenericDriver<UnixStream>),
#[cfg(all(unix, feature = "tls"))]
UnixTls(GenericDriver<TlsStream<ClientConnection, UnixStream>>),
#[cfg(feature = "quic")]
Quic(GenericDriver<crate::driver::quic::QuicStream>),
}

impl Driver {
#[cfg(feature = "io-uring")]
/// downcast [Driver] to IoUringDriver if it's Tcp variant.
/// IoUringDriver can not be a new variant of Dirver as it's !Send.
pub fn try_into_io_uring_tcp(self) -> io_uring::IoUringDriver<xitca_io::net::io_uring::TcpStream> {
match self.inner {
_Driver::Tcp(drv) => {
match self {
Self::Tcp(drv) => {
let std = drv.io.into_std().unwrap();
let tcp = xitca_io::net::io_uring::TcpStream::from_std(std);
io_uring::IoUringDriver::new(
Expand All @@ -105,94 +115,40 @@ impl Driver {
}
}

pub(super) fn tcp(drv: GenericDriver<TcpStream>) -> Self {
Self {
inner: _Driver::Tcp(drv),
}
}

pub(super) fn dynamic(drv: GenericDriver<Box<dyn AsyncIoDyn + Send>>) -> Self {
Self {
inner: _Driver::Dynamic(drv),
}
}

#[cfg(feature = "tls")]
pub(super) fn tls(drv: GenericDriver<TlsStream<ClientConnection, TcpStream>>) -> Self {
Self {
inner: _Driver::Tls(drv),
}
}

#[cfg(unix)]
pub(super) fn unix(drv: GenericDriver<UnixStream>) -> Self {
Self {
inner: _Driver::Unix(drv),
}
}

#[cfg(all(unix, feature = "tls"))]
pub(super) fn unix_tls(drv: GenericDriver<TlsStream<ClientConnection, UnixStream>>) -> Self {
Self {
inner: _Driver::UnixTls(drv),
}
}

#[cfg(feature = "quic")]
pub(super) fn quic(drv: GenericDriver<crate::driver::quic::QuicStream>) -> Self {
Self {
inner: _Driver::Quic(drv),
}
}

// run till the connection is closed by Client.
async fn run_till_closed(self) {
let _ = match self.inner {
_Driver::Tcp(drv) => drv.run().await,
_Driver::Dynamic(drv) => drv.run().await,
let _ = match self {
Self::Tcp(drv) => drv.run().await,
Self::Dynamic(drv) => drv.run().await,
#[cfg(feature = "tls")]
_Driver::Tls(drv) => drv.run().await,
Self::Tls(drv) => drv.run().await,
#[cfg(unix)]
_Driver::Unix(drv) => drv.run().await,
Self::Unix(drv) => drv.run().await,
#[cfg(all(unix, feature = "tls"))]
_Driver::UnixTls(drv) => drv.run().await,
Self::UnixTls(drv) => drv.run().await,
#[cfg(feature = "quic")]
_Driver::Quic(drv) => drv.run().await,
Self::Quic(drv) => drv.run().await,
};
}
}

// TODO: use Box<dyn AsyncIterator> when life time GAT is object safe.
enum _Driver {
Tcp(GenericDriver<TcpStream>),
Dynamic(GenericDriver<Box<dyn AsyncIoDyn + Send>>),
#[cfg(feature = "tls")]
Tls(GenericDriver<TlsStream<ClientConnection, TcpStream>>),
#[cfg(unix)]
Unix(GenericDriver<UnixStream>),
#[cfg(all(unix, feature = "tls"))]
UnixTls(GenericDriver<TlsStream<ClientConnection, UnixStream>>),
#[cfg(feature = "quic")]
Quic(GenericDriver<crate::driver::quic::QuicStream>),
}

impl AsyncLendingIterator for Driver {
type Ok<'i> = backend::Message where Self: 'i;
type Err = Error;

#[inline]
async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
match self.inner {
_Driver::Tcp(ref mut drv) => drv.try_next().await,
_Driver::Dynamic(ref mut drv) => drv.try_next().await,
match self {
Self::Tcp(ref mut drv) => drv.try_next().await,
Self::Dynamic(ref mut drv) => drv.try_next().await,
#[cfg(feature = "tls")]
_Driver::Tls(ref mut drv) => drv.try_next().await,
Self::Tls(ref mut drv) => drv.try_next().await,
#[cfg(unix)]
_Driver::Unix(ref mut drv) => drv.try_next().await,
Self::Unix(ref mut drv) => drv.try_next().await,
#[cfg(all(unix, feature = "tls"))]
_Driver::UnixTls(ref mut drv) => drv.try_next().await,
Self::UnixTls(ref mut drv) => drv.try_next().await,
#[cfg(feature = "quic")]
_Driver::Quic(ref mut drv) => drv.try_next().await,
Self::Quic(ref mut drv) => drv.try_next().await,
}
}
}
Expand All @@ -206,6 +162,7 @@ impl IntoFuture for Driver {
}
}

// helper trait for interacting with io driver directly.
pub(crate) trait Drive: Send {
fn send(&mut self, msg: BytesMut) -> impl Future<Output = Result<(), Error>> + Send;

Expand Down
12 changes: 6 additions & 6 deletions postgres/src/driver/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub(super) async fn connect(host: Host, cfg: &mut Config) -> Result<(DriverTx, D
let io = super::tls::connect_tls(io, host, cfg).await?;
let (mut drv, tx) = GenericDriver::new(io);
prepare_session(&mut drv, cfg).await?;
Ok((tx, Driver::tls(drv)))
Ok((tx, Driver::Tls(drv)))
}
#[cfg(not(feature = "tls"))]
{
Expand All @@ -41,7 +41,7 @@ pub(super) async fn connect(host: Host, cfg: &mut Config) -> Result<(DriverTx, D
} else {
let (mut drv, tx) = GenericDriver::new(io);
prepare_session(&mut drv, cfg).await?;
Ok((tx, Driver::tcp(drv)))
Ok((tx, Driver::Tcp(drv)))
}
}
Host::Unix(ref _host) => {
Expand All @@ -55,7 +55,7 @@ pub(super) async fn connect(host: Host, cfg: &mut Config) -> Result<(DriverTx, D
let io = super::tls::connect_tls(io, host.as_ref(), cfg).await?;
let (mut drv, tx) = GenericDriver::new(io);
prepare_session(&mut drv, cfg).await?;
Ok((tx, Driver::unix_tls(drv)))
Ok((tx, Driver::UnixTls(drv)))
}
#[cfg(not(feature = "tls"))]
{
Expand All @@ -64,7 +64,7 @@ pub(super) async fn connect(host: Host, cfg: &mut Config) -> Result<(DriverTx, D
} else {
let (mut drv, tx) = GenericDriver::new(io);
prepare_session(&mut drv, cfg).await?;
Ok((tx, Driver::unix(drv)))
Ok((tx, Driver::Unix(drv)))
}
}

Expand All @@ -79,7 +79,7 @@ pub(super) async fn connect(host: Host, cfg: &mut Config) -> Result<(DriverTx, D
let io = super::quic::connect_quic(_host, cfg.get_ports()).await?;
let (mut drv, tx) = GenericDriver::new(io);
prepare_session(&mut drv, cfg).await?;
Ok((tx, Driver::quic(drv)))
Ok((tx, Driver::Quic(drv)))
}
#[cfg(not(feature = "quic"))]
{
Expand All @@ -97,7 +97,7 @@ where
{
let (mut drv, tx) = GenericDriver::new(Box::new(io) as _);
prepare_session(&mut drv, cfg).await?;
Ok((tx, Driver::dynamic(drv)))
Ok((tx, Driver::Dynamic(drv)))
}

async fn connect_tcp(host: &str, ports: &[u16]) -> Result<TcpStream, Error> {
Expand Down
2 changes: 1 addition & 1 deletion postgres/src/driver/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl DriverTx {
}
}

pub(crate) struct GenericDriver<Io> {
pub struct GenericDriver<Io> {
pub(crate) io: Io,
pub(crate) write_buf: WriteBuf,
pub(crate) read_buf: PagedBytesMut,
Expand Down

0 comments on commit 0eebca4

Please sign in to comment.