Skip to content

Commit

Permalink
Merge pull request #3907 from weiznich/fix/custom_test_transaction_fo…
Browse files Browse the repository at this point in the history
…r_multiconnection

Fix several issues with `#[derive(MultiConnection)]`
  • Loading branch information
weiznich authored Feb 19, 2024
2 parents 7c4ba73 + 49ac723 commit 57346fa
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 18 deletions.
1 change: 1 addition & 0 deletions diesel/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ where
&mut self,
) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData;

/// Get the instrumentation instance stored in this connection
#[diesel_derives::__diesel_public_if(
feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"
)]
Expand Down
3 changes: 3 additions & 0 deletions diesel/src/connection/statement_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ where
doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))
)]
pub trait QueryFragmentForCachedStatement<DB> {
/// Convert the query fragment into a SQL string for the given backend
fn construct_sql(&self, backend: &DB) -> QueryResult<String>;
/// Check whether it's safe to cache the query
fn is_safe_to_cache_prepared(&self, backend: &DB) -> QueryResult<bool>;
}
impl<T, DB> QueryFragmentForCachedStatement<DB> for T
Expand All @@ -269,6 +271,7 @@ where
self.to_sql(&mut query_builder, backend)?;
Ok(query_builder.finish())
}

fn is_safe_to_cache_prepared(&self, backend: &DB) -> QueryResult<bool> {
<T as QueryFragment<DB>>::is_safe_to_cache_prepared(self, backend)
}
Expand Down
18 changes: 18 additions & 0 deletions diesel/src/query_builder/bind_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ pub trait BindCollector<'a, DB: TypeMetadata>: Sized {
where
DB: Backend + HasSqlType<T>,
U: ToSql<T, DB> + ?Sized + 'a;

/// Push a null value with the given type information onto the bind collector
///
// For backward compatibility reasons we provide a default implementation
// but custom backends that want to support `#[derive(MultiConnection)]`
// need to provide a customized implementation of this function
#[diesel_derives::__diesel_public_if(
feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"
)]
fn push_null_value(&mut self, _metadata: DB::TypeMetadata) -> QueryResult<()> {
Ok(())
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -105,6 +117,12 @@ where
self.metadata.push(metadata);
Ok(())
}

fn push_null_value(&mut self, metadata: DB::TypeMetadata) -> QueryResult<()> {
self.metadata.push(metadata);
self.binds.push(None);
Ok(())
}
}

// This is private for now as we may want to add `Into` impls for the wrapper type
Expand Down
5 changes: 5 additions & 0 deletions diesel/src/sqlite/connection/bind_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,9 @@ impl<'a> BindCollector<'a, Sqlite> for SqliteBindCollector<'a> {
));
Ok(())
}

fn push_null_value(&mut self, metadata: SqliteType) -> QueryResult<()> {
self.binds.push((InternalSqliteBindValue::Null, metadata));
Ok(())
}
}
69 changes: 51 additions & 18 deletions diesel_derives/src/multiconnection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ fn generate_connection_impl(
}
});

let impl_begin_test_transaction = connection_types.iter().map(|c| {
let ident = c.name;
quote::quote! {
Self::#ident(conn) => conn.begin_test_transaction()
}
});

let r2d2_impl = if cfg!(feature = "r2d2") {
let impl_ping_r2d2 = connection_types.iter().map(|c| {
let ident = c.name;
Expand Down Expand Up @@ -295,6 +302,9 @@ fn generate_connection_impl(
let mut query_builder = self.query_builder.duplicate();
self.inner.to_sql(&mut query_builder, &self.backend)?;
pass.push_sql(&query_builder.finish());
if !self.inner.is_safe_to_cache_prepared(&self.backend)? {
pass.unsafe_to_cache_prepared();
}
if let Some((outer_collector, lookup)) = pass.bind_collector() {
C::handle_inner_pass(outer_collector, lookup, &self.backend, &self.inner)?;
}
Expand Down Expand Up @@ -356,6 +366,12 @@ fn generate_connection_impl(
#(#instrumentation_impl,)*
}
}

fn begin_test_transaction(&mut self) -> diesel::QueryResult<()> {
match self {
#(#impl_begin_test_transaction,)*
}
}
}

impl LoadConnection for MultiConnection
Expand Down Expand Up @@ -757,6 +773,18 @@ fn generate_bind_collector(connection_types: &[ConnectionVariant]) -> TokenStrea
}
});

