7
7
use self :: error_helper:: ErrorHelper ;
8
8
use self :: row:: PgRow ;
9
9
use self :: serialize:: ToSqlHelper ;
10
- use crate :: stmt_cache:: { PrepareCallback , StmtCache } ;
10
+ use crate :: stmt_cache:: { CallbackHelper , QueryFragmentHelper } ;
11
11
use crate :: { AnsiTransactionManager , AsyncConnection , SimpleAsyncConnection } ;
12
- use diesel:: connection:: statement_cache:: { PrepareForCache , StatementCacheKey } ;
13
- use diesel :: connection :: Instrumentation ;
14
- use diesel :: connection :: InstrumentationEvent ;
12
+ use diesel:: connection:: statement_cache:: {
13
+ PrepareForCache , QueryFragmentForCachedStatement , StatementCache ,
14
+ } ;
15
15
use diesel:: connection:: StrQueryHelper ;
16
+ use diesel:: connection:: { CacheSize , Instrumentation } ;
17
+ use diesel:: connection:: { DynInstrumentation , InstrumentationEvent } ;
16
18
use diesel:: pg:: {
17
19
Pg , PgMetadataCache , PgMetadataCacheKey , PgMetadataLookup , PgQueryBuilder , PgTypeMetadata ,
18
20
} ;
@@ -122,13 +124,13 @@ const FAKE_OID: u32 = 0;
122
124
/// [tokio_postgres_rustls]: https://docs.rs/tokio-postgres-rustls/0.12.0/tokio_postgres_rustls/
123
125
pub struct AsyncPgConnection {
124
126
conn : Arc < tokio_postgres:: Client > ,
125
- stmt_cache : Arc < Mutex < StmtCache < diesel:: pg:: Pg , Statement > > > ,
127
+ stmt_cache : Arc < Mutex < StatementCache < diesel:: pg:: Pg , Statement > > > ,
126
128
transaction_state : Arc < Mutex < AnsiTransactionManager > > ,
127
129
metadata_cache : Arc < Mutex < PgMetadataCache > > ,
128
130
connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
129
131
shutdown_channel : Option < oneshot:: Sender < ( ) > > ,
130
132
// a sync mutex is fine here as we only hold it for a really short time
131
- instrumentation : Arc < std:: sync:: Mutex < Option < Box < dyn Instrumentation > > > > ,
133
+ instrumentation : Arc < std:: sync:: Mutex < DynInstrumentation > > ,
132
134
}
133
135
134
136
#[ async_trait:: async_trait]
@@ -162,7 +164,7 @@ impl AsyncConnection for AsyncPgConnection {
162
164
type TransactionManager = AnsiTransactionManager ;
163
165
164
166
async fn establish ( database_url : & str ) -> ConnectionResult < Self > {
165
- let mut instrumentation = diesel :: connection :: get_default_instrumentation ( ) ;
167
+ let mut instrumentation = DynInstrumentation :: default_instrumentation ( ) ;
166
168
instrumentation. on_connection_event ( InstrumentationEvent :: start_establish_connection (
167
169
database_url,
168
170
) ) ;
@@ -229,14 +231,25 @@ impl AsyncConnection for AsyncPgConnection {
229
231
// that means there is only one instance of this arc and
230
232
// we can simply access the inner data
231
233
if let Some ( instrumentation) = Arc :: get_mut ( & mut self . instrumentation ) {
232
- instrumentation. get_mut ( ) . unwrap_or_else ( |p| p. into_inner ( ) )
234
+ & mut * * ( instrumentation. get_mut ( ) . unwrap_or_else ( |p| p. into_inner ( ) ) )
233
235
} else {
234
236
panic ! ( "Cannot access shared instrumentation" )
235
237
}
236
238
}
237
239
238
240
fn set_instrumentation ( & mut self , instrumentation : impl Instrumentation ) {
239
- self . instrumentation = Arc :: new ( std:: sync:: Mutex :: new ( Some ( Box :: new ( instrumentation) ) ) ) ;
241
+ self . instrumentation = Arc :: new ( std:: sync:: Mutex :: new ( instrumentation. into ( ) ) ) ;
242
+ }
243
+
244
+ fn set_prepared_statement_cache_size ( & mut self , size : CacheSize ) {
245
+ // there should be no other pending future when this is called
246
+ // that means there is only one instance of this arc and
247
+ // we can simply access the inner data
248
+ if let Some ( cache) = Arc :: get_mut ( & mut self . stmt_cache ) {
249
+ cache. get_mut ( ) . set_cache_size ( size)
250
+ } else {
251
+ panic ! ( "Cannot access shared statement cache" )
252
+ }
240
253
}
241
254
}
242
255
@@ -293,25 +306,33 @@ fn update_transaction_manager_status<T>(
293
306
query_result
294
307
}
295
308
296
- #[ async_trait:: async_trait]
297
- impl PrepareCallback < Statement , PgTypeMetadata > for Arc < tokio_postgres:: Client > {
298
- async fn prepare (
299
- self ,
300
- sql : & str ,
301
- metadata : & [ PgTypeMetadata ] ,
302
- _is_for_cache : PrepareForCache ,
303
- ) -> QueryResult < ( Statement , Self ) > {
304
- let bind_types = metadata
305
- . iter ( )
306
- . map ( type_from_oid)
307
- . collect :: < QueryResult < Vec < _ > > > ( ) ?;
308
-
309
- let stmt = self
310
- . prepare_typed ( sql, & bind_types)
309
+ fn prepare_statement_helper < ' a > (
310
+ conn : Arc < tokio_postgres:: Client > ,
311
+ sql : & ' a str ,
312
+ _is_for_cache : PrepareForCache ,
313
+ metadata : & [ PgTypeMetadata ] ,
314
+ ) -> CallbackHelper <
315
+ impl Future < Output = QueryResult < ( Statement , Arc < tokio_postgres:: Client > ) > > + Send ,
316
+ > {
317
+ let bind_types = metadata
318
+ . iter ( )
319
+ . map ( type_from_oid)
320
+ . collect :: < QueryResult < Vec < _ > > > ( ) ;
321
+ // ideally we wouldn't clone the SQL string here
322
+ // but as we usually cache statements anyway
323
+ // this is a fixed one time const
324
+ //
325
+ // The probleme with not cloning it is that we then cannot express
326
+ // the right result lifetime anymore (at least not easily)
327
+ let sql = sql. to_string ( ) ;
328
+ CallbackHelper ( async move {
329
+ let bind_types = bind_types?;
330
+ let stmt = conn
331
+ . prepare_typed ( & sql, & bind_types)
311
332
. await
312
333
. map_err ( ErrorHelper ) ;
313
- Ok ( ( stmt?, self ) )
314
- }
334
+ Ok ( ( stmt?, conn ) )
335
+ } )
315
336
}
316
337
317
338
fn type_from_oid ( t : & PgTypeMetadata ) -> QueryResult < Type > {
@@ -369,7 +390,7 @@ impl AsyncPgConnection {
369
390
None ,
370
391
None ,
371
392
Arc :: new ( std:: sync:: Mutex :: new (
372
- diesel :: connection :: get_default_instrumentation ( ) ,
393
+ DynInstrumentation :: default_instrumentation ( ) ,
373
394
) ) ,
374
395
)
375
396
. await
@@ -390,9 +411,7 @@ impl AsyncPgConnection {
390
411
client,
391
412
Some ( error_rx) ,
392
413
Some ( shutdown_tx) ,
393
- Arc :: new ( std:: sync:: Mutex :: new (
394
- diesel:: connection:: get_default_instrumentation ( ) ,
395
- ) ) ,
414
+ Arc :: new ( std:: sync:: Mutex :: new ( DynInstrumentation :: none ( ) ) ) ,
396
415
)
397
416
. await
398
417
}
@@ -401,11 +420,11 @@ impl AsyncPgConnection {
401
420
conn : tokio_postgres:: Client ,
402
421
connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
403
422
shutdown_channel : Option < oneshot:: Sender < ( ) > > ,
404
- instrumentation : Arc < std:: sync:: Mutex < Option < Box < dyn Instrumentation > > > > ,
423
+ instrumentation : Arc < std:: sync:: Mutex < DynInstrumentation > > ,
405
424
) -> ConnectionResult < Self > {
406
425
let mut conn = Self {
407
426
conn : Arc :: new ( conn) ,
408
- stmt_cache : Arc :: new ( Mutex :: new ( StmtCache :: new ( ) ) ) ,
427
+ stmt_cache : Arc :: new ( Mutex :: new ( StatementCache :: new ( ) ) ) ,
409
428
transaction_state : Arc :: new ( Mutex :: new ( AnsiTransactionManager :: default ( ) ) ) ,
410
429
metadata_cache : Arc :: new ( Mutex :: new ( PgMetadataCache :: new ( ) ) ) ,
411
430
connection_future,
@@ -559,23 +578,27 @@ impl AsyncPgConnection {
559
578
} ) ?;
560
579
}
561
580
}
562
- let key = match query_id {
563
- Some ( id) => StatementCacheKey :: Type ( id) ,
564
- None => StatementCacheKey :: Sql {
565
- sql : sql. clone ( ) ,
566
- bind_types : bind_collector. metadata . clone ( ) ,
567
- } ,
568
- } ;
569
581
let stmt = {
570
582
let mut stmt_cache = stmt_cache. lock ( ) . await ;
583
+ let helper = QueryFragmentHelper {
584
+ sql : sql. clone ( ) ,
585
+ safe_to_cache : is_safe_to_cache_prepared,
586
+ } ;
587
+ let instrumentation = Arc :: clone ( & instrumentation) ;
571
588
stmt_cache
572
- . cached_prepared_statement (
573
- key ,
574
- sql . clone ( ) ,
575
- is_safe_to_cache_prepared ,
589
+ . cached_statement_non_generic (
590
+ query_id ,
591
+ & helper ,
592
+ & Pg ,
576
593
& bind_collector. metadata ,
577
594
raw_connection. clone ( ) ,
578
- & instrumentation
595
+ prepare_statement_helper,
596
+ & mut move |event : InstrumentationEvent < ' _ > | {
597
+ // we wrap this lock into another callback to prevent locking
598
+ // the instrumentation longer than necessary
599
+ instrumentation. lock ( ) . unwrap_or_else ( |e| e. into_inner ( ) )
600
+ . on_connection_event ( event) ;
601
+ } ,
579
602
)
580
603
. await ?
581
604
. 0
@@ -894,6 +917,16 @@ impl crate::pooled_connection::PoolableConnection for AsyncPgConnection {
894
917
}
895
918
}
896
919
920
+ impl QueryFragmentForCachedStatement < Pg > for QueryFragmentHelper {
921
+ fn construct_sql ( & self , _backend : & Pg ) -> QueryResult < String > {
922
+ Ok ( self . sql . clone ( ) )
923
+ }
924
+
925
+ fn is_safe_to_cache_prepared ( & self , _backend : & Pg ) -> QueryResult < bool > {
926
+ Ok ( self . safe_to_cache )
927
+ }
928
+ }
929
+
897
930
#[ cfg( test) ]
898
931
mod tests {
899
932
use super :: * ;
0 commit comments