Skip to content

Commit

Permalink
add shared pg client that able to renew it's internal state. (#973)
Browse files Browse the repository at this point in the history
  • Loading branch information
fakeshadow authored Mar 9, 2024
1 parent bc4f0a3 commit 3da3650
Show file tree
Hide file tree
Showing 15 changed files with 375 additions and 143 deletions.
2 changes: 1 addition & 1 deletion postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ fallible-iterator = "0.2"
percent-encoding = "2"
postgres-protocol = "0.6.5"
postgres-types = "0.2"
tokio = { version = "1.30", features = ["net", "sync"] }
tokio = { version = "1.30", features = ["net", "rt", "sync", "time"] }
tracing = { version = "0.1.40", default-features = false }

# tls
Expand Down
6 changes: 3 additions & 3 deletions postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ impl Drop for Client {
};

if let Some(stmt) = type_info {
drop(stmt.into_guarded(self));
drop(stmt.into_guarded(&*self));
}

if let Some(stmt) = typeinfo_composite {
drop(stmt.into_guarded(self));
drop(stmt.into_guarded(&*self));
}

if let Some(stmt) = typeinfo_enum {
drop(stmt.into_guarded(self));
drop(stmt.into_guarded(&*self));
}
}
}
37 changes: 6 additions & 31 deletions postgres/src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,15 @@ impl Driver {
async fn run_till_closed(self) {
#[cfg(not(feature = "quic"))]
{
let mut this = self;
while let Err(e) = match this.inner {
_Driver::Tcp(ref mut drv) => drv.run().await,
let _ = match self.inner {
_Driver::Tcp(drv) => drv.run().await,
#[cfg(feature = "tls")]
_Driver::Tls(ref mut drv) => drv.run().await,
_Driver::Tls(drv) => drv.run().await,
#[cfg(unix)]
_Driver::Unix(ref mut drv) => drv.run().await,
_Driver::Unix(drv) => drv.run().await,
#[cfg(all(unix, feature = "tls"))]
_Driver::UnixTls(ref mut drv) => drv.run().await,
} {
while this.reconnect(&e).await.is_err() {}
}
_Driver::UnixTls(drv) => drv.run().await,
};
}

#[cfg(feature = "quic")]
Expand All @@ -114,28 +111,6 @@ impl Driver {

#[cfg(not(feature = "quic"))]
impl Driver {
/// reconnect to server with a fresh connection and state. Driver's associated
/// [Client] is able to be re-used for the fresh connection.
///
/// MUST be called when `<Self as AsyncLendingIterator>::try_next` emit [Error].
/// All in flight database query and response will be lost in the process.
pub async fn reconnect(&mut self, _: &Error) -> Result<(), Error> {
let (_, Driver { inner: inner_new, .. }) = connect(&mut self.config).await?;

match (&mut self.inner, inner_new) {
(_Driver::Tcp(drv), _Driver::Tcp(drv_new)) => drv.replace(drv_new),
#[cfg(feature = "tls")]
(_Driver::Tls(drv), _Driver::Tls(drv_new)) => drv.replace(drv_new),
#[cfg(unix)]
(_Driver::Unix(drv), _Driver::Unix(drv_new)) => drv.replace(drv_new),
#[cfg(all(unix, feature = "tls"))]
(_Driver::UnixTls(drv), _Driver::UnixTls(drv_new)) => drv.replace(drv_new),
_ => unreachable!("reconnect should always yield the same type of generic driver"),
};

Ok(())
}

pub(super) fn tcp(drv: GenericDriver<TcpStream>, config: Config) -> Self {
Self {
inner: _Driver::Tcp(drv),
Expand Down
28 changes: 6 additions & 22 deletions postgres/src/driver/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,6 @@ where
)
}

#[cfg(not(feature = "quic"))]
pub(crate) fn replace(&mut self, other: Self) {
let Self {
io,
write_buf,
read_buf,
res,
state,
..
} = other;
self.io = io;
self.write_buf = write_buf;
self.read_buf = read_buf;
self.res = res;
self.state = state;
}

async fn _try_next(&mut self) -> Result<Option<backend::Message>, Error> {
loop {
if let Some(msg) = self.try_decode()? {
Expand Down Expand Up @@ -142,7 +125,7 @@ where

// TODO: remove this feature gate.
#[cfg(not(feature = "quic"))]
pub(crate) async fn run(&mut self) -> Result<(), Error> {
pub(crate) async fn run(mut self) -> Result<(), Error> {
while self._try_next().await?.is_some() {}
Ok(())
}
Expand Down Expand Up @@ -180,10 +163,11 @@ where
while let Some(res) = ResponseMessage::try_from_buf(self.read_buf.get_mut())? {
match res {
ResponseMessage::Normal { buf, complete } => {
let front = self.res.front_mut().expect("out of bound must not happen");
front.send(buf);
if front.complete(complete) {
self.res.pop_front();
if let Some(front) = self.res.front_mut() {
front.send(buf);
if front.complete(complete) {
self.res.pop_front();
}
}
}
ResponseMessage::Async(msg) => return Ok(Some(msg)),
Expand Down
4 changes: 2 additions & 2 deletions postgres/src/driver/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use xitca_io::{

use crate::{
config::{Config, Host, SslMode},
error::{unexpected_eof_err, write_zero_err, Error},
error::{unexpected_eof_err, Error},
session::prepare_session,
};

Expand Down Expand Up @@ -147,7 +147,7 @@ where
while !buf.is_empty() {
io.ready(Interest::WRITABLE).await?;
match io.write(&buf) {
Ok(0) => return Err(write_zero_err()),
Ok(0) => return Err(unexpected_eof_err()),
Ok(n) => buf.advance(n),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => return Err(e),
Expand Down
7 changes: 2 additions & 5 deletions postgres/src/driver/raw/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ use postgres_protocol::message::backend;
use tokio::sync::mpsc::unbounded_channel;
use xitca_io::bytes::BytesMut;

use crate::{
driver::codec::ResponseReceiver,
error::{unexpected_eof_err, Error},
};
use crate::{driver::codec::ResponseReceiver, error::Error};

pub struct Response {
rx: ResponseReceiver,
Expand All @@ -35,7 +32,7 @@ impl Response {
pub(crate) fn recv(&mut self) -> impl Future<Output = Result<backend::Message, Error>> + '_ {
poll_fn(|cx| {
if self.buf.is_empty() {
self.buf = ready!(self.rx.poll_recv(cx)).ok_or_else(unexpected_eof_err)?;
self.buf = ready!(self.rx.poll_recv(cx)).ok_or_else(|| Error::DriverDown(BytesMut::new()))?;
}

let res = match backend::Message::parse(&mut self.buf)?.expect("must not parse message from empty buffer.")
Expand Down
26 changes: 14 additions & 12 deletions postgres/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use core::{convert::Infallible, fmt};
use std::{error, io};

use tokio::sync::mpsc::error::SendError;
use xitca_io::bytes::BytesMut;

use crate::driver::codec::Request;

use super::from_sql::FromSqlError;

Expand All @@ -15,6 +18,7 @@ pub enum Error {
Io(io::Error),
FromSql(FromSqlError),
InvalidColumnIndex(String),
DriverDown(BytesMut),
ToDo,
}

Expand All @@ -27,6 +31,7 @@ impl fmt::Display for Error {
Self::Io(ref e) => fmt::Display::fmt(e, f),
Self::FromSql(ref e) => fmt::Display::fmt(e, f),
Self::InvalidColumnIndex(ref name) => write!(f, "invalid column {name}"),
Self::DriverDown(_) => f.write_str("Driver is down. check Driver's async task output for reason"),
Self::ToDo => f.write_str("error informant is yet implemented"),
}
}
Expand All @@ -52,9 +57,15 @@ impl From<FromSqlError> for Error {
}
}

impl<T> From<SendError<T>> for Error {
fn from(_: SendError<T>) -> Self {
Error::from(write_zero_err())
impl From<SendError<BytesMut>> for Error {
fn from(e: SendError<BytesMut>) -> Self {
Self::DriverDown(e.0)
}
}

impl From<SendError<Request>> for Error {
fn from(e: SendError<Request>) -> Self {
Self::DriverDown(e.0.msg)
}
}

Expand Down Expand Up @@ -114,12 +125,3 @@ pub(crate) fn unexpected_eof_err() -> io::Error {
"zero byte read. remote close connection unexpectedly",
)
}

#[cold]
#[inline(never)]
pub(crate) fn write_zero_err() -> io::Error {
io::Error::new(
io::ErrorKind::WriteZero,
"zero byte written. remote close connection unexpectedly",
)
}
2 changes: 2 additions & 0 deletions postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod config;
mod driver;
mod from_sql;
mod iter;
mod pool;
mod prepare;
mod query;
mod session;
Expand All @@ -34,6 +35,7 @@ pub use self::{
error::Error,
from_sql::FromSqlExt,
iter::AsyncLendingIterator,
pool::SharedClient,
query::{RowSimpleStream, RowStream},
};

Expand Down
Loading

0 comments on commit 3da3650

Please sign in to comment.