Skip to content

Commit

Permalink
don't panic when collecting row affected from pipeline item
Browse files Browse the repository at this point in the history
  • Loading branch information
fakeshadow committed Oct 1, 2024
1 parent 7c40283 commit ffb938f
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 48 deletions.
2 changes: 1 addition & 1 deletion postgres-codegen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ proc-macro = true
proc-macro2 = "1"
syn = { version = "2", features = ["full"] }
quote = "1.0"
pg_query = { version = "5" }
sqlparser = "0.51.0"
6 changes: 5 additions & 1 deletion postgres-codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ use syn::{
token::Comma,
Expr, ExprReference, Lit, LitStr,
};
use sqlparser::{
parser::Parser,
dialect::PostgreSqlDialect};


#[proc_macro]
pub fn sql(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
Expand All @@ -30,7 +34,7 @@ impl Parse for Query {
fn parse(input: ParseStream) -> syn::Result<Self> {
let sql = input.parse::<LitStr>()?;

pg_query::parse(&sql.value()).map_err(|e| syn::Error::new(sql.span(), e.to_string()))?;
Parser::parse_sql(&PostgreSqlDialect {}, &sql.value()).map_err(|e| syn::Error::new(sql.span(), e.to_string()))?;

let mut exprs = Vec::new();
let mut types = Vec::new();
Expand Down
23 changes: 15 additions & 8 deletions postgres/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use xitca_io::bytes::BytesMut;
use super::{
column::Column,
driver::codec::{self, encode::Encode, Response},
error::Error,
error::{Completed, Error},
execute::{Execute, ExecuteMut},
iter::AsyncLendingIterator,
query::Query,
Expand Down Expand Up @@ -142,6 +142,7 @@ impl Pipeline<'_, Owned, true> {
/// start a new pipeline with given capacity.
/// capacity represent how many queries will be contained by a single pipeline. a determined cap
/// can possibly reduce memory reallocation when constructing the pipeline.
#[inline]
pub fn with_capacity(cap: usize) -> Self {
Self::_with_capacity(cap)
}
Expand All @@ -163,6 +164,7 @@ impl Pipeline<'_, Owned, false> {
/// start a new un-sync pipeline with given capacity.
/// capacity represent how many queries will be contained by a single pipeline. a determined cap
/// can possibly reduce memory reallocation when constructing the pipeline.
#[inline]
pub fn unsync_with_capacity(cap: usize) -> Self {
Self::_with_capacity(cap)
}
Expand Down Expand Up @@ -242,6 +244,7 @@ where
type ExecuteMutOutput = Ready<Self::QueryMutOutput>;
type QueryMutOutput = Result<(), Error>;

#[inline]
fn execute_mut(self, pipe: &mut Pipeline<'a, B, SYNC_MODE>) -> Self::ExecuteMutOutput {
ready(self.query_mut(pipe))
}
Expand All @@ -254,6 +257,7 @@ where
.inspect_err(|_| pipe.buf.truncate(len))
}

#[inline]
fn execute_mut_blocking(self, pipe: &mut Pipeline<'a, B, SYNC_MODE>) -> Self::QueryMutOutput {
self.query_mut(pipe)
}
Expand Down Expand Up @@ -422,12 +426,11 @@ pub struct PipelineItem<'a> {

impl PipelineItem<'_> {
/// collect rows affected by this pipelined query. [Row] information will be ignored.
///
/// # Panic
/// calling this method on an already finished PipelineItem will cause panic. PipelineItem is marked as finished
/// when its [AsyncLendingIterator::try_next] method returns [Option::None]
pub async fn row_affected(mut self) -> Result<u64, Error> {
assert!(!self.finished, "PipelineItem has already finished");
if self.finished {
return Err(Completed.into());
}

loop {
match self.res.recv().await? {
backend::Message::DataRow(_) => {}
Expand All @@ -440,8 +443,12 @@ impl PipelineItem<'_> {
}
}

fn row_affected_blocking(mut self) -> Result<u64, Error> {
assert!(!self.finished, "PipelineItem has already finished");
/// blocking version of [`PipelineItem::row_affected`]
pub fn row_affected_blocking(mut self) -> Result<u64, Error> {
if self.finished {
return Err(Completed.into());
}

loop {
match self.res.blocking_recv()? {
backend::Message::DataRow(_) => {}
Expand Down
83 changes: 45 additions & 38 deletions postgres/src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use core::{
future::Future,
mem,
pin::Pin,
task::{ready, Context, Poll},
task::{Context, Poll},
};

use std::{
Expand Down Expand Up @@ -131,26 +131,7 @@ pub struct PoolConnection<'a> {
_permit: SemaphorePermit<'a>,
}

impl<'p> PoolConnection<'p> {
#[inline(never)]
fn prepare_slow<'a>(&'a mut self, named: StatementNamed<'a>) -> BoxedFuture<'a, Result<Arc<Statement>, Error>> {
Box::pin(async move {
let stmt = Statement::named(named.stmt, named.types).execute(self).await?.leak();
Ok(self.insert_cache(named.stmt, stmt))
})
}

fn prepare_slow_blocking(&mut self, named: StatementNamed<'_>) -> Result<Arc<Statement>, Error> {
let stmt = Statement::named(named.stmt, named.types).execute_blocking(self)?.leak();
Ok(self.insert_cache(named.stmt, stmt))
}

fn insert_cache(&mut self, named: &str, stmt: Statement) -> Arc<Statement> {
let stmt = Arc::new(stmt);
self.conn_mut().statements.insert(Box::from(named), stmt.clone());
stmt
}

impl PoolConnection<'_> {
/// function the same as [`Client::transaction`]
#[inline]
pub fn transaction(&mut self) -> impl Future<Output = Result<Transaction<Self>, Error>> + Send {
Expand Down Expand Up @@ -231,6 +212,12 @@ impl<'p> PoolConnection<'p> {
self.conn().client.cancel_token()
}

fn insert_cache(&mut self, named: &str, stmt: Statement) -> Arc<Statement> {
let stmt = Arc::new(stmt);
self.conn_mut().statements.insert(Box::from(named), stmt.clone());
stmt
}

fn conn(&self) -> &PoolClient {
self.conn.as_ref().unwrap()
}
Expand Down Expand Up @@ -303,29 +290,35 @@ impl PoolClient {
}
}

impl<'c, 'p, 's> ExecuteMut<'c, PoolConnection<'p>> for StatementNamed<'s>
impl<'c, 's> ExecuteMut<'c, PoolConnection<'_>> for StatementNamed<'s>
where
's: 'c,
'p: 'c,
{
type ExecuteMutOutput = StatementCacheFuture<'c>;
type QueryMutOutput = Self::ExecuteMutOutput;

fn execute_mut(self, cli: &'c mut PoolConnection<'p>) -> Self::ExecuteMutOutput {
fn execute_mut(self, cli: &'c mut PoolConnection) -> Self::ExecuteMutOutput {
match cli.conn().statements.get(self.stmt) {
Some(stmt) => StatementCacheFuture::Cached(stmt.clone()),
None => StatementCacheFuture::Prepared(cli.prepare_slow(self)),
None => StatementCacheFuture::Prepared(Box::pin(async move {
let stmt = self.execute(cli).await?.leak();
Ok(cli.insert_cache(self.stmt, stmt))
})),
}
}

fn query_mut(self, cli: &'c mut PoolConnection<'p>) -> Self::QueryMutOutput {
#[inline]
fn query_mut(self, cli: &'c mut PoolConnection) -> Self::QueryMutOutput {
self.execute_mut(cli)
}

fn execute_mut_blocking(self, cli: &mut PoolConnection) -> <Self::ExecuteMutOutput as Future>::Output {
match cli.conn().statements.get(self.stmt) {
Some(stmt) => Ok(stmt.clone()),
None => cli.prepare_slow_blocking(self),
None => {
let stmt = self.execute_blocking(cli)?.leak();
Ok(cli.insert_cache(self.stmt, stmt))
}
}
}
}
Expand All @@ -341,19 +334,33 @@ impl Future for StatementCacheFuture<'_> {

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match this {
Self::Cached(_) => {
let Self::Cached(stmt) = mem::replace(this, Self::Done) else {
unreachable!("")
};
Poll::Ready(Ok(stmt))
}
Self::Prepared(ref mut fut) => {
let res = ready!(fut.as_mut().poll(cx));
drop(mem::replace(this, Self::Done));
Poll::Ready(res)
match mem::replace(this, Self::Done) {
Self::Cached(stmt) => Poll::Ready(Ok(stmt)),
Self::Prepared(mut fut) => {
let res = fut.as_mut().poll(cx);
if res.is_pending() {
drop(mem::replace(this, Self::Prepared(fut)));
}
res
}
Self::Done => panic!("StatementCacheFuture polled after finish"),
}
}
}

#[cfg(test)]
mod test {
use super::*;

#[tokio::test]
async fn pool() {
let pool = Pool::builder("postgres://postgres:postgres@localhost:5432")
.build()
.unwrap();

let mut conn = pool.get().await.unwrap();

let stmt = Statement::named("SELECT 1", &[]).execute_mut(&mut conn).await.unwrap();
stmt.execute(&conn.consume()).await.unwrap();
}
}
1 change: 1 addition & 0 deletions postgres/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ impl Statement {
}
}

#[derive(Clone, Copy)]
pub struct StatementNamed<'a> {
pub(crate) stmt: &'a str,
pub(crate) types: &'a [Type],
Expand Down

0 comments on commit ffb938f

Please sign in to comment.