From 465cd7aab1fba67156ab21070560d2eaf72ed4d2 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Fri, 15 Mar 2024 01:49:19 +0800 Subject: [PATCH 1/4] rework crate error handling. --- postgres/src/config.rs | 24 ++--- postgres/src/driver/quic.rs | 8 +- postgres/src/driver/quic/response.rs | 2 +- postgres/src/driver/raw.rs | 2 +- postgres/src/driver/raw/response.rs | 9 +- postgres/src/driver/raw/tls.rs | 4 +- postgres/src/error.rs | 142 +++++++++++++++++++++------ postgres/src/pipeline.rs | 6 +- postgres/src/pool.rs | 70 ++++++++----- postgres/src/prepare.rs | 12 +-- postgres/src/query/base.rs | 8 +- postgres/src/query/decode.rs | 2 +- postgres/src/query/encode.rs | 6 +- postgres/src/query/simple.rs | 2 +- postgres/src/row/types.rs | 11 ++- postgres/src/session.rs | 14 +-- 16 files changed, 215 insertions(+), 107 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 3c29e6d5..60c2cf3e 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -245,7 +245,7 @@ impl Config { "disable" => SslMode::Disable, "prefer" => SslMode::Prefer, "require" => SslMode::Require, - _ => return Err(Error::ToDo), + _ => return Err(Error::todo()), }; self.ssl_mode(mode); } @@ -259,7 +259,7 @@ impl Config { let port = if port.is_empty() { 5432 } else { - port.parse().map_err(|_| Error::ToDo)? + port.parse().map_err(|_| Error::todo())? }; self.port(port); } @@ -269,13 +269,13 @@ impl Config { "any" => TargetSessionAttrs::Any, "read-write" => TargetSessionAttrs::ReadWrite, _ => { - return Err(Error::ToDo); + return Err(Error::todo()); } }; self.target_session_attrs(target_session_attrs); } _ => { - return Err(Error::ToDo); + return Err(Error::todo()); } } @@ -375,9 +375,9 @@ impl<'a> Parser<'a> { Some((_, c)) if c == target => Ok(()), Some((i, c)) => { let _m = format!("unexpected character at byte {i}: expected `{target}` but got `{c}`"); - Err(Error::ToDo) + Err(Error::todo()) } - None => Err(Error::ToDo), + None => Err(Error::todo()), } } @@ -436,7 +436,7 @@ impl<'a> Parser<'a> { } if value.is_empty() { - return Err(Error::ToDo); + return Err(Error::todo()); } Ok(value) @@ -460,7 +460,7 @@ impl<'a> Parser<'a> { } } - Err(Error::ToDo) + Err(Error::todo()) } fn parameter(&mut self) -> Result, Error> { @@ -566,7 +566,7 @@ impl<'a> UrlParser<'a> { let (host, port) = if chunk.starts_with('[') { let idx = match chunk.find(']') { Some(idx) => idx, - None => return Err(Error::ToDo), + None => return Err(Error::todo()), }; let host = &chunk[1..idx]; @@ -576,7 +576,7 @@ impl<'a> UrlParser<'a> { } else if remaining.is_empty() { None } else { - return Err(Error::ToDo); + return Err(Error::todo()); }; (host, port) @@ -620,7 +620,7 @@ impl<'a> UrlParser<'a> { while !self.s.is_empty() { let key = match self.take_until(&['=']) { Some(key) => self.decode(key)?, - None => return Err(Error::ToDo), + None => return Err(Error::todo()), }; self.eat_byte(); @@ -651,6 +651,6 @@ impl<'a> UrlParser<'a> { fn decode(&self, s: &'a str) -> Result, Error> { percent_encoding::percent_decode(s.as_bytes()) .decode_utf8() - .map_err(|_| Error::ToDo) + .map_err(|_| Error::todo()) } } diff --git a/postgres/src/driver/quic.rs b/postgres/src/driver/quic.rs index 2bfab816..f5f923f7 100644 --- a/postgres/src/driver/quic.rs +++ b/postgres/src/driver/quic.rs @@ -150,9 +150,9 @@ async fn connect_quic(host: &str, ports: &[u16]) -> Result { match endpoint.connect(addr, host) { Ok(conn) => match conn.await { Ok(inner) => return Ok(ClientTx::new(inner)), - Err(_) => err = Some(Error::ToDo), + Err(_) => err = Some(Error::todo()), }, - Err(_) => err = Some(Error::ToDo), + Err(_) => err = Some(Error::todo()), } } @@ -186,7 +186,7 @@ impl QuicDriver { .read_chunk(4096, true) .await .map(|c| c.map(|c| c.bytes)) - .map_err(|_| Error::ToDo) + .map_err(|_| Error::todo()) .transpose() } @@ -201,7 +201,7 @@ impl QuicDriver { Ok(None) | Err(ReadError::ConnectionLost(ConnectionError::ApplicationClosed(_))) | Err(ReadError::ConnectionLost(ConnectionError::LocallyClosed)) => return Ok(None), - Err(_) => return Err(Error::ToDo), + Err(_) => return Err(Error::todo()), } } } diff --git a/postgres/src/driver/quic/response.rs b/postgres/src/driver/quic/response.rs index e9f6db69..d5acf3e6 100644 --- a/postgres/src/driver/quic/response.rs +++ b/postgres/src/driver/quic/response.rs @@ -22,7 +22,7 @@ impl Response { loop { match backend::Message::parse(&mut self.buf)? { // TODO: error response. - Some(backend::Message::ErrorResponse(_body)) => return Err(Error::ToDo), + Some(backend::Message::ErrorResponse(_body)) => return Err(Error::todo()), Some(msg) => return Ok(msg), None => { let chunk = self diff --git a/postgres/src/driver/raw.rs b/postgres/src/driver/raw.rs index c118d8c2..1917d42e 100644 --- a/postgres/src/driver/raw.rs +++ b/postgres/src/driver/raw.rs @@ -131,7 +131,7 @@ where match cfg.get_ssl_mode() { SslMode::Disable => Ok(false), mode => match (_should_connect_tls(io).await?, mode) { - (false, SslMode::Require) => Err(Error::ToDo), + (false, SslMode::Require) => Err(Error::todo()), (bool, _) => Ok(bool), }, } diff --git a/postgres/src/driver/raw/response.rs b/postgres/src/driver/raw/response.rs index 050e5fd1..b39eb9d7 100644 --- a/postgres/src/driver/raw/response.rs +++ b/postgres/src/driver/raw/response.rs @@ -6,7 +6,10 @@ use core::{ use postgres_protocol::message::backend; use xitca_io::bytes::BytesMut; -use crate::{driver::codec::ResponseReceiver, error::Error}; +use crate::{ + driver::codec::ResponseReceiver, + error::{DriverDown, Error}, +}; pub struct Response { rx: ResponseReceiver, @@ -24,13 +27,13 @@ impl Response { pub(crate) fn recv(&mut self) -> impl Future> + '_ { poll_fn(|cx| { if self.buf.is_empty() { - self.buf = ready!(self.rx.poll_recv(cx)).ok_or_else(|| Error::DriverDown(BytesMut::new()))?; + self.buf = ready!(self.rx.poll_recv(cx)).ok_or_else(|| DriverDown(BytesMut::new()))?; } let res = match backend::Message::parse(&mut self.buf)?.expect("must not parse message from empty buffer.") { // TODO: error response. - backend::Message::ErrorResponse(_body) => Err(Error::ToDo), + backend::Message::ErrorResponse(_body) => Err(Error::todo()), msg => Ok(msg), }; diff --git a/postgres/src/driver/raw/tls.rs b/postgres/src/driver/raw/tls.rs index a774de98..82cc26b5 100644 --- a/postgres/src/driver/raw/tls.rs +++ b/postgres/src/driver/raw/tls.rs @@ -8,9 +8,9 @@ pub(super) async fn connect(io: Io, host: &str, cfg: &mut Config) -> Result< where Io: AsyncIo, { - let name = ServerName::try_from(host).map_err(|_| Error::ToDo)?.to_owned(); + let name = ServerName::try_from(host).map_err(|_| Error::todo())?.to_owned(); let config = dangerous_config(Vec::new()); - let session = ClientConnection::new(config, name).map_err(|_| Error::ToDo)?; + let session = ClientConnection::new(config, name).map_err(|_| Error::todo())?; let stream = TlsStream::handshake(io, session).await?; diff --git a/postgres/src/error.rs b/postgres/src/error.rs index 59a9b520..3d330406 100644 --- a/postgres/src/error.rs +++ b/postgres/src/error.rs @@ -1,4 +1,8 @@ -use core::{convert::Infallible, fmt}; +use core::{ + convert::Infallible, + fmt, mem, + ops::{Deref, DerefMut}, +}; use std::{error, io}; @@ -9,35 +13,100 @@ use crate::driver::codec::Request; use super::from_sql::FromSqlError; -#[non_exhaustive] -#[derive(Debug)] -pub enum Error { - Feature(FeatureError), - Authentication(AuthenticationError), - UnexpectedMessage, - Io(io::Error), - FromSql(FromSqlError), - InvalidColumnIndex(String), - DriverDown(BytesMut), - ToDo, +pub struct Error(Box); + +impl Error { + pub(crate) fn todo() -> Self { + Self("WIP error type placeholder".to_string().into()) + } + + pub(crate) fn unexpected() -> Self { + Self(Box::new(UnexpectedMessage)) + } + + #[cold] + #[inline(never)] + pub(crate) fn if_driver_down(&mut self) -> Option { + self.0.downcast_mut().map(mem::take) + } +} + +impl Deref for Error { + type Target = dyn error::Error + Send + Sync; + + fn deref(&self) -> &Self::Target { + &*self.0 + } +} + +impl DerefMut for Error { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.0 + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.0, f) + } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - Self::Feature(ref e) => fmt::Display::fmt(e, f), - Self::Authentication(ref e) => fmt::Display::fmt(e, f), - Self::UnexpectedMessage => f.write_str("unexpected message from server"), - Self::Io(ref e) => fmt::Display::fmt(e, f), - Self::FromSql(ref e) => fmt::Display::fmt(e, f), - Self::InvalidColumnIndex(ref name) => write!(f, "invalid column {name}"), - Self::DriverDown(_) => f.write_str("Driver is down. check Driver's async task output for reason"), - Self::ToDo => f.write_str("error informant is yet implemented"), - } + fmt::Display::fmt(&self.0, f) + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + self.0.source() + } +} + +#[derive(Default)] +pub struct DriverDown(pub BytesMut); + +impl fmt::Debug for DriverDown { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DriverDown").finish() } } -impl error::Error for Error {} +impl fmt::Display for DriverDown { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Driver is down") + } +} + +impl error::Error for DriverDown {} + +impl From for Error { + fn from(e: DriverDown) -> Self { + Self(Box::new(e)) + } +} + +pub struct InvalidColumnIndex(pub String); + +impl fmt::Debug for InvalidColumnIndex { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("InvalidColumnIndex").finish() + } +} + +impl fmt::Display for InvalidColumnIndex { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "invalid column index: {}", self.0) + } +} + +impl error::Error for InvalidColumnIndex {} + +impl From for Error { + fn from(e: InvalidColumnIndex) -> Self { + Self(Box::new(e)) + } +} impl From for Error { fn from(e: Infallible) -> Self { @@ -47,25 +116,25 @@ impl From for Error { impl From for Error { fn from(e: io::Error) -> Self { - Self::Io(e) + Self(Box::new(e)) } } impl From for Error { fn from(e: FromSqlError) -> Self { - Self::FromSql(e) + Self(e) } } impl From> for Error { fn from(e: SendError) -> Self { - Self::DriverDown(e.0) + Self(Box::new(DriverDown(e.0))) } } impl From> for Error { fn from(e: SendError) -> Self { - Self::DriverDown(e.0.msg) + Self(Box::new(DriverDown(e.0.msg))) } } @@ -88,9 +157,11 @@ impl fmt::Display for AuthenticationError { } } +impl error::Error for AuthenticationError {} + impl From for Error { fn from(e: AuthenticationError) -> Self { - Self::Authentication(e) + Self(Box::new(e)) } } @@ -111,12 +182,25 @@ impl fmt::Display for FeatureError { } } +impl error::Error for FeatureError {} + impl From for Error { fn from(e: FeatureError) -> Self { - Self::Feature(e) + Self(Box::new(e)) } } +#[derive(Debug)] +pub struct UnexpectedMessage; + +impl fmt::Display for UnexpectedMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("unexpected message from database") + } +} + +impl error::Error for UnexpectedMessage {} + #[cold] #[inline(never)] pub(crate) fn unexpected_eof_err() -> io::Error { diff --git a/postgres/src/pipeline.rs b/postgres/src/pipeline.rs index 44b5fa6f..f67410f2 100644 --- a/postgres/src/pipeline.rs +++ b/postgres/src/pipeline.rs @@ -194,7 +194,7 @@ impl<'a> AsyncLendingIterator for PipelineStream<'a> { // item arrives. } backend::Message::ReadyForQuery(_) => {} - _ => return Err(Error::UnexpectedMessage), + _ => return Err(Error::unexpected()), } } @@ -230,7 +230,7 @@ impl PipelineItem<'_, '_> { self.finished = true; return crate::query::decode::body_to_affected_rows(&body); } - _ => return Err(Error::UnexpectedMessage), + _ => return Err(Error::unexpected()), } } } @@ -252,7 +252,7 @@ impl AsyncLendingIterator for PipelineItem<'_, '_> { return Row::try_new(columns, body, &mut self.stream.ranges).map(Some); } backend::Message::CommandComplete(_) => self.finished = true, - _ => return Err(Error::UnexpectedMessage), + _ => return Err(Error::unexpected()), } } diff --git a/postgres/src/pool.rs b/postgres/src/pool.rs index c9ddd609..193b169d 100644 --- a/postgres/src/pool.rs +++ b/postgres/src/pool.rs @@ -10,7 +10,7 @@ use crate::{ client::Client, config::Config, driver::connect, - error::Error, + error::{DriverDown, Error}, iter::slice_iter, statement::{Statement, StatementGuarded}, util::lock::Lock, @@ -134,11 +134,14 @@ impl SharedClient { { let cli = self.read().await; match cli.query_raw(stmt, params).await { - Err(Error::DriverDown(buf)) => { - drop(cli); - Box::pin(self.query_raw_slow(stmt, buf)).await - } - res => res, + Ok(res) => Ok(res), + Err(mut e) => match e.if_driver_down() { + Some(DriverDown(buf)) => { + drop(cli); + Box::pin(self.query_raw_slow(stmt, buf)).await + } + None => Err(e), + }, } } @@ -148,8 +151,11 @@ impl SharedClient { loop { self.reconnect().await; match self.read().await.query_buf(stmt, buf).await { - Err(Error::DriverDown(b)) => buf = b, - res => return res, + Ok(res) => return Ok(res), + Err(mut e) => match e.if_driver_down() { + Some(DriverDown(b)) => buf = b, + None => return Err(e), + }, } } } @@ -159,11 +165,14 @@ impl SharedClient { pub async fn query_simple(&self, stmt: &str) -> Result { let cli = self.read().await; match cli.query_simple(stmt).await { - Err(Error::DriverDown(buf)) => { - drop(cli); - Box::pin(self.query_simple_slow(buf)).await - } - res => res, + Ok(res) => Ok(res), + Err(mut e) => match e.if_driver_down() { + Some(DriverDown(buf)) => { + drop(cli); + Box::pin(self.query_simple_slow(buf)).await + } + None => Err(e), + }, } } @@ -171,8 +180,11 @@ impl SharedClient { loop { self.reconnect().await; match self.read().await.query_buf_simple(buf).await { - Err(Error::DriverDown(b)) => buf = b, - res => return res, + Ok(res) => return Ok(res), + Err(mut e) => match e.if_driver_down() { + Some(DriverDown(b)) => buf = b, + None => return Err(e), + }, } } } @@ -186,11 +198,13 @@ impl SharedClient { let cli = self.read().await; match cli._prepare(query, types).await { Ok(stmt) => return Ok(stmt.into_guarded(cli)), - Err(Error::DriverDown(_)) => { + Err(mut e) => { + if e.if_driver_down().is_none() { + return Err(e); + } drop(cli); Box::pin(self.reconnect()).await; } - Err(e) => return Err(e), } } } @@ -254,12 +268,14 @@ impl SharedClient { columns: pipe.columns, ranges: Vec::new(), }), - Err(Error::DriverDown(buf)) => { - drop(cli); - pipe.buf = buf; - Box::pin(self.pipeline_slow::(pipe)).await - } - Err(e) => Err(e), + Err(mut e) => match e.if_driver_down() { + Some(DriverDown(b)) => { + drop(cli); + pipe.buf = b; + Box::pin(self.pipeline_slow::(pipe)).await + } + None => Err(e), + }, } } @@ -282,10 +298,10 @@ impl SharedClient { ranges: Vec::new(), }) } - Err(Error::DriverDown(buf)) => { - pipe.buf = buf; - } - Err(e) => return Err(e), + Err(mut e) => match e.if_driver_down() { + Some(DriverDown(b)) => pipe.buf = b, + None => return Err(e), + }, } } } diff --git a/postgres/src/prepare.rs b/postgres/src/prepare.rs index 00ff5433..d710d21e 100644 --- a/postgres/src/prepare.rs +++ b/postgres/src/prepare.rs @@ -46,23 +46,23 @@ impl Client { match res.recv().await? { backend::Message::ParseComplete => {} - _ => return Err(Error::UnexpectedMessage), + _ => return Err(Error::unexpected()), } let parameter_description = match res.recv().await? { backend::Message::ParameterDescription(body) => body, - _ => return Err(Error::UnexpectedMessage), + _ => return Err(Error::unexpected()), }; let row_description = match res.recv().await? { backend::Message::RowDescription(body) => Some(body), backend::Message::NoData => None, - _ => return Err(Error::UnexpectedMessage), + _ => return Err(Error::unexpected()), }; let mut parameters = Vec::new(); let mut it = parameter_description.parameters(); - while let Some(oid) = it.next().map_err(|_| Error::ToDo)? { + while let Some(oid) = it.next().map_err(|_| Error::todo())? { let ty = self.get_type(oid).await?; parameters.push(ty); } @@ -70,7 +70,7 @@ impl Client { let mut columns = Vec::new(); if let Some(row_description) = row_description { let mut it = row_description.fields(); - while let Some(field) = it.next().map_err(|_| Error::ToDo)? { + while let Some(field) = it.next().map_err(|_| Error::todo())? { let type_ = self.get_type(field.type_oid()).await?; let column = Column::new(field.name(), type_); columns.push(column); @@ -91,7 +91,7 @@ impl Client { let stmt = self.typeinfo_statement().await?; let mut rows = self.query_raw(&stmt, &[&oid]).await?; - let row = rows.try_next().await?.ok_or(Error::UnexpectedMessage)?; + let row = rows.try_next().await?.ok_or_else(Error::unexpected)?; let name = row.try_get::(0)?; let type_ = row.try_get::(1)?; diff --git a/postgres/src/query/base.rs b/postgres/src/query/base.rs index 068af3be..afc17cb7 100644 --- a/postgres/src/query/base.rs +++ b/postgres/src/query/base.rs @@ -89,7 +89,7 @@ impl Client { let mut res = self.send(buf).await?; match res.recv().await? { backend::Message::BindComplete => Ok(res), - _ => Err(Error::UnexpectedMessage), + _ => Err(Error::unexpected()), } } @@ -116,7 +116,7 @@ impl Response { } backend::Message::EmptyQueryResponse => rows = 0, backend::Message::ReadyForQuery(_) => return Ok(rows), - _ => return Err(Error::UnexpectedMessage), + _ => return Err(Error::unexpected()), } } } @@ -129,7 +129,7 @@ impl Response { | backend::Message::DataRow(_) | backend::Message::EmptyQueryResponse => {} backend::Message::ReadyForQuery(_) => return Ok(()), - _ => return Err(Error::UnexpectedMessage), + _ => return Err(Error::unexpected()), } } } @@ -150,7 +150,7 @@ impl<'a> AsyncLendingIterator for RowStream<'a> { | backend::Message::CommandComplete(_) | backend::Message::PortalSuspended => {} backend::Message::ReadyForQuery(_) => return Ok(None), - _ => return Err(Error::UnexpectedMessage), + _ => return Err(Error::unexpected()), } } } diff --git a/postgres/src/query/decode.rs b/postgres/src/query/decode.rs index 7b6c6b23..cc33f85c 100644 --- a/postgres/src/query/decode.rs +++ b/postgres/src/query/decode.rs @@ -4,6 +4,6 @@ use postgres_protocol::message::backend; // Extract the number of rows affected. pub(crate) fn body_to_affected_rows(body: &backend::CommandCompleteBody) -> Result { body.tag() - .map_err(|_| Error::ToDo) + .map_err(|_| Error::todo()) .map(|r| r.rsplit(' ').next().unwrap().parse().unwrap_or(0)) } diff --git a/postgres/src/query/encode.rs b/postgres/src/query/encode.rs index 6cb1493c..eb1aea69 100644 --- a/postgres/src/query/encode.rs +++ b/postgres/src/query/encode.rs @@ -22,7 +22,7 @@ where I::Item: BorrowToSql, { encode_bind(stmt, params, "", buf)?; - frontend::execute("", 0, buf).map_err(|_| Error::ToDo)?; + frontend::execute("", 0, buf).map_err(|_| Error::todo())?; if SYNC_MODE { frontend::sync(buf); } @@ -54,7 +54,7 @@ where match r { Ok(()) => Ok(()), - Err(frontend::BindError::Conversion(_)) => Err(Error::ToDo), - Err(frontend::BindError::Serialization(_)) => Err(Error::ToDo), + Err(frontend::BindError::Conversion(_)) => Err(Error::todo()), + Err(frontend::BindError::Serialization(_)) => Err(Error::todo()), } } diff --git a/postgres/src/query/simple.rs b/postgres/src/query/simple.rs index e3e54f9d..fa9d4502 100644 --- a/postgres/src/query/simple.rs +++ b/postgres/src/query/simple.rs @@ -59,7 +59,7 @@ impl AsyncLendingIterator for RowSimpleStream { backend::Message::CommandComplete(_) | backend::Message::EmptyQueryResponse | backend::Message::ReadyForQuery(_) => return Ok(None), - _ => return Err(Error::UnexpectedMessage), + _ => return Err(Error::unexpected()), } } } diff --git a/postgres/src/row/types.rs b/postgres/src/row/types.rs index bde6d85d..e2fd15da 100644 --- a/postgres/src/row/types.rs +++ b/postgres/src/row/types.rs @@ -5,7 +5,12 @@ use postgres_protocol::message::backend::DataRowBody; use postgres_types::FromSql; use xitca_io::bytes::Bytes; -use crate::{column::Column, error::Error, from_sql::FromSqlExt, Type}; +use crate::{ + column::Column, + error::{Error, InvalidColumnIndex}, + from_sql::FromSqlExt, + Type, +}; use super::traits::RowIndexAndType; @@ -84,10 +89,10 @@ impl<'a, C> GenericRow<'a, C> { ) -> Result<(usize, &Type), Error> { let (idx, ty) = idx ._from_columns(self.columns()) - .ok_or_else(|| Error::InvalidColumnIndex(format!("{idx}")))?; + .ok_or_else(|| InvalidColumnIndex(idx.to_string()))?; if !ty_check(ty) { - return Err(Error::ToDo); + return Err(Error::todo()); // return Err(Error::from_sql(Box::new(WrongType::new::(ty.clone())), idx)); } diff --git a/postgres/src/session.rs b/postgres/src/session.rs index 94824e91..c1533692 100644 --- a/postgres/src/session.rs +++ b/postgres/src/session.rs @@ -57,15 +57,15 @@ where loop { match drv.recv().await? { backend::Message::DataRow(body) => { - let range = body.ranges().next()?.flatten().ok_or(Error::ToDo)?; + let range = body.ranges().next()?.flatten().ok_or(Error::todo())?; let slice = &body.buffer()[range.start..range.end]; if slice == b"on" { - return Err(Error::ToDo); + return Err(Error::todo()); } } backend::Message::RowDescription(_) | backend::Message::CommandComplete(_) => {} backend::Message::EmptyQueryResponse | backend::Message::ReadyForQuery(_) => break, - _ => return Err(Error::UnexpectedMessage), + _ => return Err(Error::unexpected()), } } } @@ -137,12 +137,12 @@ where } else { // server ask for channel binding but no tls_server_end_point can be // found. - return Err(Error::ToDo); + return Err(Error::todo()); } } (false, true) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256), // TODO: return "unsupported SASL mechanism" error. - (false, false) => return Err(Error::ToDo), + (false, false) => return Err(Error::todo()), }; let mut scram = sasl::ScramSha256::new(pass, channel_binding); @@ -158,12 +158,12 @@ where let msg = buf.split(); drv.send(msg).await?; } - _ => return Err(Error::ToDo), + _ => return Err(Error::todo()), } match drv.recv().await? { backend::Message::AuthenticationSaslFinal(body) => scram.finish(body.data())?, - _ => return Err(Error::ToDo), + _ => return Err(Error::todo()), } } backend::Message::ErrorResponse(_) => return Err(Error::from(AuthenticationError::WrongPassWord)), From 4ae00f609d11a5ad178491de04a0d19122c21f25 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Fri, 15 Mar 2024 10:48:46 +0800 Subject: [PATCH 2/4] fix pool error handling. --- postgres/src/pool.rs | 124 ++++++++++++++++++++++++------------------- 1 file changed, 68 insertions(+), 56 deletions(-) diff --git a/postgres/src/pool.rs b/postgres/src/pool.rs index 193b169d..c745ac5a 100644 --- a/postgres/src/pool.rs +++ b/postgres/src/pool.rs @@ -135,27 +135,27 @@ impl SharedClient { let cli = self.read().await; match cli.query_raw(stmt, params).await { Ok(res) => Ok(res), - Err(mut e) => match e.if_driver_down() { - Some(DriverDown(buf)) => { - drop(cli); - Box::pin(self.query_raw_slow(stmt, buf)).await - } - None => Err(e), - }, + Err(err) => { + drop(cli); + Box::pin(self.query_raw_slow(stmt, err)).await + } } } - #[cold] - #[inline(never)] - async fn query_raw_slow<'a>(&self, stmt: &'a Statement, mut buf: BytesMut) -> Result, Error> { + async fn query_raw_slow<'a>(&self, stmt: &'a Statement, mut err: Error) -> Result, Error> { + let mut buf; + loop { + match err.if_driver_down() { + Some(DriverDown(b)) => buf = b, + None => return Err(err), + } + self.reconnect().await; + match self.read().await.query_buf(stmt, buf).await { Ok(res) => return Ok(res), - Err(mut e) => match e.if_driver_down() { - Some(DriverDown(b)) => buf = b, - None => return Err(e), - }, + Err(e) => err = e, } } } @@ -256,53 +256,65 @@ impl SharedClient { } #[cfg(not(feature = "quic"))] -impl SharedClient { - pub async fn pipeline<'a, const SYNC_MODE: bool>( - &self, - mut pipe: crate::pipeline::Pipeline<'a, SYNC_MODE>, - ) -> Result, Error> { - let cli = self.read().await; - match cli._pipeline::(&pipe.columns, pipe.buf).await { - Ok(res) => Ok(crate::pipeline::PipelineStream { - res, - columns: pipe.columns, - ranges: Vec::new(), - }), - Err(mut e) => match e.if_driver_down() { - Some(DriverDown(b)) => { +const _: () = { + use std::collections::VecDeque; + + use crate::{ + column::Column, + pipeline::{Pipeline, PipelineStream}, + }; + + impl SharedClient { + pub async fn pipeline<'a, const SYNC_MODE: bool>( + &self, + pipe: Pipeline<'a, SYNC_MODE>, + ) -> Result, Error> { + let Pipeline { columns, buf } = pipe; + let cli = self.read().await; + match cli._pipeline::(&columns, buf).await { + Ok(res) => Ok(PipelineStream { + res, + columns, + ranges: Vec::new(), + }), + Err(err) => { drop(cli); - pipe.buf = b; - Box::pin(self.pipeline_slow::(pipe)).await + Box::pin(self.pipeline_slow::(columns, err)).await } - None => Err(e), - }, + } } - } - async fn pipeline_slow<'a, const SYNC_MODE: bool>( - &self, - mut pipe: crate::pipeline::Pipeline<'a, SYNC_MODE>, - ) -> Result, Error> { - loop { - self.reconnect().await; - match self - .read() - .await - ._pipeline_no_additive_sync::(&pipe.columns, pipe.buf) - .await - { - Ok(res) => { - return Ok(crate::pipeline::PipelineStream { - res, - columns: pipe.columns, - ranges: Vec::new(), - }) + async fn pipeline_slow<'a, const SYNC_MODE: bool>( + &self, + columns: VecDeque<&'a [Column]>, + mut err: Error, + ) -> Result, Error> { + let mut buf; + + loop { + match err.if_driver_down() { + Some(DriverDown(b)) => buf = b, + None => return Err(err), + } + + self.reconnect().await; + + match self + .read() + .await + ._pipeline_no_additive_sync::(&columns, buf) + .await + { + Ok(res) => { + return Ok(crate::pipeline::PipelineStream { + res, + columns, + ranges: Vec::new(), + }) + } + Err(e) => err = e, } - Err(mut e) => match e.if_driver_down() { - Some(DriverDown(b)) => pipe.buf = b, - None => return Err(e), - }, } } } -} +}; From 60ede0b53b7524680ebe822cb055a227aa7868d0 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Fri, 15 Mar 2024 18:27:55 +0800 Subject: [PATCH 3/4] doc fix. --- postgres/src/driver.rs | 2 +- postgres/src/error.rs | 24 +++++++++++++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/postgres/src/driver.rs b/postgres/src/driver.rs index 8edf4427..5798bb97 100644 --- a/postgres/src/driver.rs +++ b/postgres/src/driver.rs @@ -55,7 +55,7 @@ pub(super) async fn connect(cfg: &mut Config) -> Result<(Client, Driver), Error> } /// async driver of [Client](crate::Client). -/// it handles IO and emit server sent message that do not belong to any query with [AsyncIterator] +/// it handles IO and emit server sent message that do not belong to any query with [AsyncLendingIterator] /// trait impl. /// /// # Examples: diff --git a/postgres/src/error.rs b/postgres/src/error.rs index 3d330406..7d72e4ff 100644 --- a/postgres/src/error.rs +++ b/postgres/src/error.rs @@ -13,6 +13,19 @@ use crate::driver::codec::Request; use super::from_sql::FromSqlError; +/// public facing error type. providing basic format and display based error handling. +/// for typed based error handling runtime type cast is needed with the help of other +/// public error types offered by this module. +/// +/// # Example +/// ```rust +/// use xitca_postgres::error::{DriverDown, Error}; +/// +/// fn is_driver_down(e: Error) -> bool { +/// // downcast error to DriverDown error type to check if client driver is gone. +/// e.downcast_ref::().is_some() +/// } +/// ``` pub struct Error(Box); impl Error { @@ -63,6 +76,14 @@ impl error::Error for Error { } } +/// error indicate [Client]'s [Driver] is dropped and can't be accessed anymore. +/// +/// the field inside error contains the raw bytes buffer of query message that are ready to be +/// sent to the [Driver] for transporting. It's possible to construct a new [Client] and [Driver] +/// pair where the bytes buffer could be reused for graceful retry of previous failed queries. +/// +/// [Client]: crate::client::Client +/// [Driver]: crate::driver::Driver #[derive(Default)] pub struct DriverDown(pub BytesMut); @@ -74,7 +95,7 @@ impl fmt::Debug for DriverDown { impl fmt::Display for DriverDown { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("Driver is down") + f.write_str("Driver is dropped and unaccessible.") } } @@ -138,6 +159,7 @@ impl From> for Error { } } +/// error happens when library user failed to provide valid authentication info to database server. #[derive(Debug)] pub enum AuthenticationError { MissingUserName, From d413dbda318f4b25b4c99a92687b4c00ca462a82 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Fri, 15 Mar 2024 23:41:55 +0800 Subject: [PATCH 4/4] remove dead code from driver. --- postgres/src/driver.rs | 17 +++++------------ postgres/src/driver/quic.rs | 2 +- postgres/src/driver/raw.rs | 8 ++++---- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/postgres/src/driver.rs b/postgres/src/driver.rs index 5798bb97..ee9bfd2d 100644 --- a/postgres/src/driver.rs +++ b/postgres/src/driver.rs @@ -80,8 +80,6 @@ pub(super) async fn connect(cfg: &mut Config) -> Result<(Client, Driver), Error> /// ``` pub struct Driver { inner: _Driver, - #[allow(dead_code)] - config: Config, } impl Driver { @@ -111,44 +109,39 @@ impl Driver { #[cfg(not(feature = "quic"))] impl Driver { - pub(super) fn tcp(drv: GenericDriver, config: Config) -> Self { + pub(super) fn tcp(drv: GenericDriver) -> Self { Self { inner: _Driver::Tcp(drv), - config, } } #[cfg(feature = "tls")] - pub(super) fn tls(drv: GenericDriver>, config: Config) -> Self { + pub(super) fn tls(drv: GenericDriver>) -> Self { Self { inner: _Driver::Tls(drv), - config, } } #[cfg(unix)] - pub(super) fn unix(drv: GenericDriver, config: Config) -> Self { + pub(super) fn unix(drv: GenericDriver) -> Self { Self { inner: _Driver::Unix(drv), - config, } } #[cfg(all(unix, feature = "tls"))] - pub(super) fn unix_tls(drv: GenericDriver>, config: Config) -> Self { + pub(super) fn unix_tls(drv: GenericDriver>) -> Self { Self { inner: _Driver::UnixTls(drv), - config, } } } #[cfg(feature = "quic")] impl Driver { - pub(super) fn quic(drv: QuicDriver, config: Config) -> Self { + pub(super) fn quic(drv: QuicDriver) -> Self { Self { inner: _Driver::Quic(drv), - config, } } } diff --git a/postgres/src/driver/quic.rs b/postgres/src/driver/quic.rs index f5f923f7..81e450dc 100644 --- a/postgres/src/driver/quic.rs +++ b/postgres/src/driver/quic.rs @@ -90,7 +90,7 @@ pub(super) async fn _connect(host: Host, cfg: &Config) -> Result<(ClientTx, Driv let mut drv = QuicDriver::new(streams); prepare_session(&mut drv, cfg).await?; drv.close_tx().await; - Ok((tx, Driver::quic(drv, cfg.clone()))) + Ok((tx, Driver::quic(drv))) } _ => unreachable!(), } diff --git a/postgres/src/driver/raw.rs b/postgres/src/driver/raw.rs index 1917d42e..2b8f6292 100644 --- a/postgres/src/driver/raw.rs +++ b/postgres/src/driver/raw.rs @@ -68,7 +68,7 @@ pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(ClientTx, let io = tls::connect(io, host, cfg).await?; let (mut drv, tx) = GenericDriver::new(io); prepare_session(&mut drv, cfg).await?; - Ok((ClientTx(tx), Driver::tls(drv, cfg.clone()))) + Ok((ClientTx(tx), Driver::tls(drv))) } #[cfg(not(feature = "tls"))] { @@ -77,7 +77,7 @@ pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(ClientTx, } else { let (mut drv, tx) = GenericDriver::new(io); prepare_session(&mut drv, cfg).await?; - Ok((ClientTx(tx), Driver::tcp(drv, cfg.clone()))) + Ok((ClientTx(tx), Driver::tcp(drv))) } } #[cfg(unix)] @@ -90,7 +90,7 @@ pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(ClientTx, let io = tls::connect(io, host.as_ref(), cfg).await?; let (mut drv, tx) = GenericDriver::new(io); prepare_session(&mut drv, cfg).await?; - Ok((ClientTx(tx), Driver::unix_tls(drv, cfg.clone()))) + Ok((ClientTx(tx), Driver::unix_tls(drv))) } #[cfg(not(feature = "tls"))] { @@ -99,7 +99,7 @@ pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(ClientTx, } else { let (mut drv, tx) = GenericDriver::new(io); prepare_session(&mut drv, cfg).await?; - Ok((ClientTx(tx), Driver::unix(drv, cfg.clone()))) + Ok((ClientTx(tx), Driver::unix(drv))) } } _ => unreachable!(),