Skip to content

Commit

Permalink
fix test.
Browse files Browse the repository at this point in the history
  • Loading branch information
fakeshadow committed Apr 2, 2024
1 parent fa74a69 commit eb4aba3
Show file tree
Hide file tree
Showing 12 changed files with 255 additions and 108 deletions.
2 changes: 1 addition & 1 deletion postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use xitca_io::bytes::BytesMut;
use xitca_unsafe_collection::no_hash::NoHashBuilder;

use super::{
driver::{DriverTx, Response},
driver::{codec::Response, DriverTx},
error::Error,
statement::Statement,
util::lock::Lock,
Expand Down
27 changes: 22 additions & 5 deletions postgres/src/driver/mod.rs → postgres/src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ pub(crate) mod codec;
pub(crate) mod generic;

mod connect;
mod response;

pub use response::Response;

pub(crate) use generic::DriverTx;

Expand All @@ -25,7 +22,10 @@ use core::{
use std::net::SocketAddr;

use postgres_protocol::message::backend;
use xitca_io::bytes::BytesMut;
use xitca_io::{
bytes::BytesMut,
io::{AsyncIo, AsyncIoDyn},
};

use super::{client::Client, config::Config, error::Error, iter::AsyncLendingIterator};

Expand All @@ -41,7 +41,7 @@ pub(super) async fn connect(cfg: &mut Config) -> Result<(Client, Driver), Error>
let mut err = None;
let hosts = cfg.get_hosts().to_vec();
for host in hosts {
match self::connect::_connect(host, cfg).await {
match self::connect::connect(host, cfg).await {
Ok((tx, drv)) => return Ok((Client::new(tx), drv)),
Err(e) => err = Some(e),
}
Expand All @@ -50,6 +50,14 @@ pub(super) async fn connect(cfg: &mut Config) -> Result<(Client, Driver), Error>
Err(err.unwrap())
}

pub(super) async fn connect_io<Io>(io: Io, cfg: &mut Config) -> Result<(Client, Driver), Error>
where
Io: AsyncIo + Send + 'static,
{
let (tx, drv) = self::connect::connect_io(io, cfg).await?;
Ok((Client::new(tx), drv))
}

/// async driver of [Client](crate::Client).
/// it handles IO and emit server sent message that do not belong to any query with [AsyncLendingIterator]
/// trait impl.
Expand Down Expand Up @@ -83,6 +91,7 @@ impl Driver {
async fn run_till_closed(self) {
let _ = match self.inner {
_Driver::Tcp(drv) => drv.run().await,
_Driver::Dynamic(drv) => drv.run().await,
#[cfg(feature = "tls")]
_Driver::Tls(drv) => drv.run().await,
#[cfg(unix)]
Expand All @@ -102,6 +111,12 @@ impl Driver {
}
}

pub(super) fn dynamic(drv: GenericDriver<Box<dyn AsyncIoDyn + Send>>) -> Self {
Self {
inner: _Driver::Dynamic(drv),
}
}

#[cfg(feature = "tls")]
pub(super) fn tls(drv: GenericDriver<TlsStream<ClientConnection, TcpStream>>) -> Self {
Self {
Expand Down Expand Up @@ -133,6 +148,7 @@ impl Driver {
// TODO: use Box<dyn AsyncIterator> when life time GAT is object safe.
enum _Driver {
Tcp(GenericDriver<TcpStream>),
Dynamic(GenericDriver<Box<dyn AsyncIoDyn + Send>>),
#[cfg(feature = "tls")]
Tls(GenericDriver<TlsStream<ClientConnection, TcpStream>>),
#[cfg(unix)]
Expand All @@ -151,6 +167,7 @@ impl AsyncLendingIterator for Driver {
async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
match self.inner {
_Driver::Tcp(ref mut drv) => drv.try_next().await,
_Driver::Dynamic(ref mut drv) => drv.try_next().await,
#[cfg(feature = "tls")]
_Driver::Tls(ref mut drv) => drv.try_next().await,
#[cfg(unix)]
Expand Down
38 changes: 37 additions & 1 deletion postgres/src/driver/codec.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,44 @@
use core::{
future::{poll_fn, Future},
task::{ready, Poll},
};

use postgres_protocol::message::backend;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use xitca_io::bytes::BytesMut;

use crate::error::Error;
use crate::error::{DriverDown, Error};

pub struct Response {
rx: ResponseReceiver,
buf: BytesMut,
}

impl Response {
pub(crate) fn new(rx: ResponseReceiver) -> Self {
Self {
rx,
buf: BytesMut::new(),
}
}

pub(crate) fn recv(&mut self) -> impl Future<Output = Result<backend::Message, Error>> + '_ {
poll_fn(|cx| {
if self.buf.is_empty() {
self.buf = ready!(self.rx.poll_recv(cx)).ok_or_else(|| DriverDown(BytesMut::new()))?;
}

let res = match backend::Message::parse(&mut self.buf)?.expect("must not parse message from empty buffer.")
{
// TODO: error response.
backend::Message::ErrorResponse(_body) => Err(Error::todo()),
msg => Ok(msg),
};

Poll::Ready(res)
})
}
}

#[derive(Debug)]
pub(crate) struct ResponseSender {
Expand Down
13 changes: 12 additions & 1 deletion postgres/src/driver/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use super::{

#[cold]
#[inline(never)]
pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(DriverTx, Driver), Error> {
pub(super) async fn connect(host: Host, cfg: &mut Config) -> Result<(DriverTx, Driver), Error> {
// this block have repeated code due to HRTB limitation.
// namely for <'_> AsyncIo::Future<'_>: Send bound can not be expressed correctly.
match host {
Expand Down Expand Up @@ -84,6 +84,17 @@ pub(super) async fn _connect(host: Host, cfg: &mut Config) -> Result<(DriverTx,
}
}

#[cold]
#[inline(never)]
pub(super) async fn connect_io<Io>(io: Io, cfg: &mut Config) -> Result<(DriverTx, Driver), Error>
where
Io: AsyncIo + Send + 'static,
{
let (mut drv, tx) = GenericDriver::new(Box::new(io) as _);
prepare_session(&mut drv, cfg).await?;
Ok((tx, Driver::dynamic(drv)))
}

async fn connect_tcp(host: &str, ports: &[u16]) -> Result<TcpStream, Error> {
let addrs = super::resolve(host, ports).await?;

Expand Down
3 changes: 1 addition & 2 deletions postgres/src/driver/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ use xitca_unsafe_collection::futures::{Select as _, SelectOutput};
use crate::{error::Error, iter::AsyncLendingIterator};

use super::{
codec::{Request, ResponseMessage, ResponseSender},
response::Response,
codec::{Request, Response, ResponseMessage, ResponseSender},
Drive,
};

Expand Down
Loading

0 comments on commit eb4aba3

Please sign in to comment.