diff --git a/postgres/src/driver.rs b/postgres/src/driver.rs index e33ebc6d..b3cf96cf 100644 --- a/postgres/src/driver.rs +++ b/postgres/src/driver.rs @@ -41,12 +41,12 @@ use xitca_tls::rustls::{ClientConnection, TlsStream}; #[cfg(unix)] use xitca_io::net::UnixStream; -pub(super) async fn connect(mut cfg: Config) -> Result<(Client, Driver), Error> { +pub(super) async fn connect(cfg: &mut Config) -> Result<(Client, Driver), Error> { let mut err = None; let hosts = cfg.get_hosts().to_vec(); for host in hosts { - match _connect(host, &mut cfg).await { - Ok(t) => return Ok(t), + match _connect(host, cfg).await { + Ok((tx, drv)) => return Ok((Client::new(tx), drv)), Err(e) => err = Some(e), } } @@ -80,43 +80,100 @@ pub(super) async fn connect(mut cfg: Config) -> Result<(Client, Driver), Error> /// ``` pub struct Driver { inner: _Driver, + #[allow(dead_code)] + config: Config, +} + +impl Driver { + // run till the connection is closed by Client. + async fn run_till_closed(self) { + #[cfg(not(feature = "quic"))] + { + let mut this = self; + while let Err(e) = match this.inner { + _Driver::Tcp(ref mut drv) => drv.run().await, + #[cfg(feature = "tls")] + _Driver::Tls(ref mut drv) => drv.run().await, + #[cfg(unix)] + _Driver::Unix(ref mut drv) => drv.run().await, + #[cfg(all(unix, feature = "tls"))] + _Driver::UnixTls(ref mut drv) => drv.run().await, + } { + while this.reconnect(&e).await.is_err() {} + } + } + + #[cfg(feature = "quic")] + match self.inner { + _Driver::Quic(drv) => { + let _ = drv.run().await; + } + } + } } #[cfg(not(feature = "quic"))] impl Driver { - pub(super) fn tcp(drv: GenericDriver) -> Self { + /// reconnect to server with a fresh connection and state. Driver's associated + /// [Client] is able to be re-used for the fresh connection. + /// + /// MUST be called when `::try_next` emit [Error]. + /// All in flight database query and response will be lost in the process. + pub async fn reconnect(&mut self, _: &Error) -> Result<(), Error> { + let (_, Driver { inner: inner_new, .. }) = connect(&mut self.config).await?; + + match (&mut self.inner, inner_new) { + (_Driver::Tcp(drv), _Driver::Tcp(drv_new)) => drv.replace(drv_new), + #[cfg(feature = "tls")] + (_Driver::Tls(drv), _Driver::Tls(drv_new)) => drv.replace(drv_new), + #[cfg(unix)] + (_Driver::Unix(drv), _Driver::Unix(drv_new)) => drv.replace(drv_new), + #[cfg(all(unix, feature = "tls"))] + (_Driver::UnixTls(drv), _Driver::UnixTls(drv_new)) => drv.replace(drv_new), + _ => unreachable!("reconnect should always yield the same type of generic driver"), + }; + + Ok(()) + } + + pub(super) fn tcp(drv: GenericDriver, config: Config) -> Self { Self { inner: _Driver::Tcp(drv), + config, } } #[cfg(feature = "tls")] - pub(super) fn tls(drv: GenericDriver>) -> Self { + pub(super) fn tls(drv: GenericDriver>, config: Config) -> Self { Self { inner: _Driver::Tls(drv), + config, } } #[cfg(unix)] - pub(super) fn unix(drv: GenericDriver) -> Self { + pub(super) fn unix(drv: GenericDriver, config: Config) -> Self { Self { inner: _Driver::Unix(drv), + config, } } #[cfg(all(unix, feature = "tls"))] - pub(super) fn unix_tls(drv: GenericDriver>) -> Self { + pub(super) fn unix_tls(drv: GenericDriver>, config: Config) -> Self { Self { inner: _Driver::UnixTls(drv), + config, } } } #[cfg(feature = "quic")] impl Driver { - pub(super) fn quic(drv: QuicDriver) -> Self { + pub(super) fn quic(drv: QuicDriver, config: Config) -> Self { Self { inner: _Driver::Quic(drv), + config, } } } @@ -163,25 +220,11 @@ impl AsyncLendingIterator for Driver { } impl IntoFuture for Driver { - type Output = Result<(), Error>; + type Output = (); type IntoFuture = Pin + Send>>; fn into_future(self) -> Self::IntoFuture { - #[cfg(not(feature = "quic"))] - match self.inner { - _Driver::Tcp(drv) => Box::pin(drv.run()), - #[cfg(feature = "tls")] - _Driver::Tls(drv) => Box::pin(drv.run()), - #[cfg(unix)] - _Driver::Unix(drv) => Box::pin(drv.run()), - #[cfg(all(unix, feature = "tls"))] - _Driver::UnixTls(drv) => Box::pin(drv.run()), - } - - #[cfg(feature = "quic")] - match self.inner { - _Driver::Quic(drv) => Box::pin(drv.run()), - } + Box::pin(self.run_till_closed()) } } @@ -210,7 +253,7 @@ impl Driver { let tcp = xitca_io::net::io_uring::TcpStream::from_std(std); io_uring::IoUringDriver::new( tcp, - drv.rx.unwrap(), + drv.rx, drv.write_buf.into_inner(), drv.read_buf.into_inner(), drv.res, diff --git a/postgres/src/driver/generic.rs b/postgres/src/driver/generic.rs index 0f613d0b..a95d57ab 100644 --- a/postgres/src/driver/generic.rs +++ b/postgres/src/driver/generic.rs @@ -31,8 +31,14 @@ pub(crate) struct GenericDriver { pub(crate) io: Io, pub(crate) write_buf: WriteBuf, pub(crate) read_buf: PagedBytesMut, - pub(crate) rx: Option, + pub(crate) rx: GenericDriverRx, pub(crate) res: VecDeque, + state: DriverState, +} + +enum DriverState { + Running, + Closing(Option), } impl GenericDriver @@ -46,14 +52,32 @@ where io, write_buf: WriteBuf::new(), read_buf: PagedBytesMut::new(), - rx: Some(rx), + rx, res: VecDeque::new(), + state: DriverState::Running, }, tx, ) } - pub(crate) async fn try_next(&mut self) -> Result, Error> { + #[cfg(not(feature = "quic"))] + pub(crate) fn replace(&mut self, other: Self) { + let Self { + io, + write_buf, + read_buf, + res, + state, + .. + } = other; + self.io = io; + self.write_buf = write_buf; + self.read_buf = read_buf; + self.res = res; + self.state = state; + } + + async fn _try_next(&mut self) -> Result, Error> { loop { if let Some(msg) = self.try_decode()? { return Ok(Some(msg)); @@ -65,19 +89,20 @@ where Interest::READABLE }; - let select = match self.rx { - Some(ref mut rx) => { + let select = match self.state { + DriverState::Running => { let ready = self.io.ready(interest); - rx.recv().select(ready).await + self.rx.recv().select(ready).await } - None => { + DriverState::Closing(ref mut e) => { if !interest.is_writable() && self.res.is_empty() { // no interest to write to io and all response have been finished so // shutdown io and exit. // if there is a better way to exhaust potential remaining backend message // please file an issue. poll_fn(|cx| Pin::new(&mut self.io).poll_shutdown(cx)).await?; - return Ok(None); + + return e.take().map(|e| Err(e.into())).transpose(); } let ready = self.io.ready(interest); SelectOutput::B(ready.await) @@ -95,25 +120,34 @@ where if ready.is_readable() { self.try_read()?; } - if ready.is_writable() && self.try_write().is_err() { - // write failed as server stopped reading. - // drop channel so all pending request in it can be notified. - self.rx = None; + if ready.is_writable() { + if let Err(e) = self.try_write() { + error!("server closed read half unexpectedly: {e}"); + + // when write error occur the driver would go into half close state(read only). + // clearing write_buf would drop all pending requests in it and hint the driver + // no future Interest::WRITABLE should be passed to AsyncIo::ready method. + self.write_buf.clear(); + + // enter closed state and no more request would be received from channel. + // requests inside it would eventually be dropped after shutdown completed. + self.state = DriverState::Closing(Some(e)); + } } } - SelectOutput::A(None) => self.rx = None, + SelectOutput::A(None) => self.state = DriverState::Closing(None), } } } // TODO: remove this feature gate. #[cfg(not(feature = "quic"))] - pub(crate) async fn run(mut self) -> Result<(), Error> { - while self.try_next().await?.is_some() {} + pub(crate) async fn run(&mut self) -> Result<(), Error> { + while self._try_next().await?.is_some() {} Ok(()) } - pub(crate) async fn send(&mut self, msg: BytesMut) -> Result<(), Error> { + async fn send(&mut self, msg: BytesMut) -> Result<(), Error> { self.write_buf_extend(&msg); loop { self.try_write()?; @@ -151,14 +185,7 @@ where } fn try_write(&mut self) -> io::Result<()> { - self.write_buf.do_io(&mut self.io).map_err(|e| { - // when write error occur the driver would go into half close state(read only). - // clearing write_buf would drop all pending requests in it and hint the driver no - // future Interest::READABLE should be passed to AsyncIo::ready method. - self.write_buf.clear(); - error!("server closed read half unexpectedly: {e}"); - e - }) + self.write_buf.do_io(&mut self.io) } fn try_decode(&mut self) -> Result, Error> { @@ -186,8 +213,8 @@ where type Err = Error; #[inline] - async fn try_next(&mut self) -> Result>, Self::Err> { - self.try_next().await + fn try_next(&mut self) -> impl Future>, Self::Err>> + Send { + self._try_next() } } diff --git a/postgres/src/driver/quic.rs b/postgres/src/driver/quic.rs index a0a09ab4..2bfab816 100644 --- a/postgres/src/driver/quic.rs +++ b/postgres/src/driver/quic.rs @@ -17,10 +17,10 @@ use quinn_proto::ConnectionError; use xitca_io::bytes::{Bytes, BytesMut}; use crate::{ - client::Client, config::{Config, Host}, error::{unexpected_eof_err, Error}, iter::AsyncLendingIterator, + session::prepare_session, }; use super::{Drive, Driver}; @@ -82,16 +82,15 @@ impl ClientTx { #[cold] #[inline(never)] -pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(Client, Driver), Error> { +pub(super) async fn _connect(host: Host, cfg: &Config) -> Result<(ClientTx, Driver), Error> { match host { Host::Udp(ref host) => { let tx = connect_quic(host, cfg.get_ports()).await?; let streams = tx.inner.open_bi().await.unwrap(); let mut drv = QuicDriver::new(streams); - let mut cli = Client::new(tx); - cli.prepare_session(&mut drv, cfg).await?; + prepare_session(&mut drv, cfg).await?; drv.close_tx().await; - Ok((cli, Driver::quic(drv))) + Ok((tx, Driver::quic(drv, cfg.clone()))) } _ => unreachable!(), } diff --git a/postgres/src/driver/raw.rs b/postgres/src/driver/raw.rs index 55cc0f14..ccd98b09 100644 --- a/postgres/src/driver/raw.rs +++ b/postgres/src/driver/raw.rs @@ -19,9 +19,9 @@ use xitca_io::{ }; use crate::{ - client::Client, config::{Config, Host, SslMode}, error::{unexpected_eof_err, write_zero_err, Error}, + session::prepare_session, }; use super::{ @@ -56,7 +56,7 @@ impl ClientTx { #[cold] #[inline(never)] -pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(Client, Driver), Error> { +pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(ClientTx, Driver), Error> { // this block have repeated code due to HRTB limitation. // namely for <'_> AsyncIo::Future<'_>: Send bound can not be expressed correctly. match host { @@ -67,9 +67,8 @@ pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(Client, Dr { let io = tls::connect(io, host, cfg).await?; let (mut drv, tx) = GenericDriver::new(io); - let mut cli = Client::new(ClientTx(tx)); - cli.prepare_session(&mut drv, cfg).await?; - Ok((cli, Driver::tls(drv))) + prepare_session(&mut drv, cfg).await?; + Ok((ClientTx(tx), Driver::tls(drv, cfg.clone()))) } #[cfg(not(feature = "tls"))] { @@ -77,9 +76,8 @@ pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(Client, Dr } } else { let (mut drv, tx) = GenericDriver::new(io); - let mut cli = Client::new(ClientTx(tx)); - cli.prepare_session(&mut drv, cfg).await?; - Ok((cli, Driver::tcp(drv))) + prepare_session(&mut drv, cfg).await?; + Ok((ClientTx(tx), Driver::tcp(drv, cfg.clone()))) } } #[cfg(unix)] @@ -91,9 +89,8 @@ pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(Client, Dr let host = host.to_string_lossy(); let io = tls::connect(io, host.as_ref(), cfg).await?; let (mut drv, tx) = GenericDriver::new(io); - let mut cli = Client::new(ClientTx(tx)); - cli.prepare_session(&mut drv, cfg).await?; - Ok((cli, Driver::unix_tls(drv))) + prepare_session(&mut drv, cfg).await?; + Ok((ClientTx(tx), Driver::unix_tls(drv, cfg.clone()))) } #[cfg(not(feature = "tls"))] { @@ -101,9 +98,8 @@ pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(Client, Dr } } else { let (mut drv, tx) = GenericDriver::new(io); - let mut cli = Client::new(ClientTx(tx)); - cli.prepare_session(&mut drv, cfg).await?; - Ok((cli, Driver::unix(drv))) + prepare_session(&mut drv, cfg).await?; + Ok((ClientTx(tx), Driver::unix(drv, cfg.clone()))) } } _ => unreachable!(), diff --git a/postgres/src/driver/raw/tls.rs b/postgres/src/driver/raw/tls.rs index 94adc755..5809d691 100644 --- a/postgres/src/driver/raw/tls.rs +++ b/postgres/src/driver/raw/tls.rs @@ -18,7 +18,7 @@ where if let Some(sha256) = stream .session() .peer_certificates() - .and_then(|certs| certs.get(0)) + .and_then(|certs| certs.first()) .map(|cert| Sha256::digest(cert.as_ref()).to_vec()) { cfg.tls_server_end_point(sha256); diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index f0ea72bb..78c2c89a 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -92,8 +92,8 @@ where /// /// ``` pub async fn connect(self) -> Result<(Client, Driver), Error> { - let cfg = Config::try_from(self.cfg)?; - driver::connect(cfg).await + let mut cfg = Config::try_from(self.cfg)?; + driver::connect(&mut cfg).await } } @@ -157,6 +157,6 @@ mod test { drop(cli); - handle.await.unwrap().unwrap(); + handle.await.unwrap(); } } diff --git a/postgres/src/proxy.rs b/postgres/src/proxy.rs index b816b975..5656fcf0 100644 --- a/postgres/src/proxy.rs +++ b/postgres/src/proxy.rs @@ -7,6 +7,8 @@ use tracing::error; use xitca_io::{bytes::BytesMut, net::TcpStream}; use xitca_unsafe_collection::futures::{Select, SelectOutput}; +use crate::iter::AsyncLendingIterator; + use super::driver::{ codec::Request, generic::{GenericDriver, GenericDriverTx}, diff --git a/postgres/src/session.rs b/postgres/src/session.rs index cfcf0a1d..94824e91 100644 --- a/postgres/src/session.rs +++ b/postgres/src/session.rs @@ -5,9 +5,9 @@ use postgres_protocol::{ authentication::{self, sasl}, message::{backend, frontend}, }; +use xitca_io::bytes::BytesMut; use super::{ - client::Client, config::Config, driver::Drive, error::{AuthenticationError, Error}, @@ -23,156 +23,160 @@ pub enum TargetSessionAttrs { ReadWrite, } -impl Client { - #[allow(clippy::needless_pass_by_ref_mut)] // dumb clippy - #[cold] - #[inline(never)] - pub(super) async fn prepare_session(&mut self, drv: &mut D, cfg: &mut Config) -> Result<(), Error> - where - D: Drive, - { - self.auth(drv, cfg).await?; - - loop { - match drv.recv().await? { - backend::Message::ReadyForQuery(_) => break, - backend::Message::BackendKeyData(_) => { - // TODO: handle process id and secret key. - } - backend::Message::ParameterStatus(_) => { - // TODO: handle parameters - } - _ => { - // TODO: other session message handling? - } +#[allow(clippy::needless_pass_by_ref_mut)] // dumb clippy +#[cold] +#[inline(never)] +pub(super) async fn prepare_session(drv: &mut D, cfg: &Config) -> Result<(), Error> +where + D: Drive, +{ + let mut buf = BytesMut::new(); + + auth(drv, cfg, &mut buf).await?; + + loop { + match drv.recv().await? { + backend::Message::ReadyForQuery(_) => break, + backend::Message::BackendKeyData(_) => { + // TODO: handle process id and secret key. + } + backend::Message::ParameterStatus(_) => { + // TODO: handle parameters + } + _ => { + // TODO: other session message handling? } } + } - if matches!(cfg.get_target_session_attrs(), TargetSessionAttrs::ReadWrite) { - let buf = self.try_buf_and_split(|buf| frontend::query("SHOW transaction_read_only", buf))?; - drv.send(buf).await?; - // TODO: use RowSimple for parsing? - loop { - match drv.recv().await? { - backend::Message::DataRow(body) => { - let range = body.ranges().next()?.flatten().ok_or(Error::ToDo)?; - let slice = &body.buffer()[range.start..range.end]; - if slice == b"on" { - return Err(Error::ToDo); - } + if matches!(cfg.get_target_session_attrs(), TargetSessionAttrs::ReadWrite) { + frontend::query("SHOW transaction_read_only", &mut buf)?; + let msg = buf.split(); + drv.send(msg).await?; + // TODO: use RowSimple for parsing? + loop { + match drv.recv().await? { + backend::Message::DataRow(body) => { + let range = body.ranges().next()?.flatten().ok_or(Error::ToDo)?; + let slice = &body.buffer()[range.start..range.end]; + if slice == b"on" { + return Err(Error::ToDo); } - backend::Message::RowDescription(_) | backend::Message::CommandComplete(_) => {} - backend::Message::EmptyQueryResponse | backend::Message::ReadyForQuery(_) => break, - _ => return Err(Error::UnexpectedMessage), } + backend::Message::RowDescription(_) | backend::Message::CommandComplete(_) => {} + backend::Message::EmptyQueryResponse | backend::Message::ReadyForQuery(_) => break, + _ => return Err(Error::UnexpectedMessage), } } - Ok(()) } + Ok(()) +} - #[cold] - #[inline(never)] - async fn auth(&mut self, drv: &mut D, cfg: &Config) -> Result<(), Error> - where - D: Drive, - { - let mut params = vec![("client_encoding", "UTF8")]; - if let Some(user) = &cfg.user { - params.push(("user", &**user)); - } - if let Some(dbname) = &cfg.dbname { - params.push(("database", &**dbname)); - } - if let Some(options) = &cfg.options { - params.push(("options", &**options)); - } - if let Some(application_name) = &cfg.application_name { - params.push(("application_name", &**application_name)); - } +#[cold] +#[inline(never)] +async fn auth(drv: &mut D, cfg: &Config, buf: &mut BytesMut) -> Result<(), Error> +where + D: Drive, +{ + let mut params = vec![("client_encoding", "UTF8")]; + if let Some(user) = &cfg.user { + params.push(("user", &**user)); + } + if let Some(dbname) = &cfg.dbname { + params.push(("database", &**dbname)); + } + if let Some(options) = &cfg.options { + params.push(("options", &**options)); + } + if let Some(application_name) = &cfg.application_name { + params.push(("application_name", &**application_name)); + } - let msg = self.try_buf_and_split(|buf| frontend::startup_message(params, buf))?; - drv.send(msg).await?; + frontend::startup_message(params, buf)?; + let msg = buf.split(); + drv.send(msg).await?; - loop { - match drv.recv().await? { - backend::Message::AuthenticationOk => return Ok(()), - backend::Message::AuthenticationCleartextPassword => { - let pass = cfg.get_password().ok_or(AuthenticationError::MissingPassWord)?; - self.send_pass(drv, pass).await?; - } - backend::Message::AuthenticationMd5Password(body) => { - let pass = cfg.get_password().ok_or(AuthenticationError::MissingPassWord)?; - let user = cfg.get_user().ok_or(AuthenticationError::MissingUserName)?.as_bytes(); - let pass = authentication::md5_hash(user, pass, body.salt()); - self.send_pass(drv, pass).await?; - } - backend::Message::AuthenticationSasl(body) => { - let pass = cfg.get_password().ok_or(AuthenticationError::MissingPassWord)?; - - let mut is_scram = false; - let mut is_scram_plus = false; - let mut mechanisms = body.mechanisms(); - - while let Some(mechanism) = mechanisms.next()? { - match mechanism { - sasl::SCRAM_SHA_256 => is_scram = true, - sasl::SCRAM_SHA_256_PLUS => is_scram_plus = true, - _ => {} - } + loop { + match drv.recv().await? { + backend::Message::AuthenticationOk => return Ok(()), + backend::Message::AuthenticationCleartextPassword => { + let pass = cfg.get_password().ok_or(AuthenticationError::MissingPassWord)?; + send_pass(drv, pass, buf).await?; + } + backend::Message::AuthenticationMd5Password(body) => { + let pass = cfg.get_password().ok_or(AuthenticationError::MissingPassWord)?; + let user = cfg.get_user().ok_or(AuthenticationError::MissingUserName)?.as_bytes(); + let pass = authentication::md5_hash(user, pass, body.salt()); + send_pass(drv, pass, buf).await?; + } + backend::Message::AuthenticationSasl(body) => { + let pass = cfg.get_password().ok_or(AuthenticationError::MissingPassWord)?; + + let mut is_scram = false; + let mut is_scram_plus = false; + let mut mechanisms = body.mechanisms(); + + while let Some(mechanism) = mechanisms.next()? { + match mechanism { + sasl::SCRAM_SHA_256 => is_scram = true, + sasl::SCRAM_SHA_256_PLUS => is_scram_plus = true, + _ => {} } + } - let (channel_binding, mechanism) = match (is_scram_plus, is_scram) { - (true, is_scram) => { - let buf = cfg.get_tls_server_end_point(); - if !buf.is_empty() { - ( - sasl::ChannelBinding::tls_server_end_point(buf), - sasl::SCRAM_SHA_256_PLUS, - ) - } else if is_scram { - (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256) - } else { - // server ask for channel binding but no tls_server_end_point can be - // found. - return Err(Error::ToDo); - } - } - (false, true) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256), - // TODO: return "unsupported SASL mechanism" error. - (false, false) => return Err(Error::ToDo), - }; - - let mut scram = sasl::ScramSha256::new(pass, channel_binding); - - let msg = - self.try_buf_and_split(|buf| frontend::sasl_initial_response(mechanism, scram.message(), buf))?; - drv.send(msg).await?; - - match drv.recv().await? { - backend::Message::AuthenticationSaslContinue(body) => { - scram.update(body.data())?; - let msg = self.try_buf_and_split(|buf| frontend::sasl_response(scram.message(), buf))?; - drv.send(msg).await?; + let (channel_binding, mechanism) = match (is_scram_plus, is_scram) { + (true, is_scram) => { + let buf = cfg.get_tls_server_end_point(); + if !buf.is_empty() { + ( + sasl::ChannelBinding::tls_server_end_point(buf), + sasl::SCRAM_SHA_256_PLUS, + ) + } else if is_scram { + (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256) + } else { + // server ask for channel binding but no tls_server_end_point can be + // found. + return Err(Error::ToDo); } - _ => return Err(Error::ToDo), } + (false, true) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256), + // TODO: return "unsupported SASL mechanism" error. + (false, false) => return Err(Error::ToDo), + }; + + let mut scram = sasl::ScramSha256::new(pass, channel_binding); - match drv.recv().await? { - backend::Message::AuthenticationSaslFinal(body) => scram.finish(body.data())?, - _ => return Err(Error::ToDo), + frontend::sasl_initial_response(mechanism, scram.message(), buf)?; + let msg = buf.split(); + drv.send(msg).await?; + + match drv.recv().await? { + backend::Message::AuthenticationSaslContinue(body) => { + scram.update(body.data())?; + frontend::sasl_response(scram.message(), buf)?; + let msg = buf.split(); + drv.send(msg).await?; } + _ => return Err(Error::ToDo), + } + + match drv.recv().await? { + backend::Message::AuthenticationSaslFinal(body) => scram.finish(body.data())?, + _ => return Err(Error::ToDo), } - backend::Message::ErrorResponse(_) => return Err(Error::from(AuthenticationError::WrongPassWord)), - _ => {} } + backend::Message::ErrorResponse(_) => return Err(Error::from(AuthenticationError::WrongPassWord)), + _ => {} } } +} - async fn send_pass(&self, drv: &mut D, pass: impl AsRef<[u8]>) -> Result<(), Error> - where - D: Drive, - { - let msg = self.try_buf_and_split(|buf| frontend::password_message(pass.as_ref(), buf))?; - drv.send(msg).await - } +async fn send_pass(drv: &mut D, pass: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), Error> +where + D: Drive, +{ + frontend::password_message(pass.as_ref(), buf)?; + let msg = buf.split(); + drv.send(msg).await }