Skip to content

Commit

Permalink
add reconnect api to driver. (#971)
Browse files Browse the repository at this point in the history
* add reconnect api to driver.

* clippy fix.

* fmt fix.

* fix io-uring feature build.

* fix single thread feature.
  • Loading branch information
fakeshadow authored Mar 6, 2024
1 parent f93419b commit df47136
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 205 deletions.
93 changes: 68 additions & 25 deletions postgres/src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ use xitca_tls::rustls::{ClientConnection, TlsStream};
#[cfg(unix)]
use xitca_io::net::UnixStream;

pub(super) async fn connect(mut cfg: Config) -> Result<(Client, Driver), Error> {
pub(super) async fn connect(cfg: &mut Config) -> Result<(Client, Driver), Error> {
let mut err = None;
let hosts = cfg.get_hosts().to_vec();
for host in hosts {
match _connect(host, &mut cfg).await {
Ok(t) => return Ok(t),
match _connect(host, cfg).await {
Ok((tx, drv)) => return Ok((Client::new(tx), drv)),
Err(e) => err = Some(e),
}
}
Expand Down Expand Up @@ -80,43 +80,100 @@ pub(super) async fn connect(mut cfg: Config) -> Result<(Client, Driver), Error>
/// ```
pub struct Driver {
inner: _Driver,
#[allow(dead_code)]
config: Config,
}

impl Driver {
// run till the connection is closed by Client.
async fn run_till_closed(self) {
#[cfg(not(feature = "quic"))]
{
let mut this = self;
while let Err(e) = match this.inner {
_Driver::Tcp(ref mut drv) => drv.run().await,
#[cfg(feature = "tls")]
_Driver::Tls(ref mut drv) => drv.run().await,
#[cfg(unix)]
_Driver::Unix(ref mut drv) => drv.run().await,
#[cfg(all(unix, feature = "tls"))]
_Driver::UnixTls(ref mut drv) => drv.run().await,
} {
while this.reconnect(&e).await.is_err() {}
}
}

#[cfg(feature = "quic")]
match self.inner {
_Driver::Quic(drv) => {
let _ = drv.run().await;
}
}
}
}

#[cfg(not(feature = "quic"))]
impl Driver {
pub(super) fn tcp(drv: GenericDriver<TcpStream>) -> Self {
/// reconnect to server with a fresh connection and state. Driver's associated
/// [Client] is able to be re-used for the fresh connection.
///
/// MUST be called when `<Self as AsyncLendingIterator>::try_next` emit [Error].
/// All in flight database query and response will be lost in the process.
pub async fn reconnect(&mut self, _: &Error) -> Result<(), Error> {
let (_, Driver { inner: inner_new, .. }) = connect(&mut self.config).await?;

match (&mut self.inner, inner_new) {
(_Driver::Tcp(drv), _Driver::Tcp(drv_new)) => drv.replace(drv_new),
#[cfg(feature = "tls")]
(_Driver::Tls(drv), _Driver::Tls(drv_new)) => drv.replace(drv_new),
#[cfg(unix)]
(_Driver::Unix(drv), _Driver::Unix(drv_new)) => drv.replace(drv_new),
#[cfg(all(unix, feature = "tls"))]
(_Driver::UnixTls(drv), _Driver::UnixTls(drv_new)) => drv.replace(drv_new),
_ => unreachable!("reconnect should always yield the same type of generic driver"),
};

Ok(())
}

pub(super) fn tcp(drv: GenericDriver<TcpStream>, config: Config) -> Self {
Self {
inner: _Driver::Tcp(drv),
config,
}
}

#[cfg(feature = "tls")]
pub(super) fn tls(drv: GenericDriver<TlsStream<ClientConnection, TcpStream>>) -> Self {
pub(super) fn tls(drv: GenericDriver<TlsStream<ClientConnection, TcpStream>>, config: Config) -> Self {
Self {
inner: _Driver::Tls(drv),
config,
}
}

#[cfg(unix)]
pub(super) fn unix(drv: GenericDriver<UnixStream>) -> Self {
pub(super) fn unix(drv: GenericDriver<UnixStream>, config: Config) -> Self {
Self {
inner: _Driver::Unix(drv),
config,
}
}

#[cfg(all(unix, feature = "tls"))]
pub(super) fn unix_tls(drv: GenericDriver<TlsStream<ClientConnection, UnixStream>>) -> Self {
pub(super) fn unix_tls(drv: GenericDriver<TlsStream<ClientConnection, UnixStream>>, config: Config) -> Self {
Self {
inner: _Driver::UnixTls(drv),
config,
}
}
}

#[cfg(feature = "quic")]
impl Driver {
pub(super) fn quic(drv: QuicDriver) -> Self {
pub(super) fn quic(drv: QuicDriver, config: Config) -> Self {
Self {
inner: _Driver::Quic(drv),
config,
}
}
}
Expand Down Expand Up @@ -163,25 +220,11 @@ impl AsyncLendingIterator for Driver {
}

impl IntoFuture for Driver {
type Output = Result<(), Error>;
type Output = ();
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;

fn into_future(self) -> Self::IntoFuture {
#[cfg(not(feature = "quic"))]
match self.inner {
_Driver::Tcp(drv) => Box::pin(drv.run()),
#[cfg(feature = "tls")]
_Driver::Tls(drv) => Box::pin(drv.run()),
#[cfg(unix)]
_Driver::Unix(drv) => Box::pin(drv.run()),
#[cfg(all(unix, feature = "tls"))]
_Driver::UnixTls(drv) => Box::pin(drv.run()),
}

#[cfg(feature = "quic")]
match self.inner {
_Driver::Quic(drv) => Box::pin(drv.run()),
}
Box::pin(self.run_till_closed())
}
}

Expand Down Expand Up @@ -210,7 +253,7 @@ impl Driver {
let tcp = xitca_io::net::io_uring::TcpStream::from_std(std);
io_uring::IoUringDriver::new(
tcp,
drv.rx.unwrap(),
drv.rx,
drv.write_buf.into_inner(),
drv.read_buf.into_inner(),
drv.res,
Expand Down
79 changes: 53 additions & 26 deletions postgres/src/driver/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@ pub(crate) struct GenericDriver<Io> {
pub(crate) io: Io,
pub(crate) write_buf: WriteBuf,
pub(crate) read_buf: PagedBytesMut,
pub(crate) rx: Option<GenericDriverRx>,
pub(crate) rx: GenericDriverRx,
pub(crate) res: VecDeque<ResponseSender>,
state: DriverState,
}

enum DriverState {
Running,
Closing(Option<io::Error>),
}

impl<Io> GenericDriver<Io>
Expand All @@ -46,14 +52,32 @@ where
io,
write_buf: WriteBuf::new(),
read_buf: PagedBytesMut::new(),
rx: Some(rx),
rx,
res: VecDeque::new(),
state: DriverState::Running,
},
tx,
)
}

pub(crate) async fn try_next(&mut self) -> Result<Option<backend::Message>, Error> {
#[cfg(not(feature = "quic"))]
pub(crate) fn replace(&mut self, other: Self) {
let Self {
io,
write_buf,
read_buf,
res,
state,
..
} = other;
self.io = io;
self.write_buf = write_buf;
self.read_buf = read_buf;
self.res = res;
self.state = state;
}

async fn _try_next(&mut self) -> Result<Option<backend::Message>, Error> {
loop {
if let Some(msg) = self.try_decode()? {
return Ok(Some(msg));
Expand All @@ -65,19 +89,20 @@ where
Interest::READABLE
};

let select = match self.rx {
Some(ref mut rx) => {
let select = match self.state {
DriverState::Running => {
let ready = self.io.ready(interest);
rx.recv().select(ready).await
self.rx.recv().select(ready).await
}
None => {
DriverState::Closing(ref mut e) => {
if !interest.is_writable() && self.res.is_empty() {
// no interest to write to io and all response have been finished so
// shutdown io and exit.
// if there is a better way to exhaust potential remaining backend message
// please file an issue.
poll_fn(|cx| Pin::new(&mut self.io).poll_shutdown(cx)).await?;
return Ok(None);

return e.take().map(|e| Err(e.into())).transpose();
}
let ready = self.io.ready(interest);
SelectOutput::B(ready.await)
Expand All @@ -95,25 +120,34 @@ where
if ready.is_readable() {
self.try_read()?;
}
if ready.is_writable() && self.try_write().is_err() {
// write failed as server stopped reading.
// drop channel so all pending request in it can be notified.
self.rx = None;
if ready.is_writable() {
if let Err(e) = self.try_write() {
error!("server closed read half unexpectedly: {e}");

// when write error occur the driver would go into half close state(read only).
// clearing write_buf would drop all pending requests in it and hint the driver
// no future Interest::WRITABLE should be passed to AsyncIo::ready method.
self.write_buf.clear();

// enter closed state and no more request would be received from channel.
// requests inside it would eventually be dropped after shutdown completed.
self.state = DriverState::Closing(Some(e));
}
}
}
SelectOutput::A(None) => self.rx = None,
SelectOutput::A(None) => self.state = DriverState::Closing(None),
}
}
}

// TODO: remove this feature gate.
#[cfg(not(feature = "quic"))]
pub(crate) async fn run(mut self) -> Result<(), Error> {
while self.try_next().await?.is_some() {}
pub(crate) async fn run(&mut self) -> Result<(), Error> {
while self._try_next().await?.is_some() {}
Ok(())
}

pub(crate) async fn send(&mut self, msg: BytesMut) -> Result<(), Error> {
async fn send(&mut self, msg: BytesMut) -> Result<(), Error> {
self.write_buf_extend(&msg);
loop {
self.try_write()?;
Expand Down Expand Up @@ -151,14 +185,7 @@ where
}

fn try_write(&mut self) -> io::Result<()> {
self.write_buf.do_io(&mut self.io).map_err(|e| {
// when write error occur the driver would go into half close state(read only).
// clearing write_buf would drop all pending requests in it and hint the driver no
// future Interest::READABLE should be passed to AsyncIo::ready method.
self.write_buf.clear();
error!("server closed read half unexpectedly: {e}");
e
})
self.write_buf.do_io(&mut self.io)
}

fn try_decode(&mut self) -> Result<Option<backend::Message>, Error> {
Expand Down Expand Up @@ -186,8 +213,8 @@ where
type Err = Error;

#[inline]
async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
self.try_next().await
fn try_next(&mut self) -> impl Future<Output = Result<Option<Self::Ok<'_>>, Self::Err>> + Send {
self._try_next()
}
}

Expand Down
9 changes: 4 additions & 5 deletions postgres/src/driver/quic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ use quinn_proto::ConnectionError;
use xitca_io::bytes::{Bytes, BytesMut};

use crate::{
client::Client,
config::{Config, Host},
error::{unexpected_eof_err, Error},
iter::AsyncLendingIterator,
session::prepare_session,
};

use super::{Drive, Driver};
Expand Down Expand Up @@ -82,16 +82,15 @@ impl ClientTx {

#[cold]
#[inline(never)]
pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(Client, Driver), Error> {
pub(super) async fn _connect(host: Host, cfg: &Config) -> Result<(ClientTx, Driver), Error> {
match host {
Host::Udp(ref host) => {
let tx = connect_quic(host, cfg.get_ports()).await?;
let streams = tx.inner.open_bi().await.unwrap();
let mut drv = QuicDriver::new(streams);
let mut cli = Client::new(tx);
cli.prepare_session(&mut drv, cfg).await?;
prepare_session(&mut drv, cfg).await?;
drv.close_tx().await;
Ok((cli, Driver::quic(drv)))
Ok((tx, Driver::quic(drv, cfg.clone())))
}
_ => unreachable!(),
}
Expand Down
Loading

0 comments on commit df47136

Please sign in to comment.