diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index a0fe95c6963b..2ad7df2ba1b8 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -392,6 +392,7 @@ where &mut self, ) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData; + /// Get the instrumentation instance stored in this connection #[diesel_derives::__diesel_public_if( feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes" )] diff --git a/diesel/src/connection/statement_cache.rs b/diesel/src/connection/statement_cache.rs index 2d4cd7ef4a18..af6d2968d5b9 100644 --- a/diesel/src/connection/statement_cache.rs +++ b/diesel/src/connection/statement_cache.rs @@ -255,7 +255,9 @@ where doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")) )] pub trait QueryFragmentForCachedStatement<DB> { + /// Convert the query fragment into a SQL string for the given backend fn construct_sql(&self, backend: &DB) -> QueryResult<String>; + /// Check whether it's safe to cache the query fn is_safe_to_cache_prepared(&self, backend: &DB) -> QueryResult<bool>; } impl<T, DB> QueryFragmentForCachedStatement<DB> for T @@ -269,6 +271,7 @@ where self.to_sql(&mut query_builder, backend)?; Ok(query_builder.finish()) } + fn is_safe_to_cache_prepared(&self, backend: &DB) -> QueryResult<bool> { <T as QueryFragment<DB>>::is_safe_to_cache_prepared(self, backend) } diff --git a/diesel/src/query_builder/bind_collector.rs b/diesel/src/query_builder/bind_collector.rs index bfe470715b1b..7db821a75a98 100644 --- a/diesel/src/query_builder/bind_collector.rs +++ b/diesel/src/query_builder/bind_collector.rs @@ -32,6 +32,18 @@ pub trait BindCollector<'a, DB: TypeMetadata>: Sized { where DB: Backend + HasSqlType<T>, U: ToSql<T, DB> + ?Sized + 'a; + + /// Push a null value with the given type information onto the bind collector + /// + // For backward compatibility reasons we provide a default implementation + // but custom backends that want to support `#[derive(MultiConnection)]` + // need to provide a customized implementation of this function + #[diesel_derives::__diesel_public_if( + feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes" + )] + fn push_null_value(&mut self, _metadata: DB::TypeMetadata) -> QueryResult<()> { + Ok(()) + } } #[derive(Debug)] @@ -105,6 +117,12 @@ where self.metadata.push(metadata); Ok(()) } + + fn push_null_value(&mut self, metadata: DB::TypeMetadata) -> QueryResult<()> { + self.metadata.push(metadata); + self.binds.push(None); + Ok(()) + } } // This is private for now as we may want to add `Into` impls for the wrapper type diff --git a/diesel/src/sqlite/connection/bind_collector.rs b/diesel/src/sqlite/connection/bind_collector.rs index 8cce52c35ed3..e276c1b5b0c1 100644 --- a/diesel/src/sqlite/connection/bind_collector.rs +++ b/diesel/src/sqlite/connection/bind_collector.rs @@ -194,4 +194,9 @@ impl<'a> BindCollector<'a, Sqlite> for SqliteBindCollector<'a> { )); Ok(()) } + + fn push_null_value(&mut self, metadata: SqliteType) -> QueryResult<()> { + self.binds.push((InternalSqliteBindValue::Null, metadata)); + Ok(()) + } } diff --git a/diesel_derives/src/multiconnection.rs b/diesel_derives/src/multiconnection.rs index b4ba74214bda..cd2bde892e24 100644 --- a/diesel_derives/src/multiconnection.rs +++ b/diesel_derives/src/multiconnection.rs @@ -212,6 +212,13 @@ fn generate_connection_impl( } }); + let impl_begin_test_transaction = connection_types.iter().map(|c| { + let ident = c.name; + quote::quote! { + Self::#ident(conn) => conn.begin_test_transaction() + } + }); + let r2d2_impl = if cfg!(feature = "r2d2") { let impl_ping_r2d2 = connection_types.iter().map(|c| { let ident = c.name; @@ -295,6 +302,9 @@ fn generate_connection_impl( let mut query_builder = self.query_builder.duplicate(); self.inner.to_sql(&mut query_builder, &self.backend)?; pass.push_sql(&query_builder.finish()); + if !self.inner.is_safe_to_cache_prepared(&self.backend)? { + pass.unsafe_to_cache_prepared(); + } if let Some((outer_collector, lookup)) = pass.bind_collector() { C::handle_inner_pass(outer_collector, lookup, &self.backend, &self.inner)?; } @@ -356,6 +366,12 @@ fn generate_connection_impl( #(#instrumentation_impl,)* } } + + fn begin_test_transaction(&mut self) -> diesel::QueryResult<()> { + match self { + #(#impl_begin_test_transaction,)* + } + } } impl LoadConnection for MultiConnection @@ -757,6 +773,18 @@ fn generate_bind_collector(connection_types: &[ConnectionVariant]) -> TokenStrea } }); + let push_null_to_inner_collector = connection_types + .iter() + .map(|c| { + let ident = c.name; + quote::quote! { + (Self::#ident(ref mut bc), super::backend::MultiTypeMetadata{ #ident: Some(metadata), .. }) => { + bc.push_null_value(metadata)?; + } + } + }) + .collect::<Vec<_>>(); + let push_bound_value_super_traits = connection_types .iter() .map(|c| { @@ -948,20 +976,14 @@ fn generate_bind_collector(connection_types: &[ConnectionVariant]) -> TokenStrea // set the `inner` field of `BindValue` to something for the `None` // case. Therefore we need to handle that explicitly here. // - // We just use a specific sql + rust type here to workaround - // the fact that rustc is not able to see that the underlying DBMS - // must support that sql + rust type combination. All tested DBMS - // (postgres, sqlite, mysql, oracle) seems to not care about the - // actual type here and coerce null values to the "right" type - // anyway - BindValue { - inner: Some(InnerBindValue { - value: InnerBindValueKind::Null, - push_bound_value_to_collector: &PushBoundValueToCollectorImpl { - p: std::marker::PhantomData::<(diesel::sql_types::Integer, i32)> - } - }) + let metadata = <MultiBackend as diesel::sql_types::HasSqlType<T>>::metadata(metadata_lookup); + match (self, metadata) { + #(#push_null_to_inner_collector)* + _ => { + unreachable!("We have matching metadata") + }, } + return Ok(()); } else { out.into_inner() } @@ -972,6 +994,14 @@ fn generate_bind_collector(connection_types: &[ConnectionVariant]) -> TokenStrea Ok(()) } + + fn push_null_value(&mut self, metadata: super::backend::MultiTypeMetadata) -> diesel::QueryResult<()> { + match (self, metadata) { + #(#push_null_to_inner_collector)* + _ => unreachable!("We have matching metadata"), + } + Ok(()) + } } #(#to_sql_impls)* @@ -1368,8 +1398,8 @@ fn generate_backend(connection_types: &[ConnectionVariant]) -> TokenStream { let type_metadata_variants = connection_types.iter().map(|c| { let ident = c.name; let ty = c.ty; - quote::quote!{ - #ident(<<#ty as diesel::Connection>::Backend as diesel::sql_types::TypeMetadata>::TypeMetadata) + quote::quote! { + pub(super) #ident: Option<<<#ty as diesel::Connection>::Backend as diesel::sql_types::TypeMetadata>::TypeMetadata> } }); @@ -1456,7 +1486,7 @@ fn generate_backend(connection_types: &[ConnectionVariant]) -> TokenStream { quote::quote!{ if let Some(lookup) = <#ty as diesel::internal::derives::multiconnection::MultiConnectionHelper>::from_any(lookup) { - return MultiTypeMetadata::#name(<<#ty as diesel::Connection>::Backend as diesel::sql_types::HasSqlType<ST>>::metadata(lookup)); + ret.#name = Some(<<#ty as diesel::Connection>::Backend as diesel::sql_types::HasSqlType<ST>>::metadata(lookup)); } } @@ -1480,8 +1510,9 @@ fn generate_backend(connection_types: &[ConnectionVariant]) -> TokenStream { pub fn lookup_sql_type<ST>(lookup: &mut dyn std::any::Any) -> MultiTypeMetadata where #(#lookup_sql_type_bounds,)* { + let mut ret = MultiTypeMetadata::default(); #(#lookup_impl)* - unreachable!() + ret } } @@ -1519,7 +1550,9 @@ fn generate_backend(connection_types: &[ConnectionVariant]) -> TokenStream { type BindCollector<'a> = super::bind_collector::MultiBindCollector<'a>; } - pub enum MultiTypeMetadata { + #[derive(Default)] + #[allow(non_snake_case)] + pub struct MultiTypeMetadata { #(#type_metadata_variants,)* }