let push_null_to_inner_collector = connection_types
.iter()
.map(|c| {
let ident = c.name;
quote::quote! {
(Self::#ident(ref mut bc), super::backend::MultiTypeMetadata{ #ident: Some(metadata), .. }) => {
bc.push_null_value(metadata)?;
}
}
})
.collect::<Vec<_>>();

let push_bound_value_super_traits = connection_types
.iter()
.map(|c| {
Expand Down Expand Up @@ -948,20 +976,14 @@ fn generate_bind_collector(connection_types: &[ConnectionVariant]) -> TokenStrea
// set the `inner` field of `BindValue` to something for the `None`
// case. Therefore we need to handle that explicitly here.
//
// We just use a specific sql + rust type here to workaround
// the fact that rustc is not able to see that the underlying DBMS
// must support that sql + rust type combination. All tested DBMS
// (postgres, sqlite, mysql, oracle) seems to not care about the
// actual type here and coerce null values to the "right" type
// anyway
BindValue {
inner: Some(InnerBindValue {
value: InnerBindValueKind::Null,
push_bound_value_to_collector: &PushBoundValueToCollectorImpl {
p: std::marker::PhantomData::<(diesel::sql_types::Integer, i32)>
}
})
let metadata = <MultiBackend as diesel::sql_types::HasSqlType<T>>::metadata(metadata_lookup);
match (self, metadata) {
#(#push_null_to_inner_collector)*
_ => {
unreachable!("We have matching metadata")
},
}
return Ok(());
} else {
out.into_inner()
}
Expand All @@ -972,6 +994,14 @@ fn generate_bind_collector(connection_types: &[ConnectionVariant]) -> TokenStrea

Ok(())
}

fn push_null_value(&mut self, metadata: super::backend::MultiTypeMetadata) -> diesel::QueryResult<()> {
match (self, metadata) {
#(#push_null_to_inner_collector)*
_ => unreachable!("We have matching metadata"),
}
Ok(())
}
}

#(#to_sql_impls)*
Expand Down Expand Up @@ -1368,8 +1398,8 @@ fn generate_backend(connection_types: &[ConnectionVariant]) -> TokenStream {
let type_metadata_variants = connection_types.iter().map(|c| {
let ident = c.name;
let ty = c.ty;
quote::quote!{
#ident(<<#ty as diesel::Connection>::Backend as diesel::sql_types::TypeMetadata>::TypeMetadata)
quote::quote! {
pub(super) #ident: Option<<<#ty as diesel::Connection>::Backend as diesel::sql_types::TypeMetadata>::TypeMetadata>
}
});

Expand Down Expand Up @@ -1456,7 +1486,7 @@ fn generate_backend(connection_types: &[ConnectionVariant]) -> TokenStream {

quote::quote!{
if let Some(lookup) = <#ty as diesel::internal::derives::multiconnection::MultiConnectionHelper>::from_any(lookup) {
return MultiTypeMetadata::#name(<<#ty as diesel::Connection>::Backend as diesel::sql_types::HasSqlType<ST>>::metadata(lookup));
ret.#name = Some(<<#ty as diesel::Connection>::Backend as diesel::sql_types::HasSqlType<ST>>::metadata(lookup));
}
}

Expand All @@ -1480,8 +1510,9 @@ fn generate_backend(connection_types: &[ConnectionVariant]) -> TokenStream {
pub fn lookup_sql_type<ST>(lookup: &mut dyn std::any::Any) -> MultiTypeMetadata
where #(#lookup_sql_type_bounds,)*
{
let mut ret = MultiTypeMetadata::default();
#(#lookup_impl)*
unreachable!()
ret
}
}

Expand Down Expand Up @@ -1519,7 +1550,9 @@ fn generate_backend(connection_types: &[ConnectionVariant]) -> TokenStream {
type BindCollector<'a> = super::bind_collector::MultiBindCollector<'a>;
}

pub enum MultiTypeMetadata {
#[derive(Default)]
#[allow(non_snake_case)]
pub struct MultiTypeMetadata {
#(#type_metadata_variants,)*
}

Expand Down

0 comments on commit 57346fa

Please sign in to comment.