Skip to content

Commit 4d349df

Browse files
committedJan 16, 2025··
Share the statement cache with diesel
This commit refactors diesel-async to use the same statement cache implementation as diesel. That brings in all the optimisations done to the diesel statement cache.
1 parent c7569a5 commit 4d349df

File tree

11 files changed

+264
-188
lines changed

11 files changed

+264
-188
lines changed
 

‎Cargo.toml

+23-5
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@ description = "An async extension for Diesel the safe, extensible ORM and Query
1313
rust-version = "1.78.0"
1414

1515
[dependencies]
16-
diesel = { version = "~2.2.0", default-features = false, features = [
17-
"i-implement-a-third-party-backend-and-opt-into-breaking-changes",
18-
] }
1916
async-trait = "0.1.66"
2017
futures-channel = { version = "0.3.17", default-features = false, features = [
2118
"std",
@@ -39,14 +36,35 @@ deadpool = { version = "0.12", optional = true, default-features = false, featur
3936
mobc = { version = ">=0.7,<0.10", optional = true }
4037
scoped-futures = { version = "0.1", features = ["std"] }
4138

39+
[dependencies.diesel]
40+
version = "~2.2.0"
41+
default-features = false
42+
features = [
43+
"i-implement-a-third-party-backend-and-opt-into-breaking-changes",
44+
]
45+
git = "https://github.com/diesel-rs/diesel"
46+
branch = "master"
47+
4248
[dev-dependencies]
4349
tokio = { version = "1.12.0", features = ["rt", "macros", "rt-multi-thread"] }
4450
cfg-if = "1"
4551
chrono = "0.4"
46-
diesel = { version = "2.2.0", default-features = false, features = ["chrono"] }
47-
diesel_migrations = "2.2.0"
4852
assert_matches = "1.0.1"
4953

54+
[dev-dependencies.diesel]
55+
version = "~2.2.0"
56+
default-features = false
57+
features = [
58+
"chrono"
59+
]
60+
git = "https://github.com/diesel-rs/diesel"
61+
branch = "master"
62+
63+
[dev-dependencies.diesel_migrations]
64+
version = "2.2.0"
65+
git = "https://github.com/diesel-rs/diesel"
66+
branch = "master"
67+
5068
[features]
5169
default = []
5270
mysql = [

‎examples/postgres/pooled-with-rustls/Cargo.toml

+7-1
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@ edition = "2021"
66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
77

88
[dependencies]
9-
diesel = { version = "2.2.0", default-features = false, features = ["postgres"] }
109
diesel-async = { version = "0.5.0", path = "../../../", features = ["bb8", "postgres"] }
1110
futures-util = "0.3.21"
1211
rustls = "0.23.8"
1312
rustls-native-certs = "0.7.1"
1413
tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] }
1514
tokio-postgres = "0.7.7"
1615
tokio-postgres-rustls = "0.12.0"
16+
17+
18+
[dependencies.diesel]
19+
version = "2.2.0"
20+
default-features = false
21+
git = "https://github.com/diesel-rs/diesel"
22+
branch = "master"

‎examples/postgres/run-pending-migrations-with-rustls/Cargo.toml

+11-2
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,21 @@ edition = "2021"
66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
77

88
[dependencies]
9-
diesel = { version = "2.2.0", default-features = false, features = ["postgres"] }
109
diesel-async = { version = "0.5.0", path = "../../../", features = ["bb8", "postgres", "async-connection-wrapper"] }
11-
diesel_migrations = "2.2.0"
1210
futures-util = "0.3.21"
1311
rustls = "0.23.10"
1412
rustls-native-certs = "0.7.1"
1513
tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] }
1614
tokio-postgres = "0.7.7"
1715
tokio-postgres-rustls = "0.12.0"
16+
17+
[dependencies.diesel]
18+
version = "2.2.0"
19+
default-features = false
20+
git = "https://github.com/diesel-rs/diesel"
21+
branch = "master"
22+
23+
[dependencies.diesel_migrations]
24+
version = "2.2.0"
25+
git = "https://github.com/diesel-rs/diesel"
26+
branch = "master"

‎examples/sync-wrapper/Cargo.toml

+12-2
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,22 @@ edition = "2021"
66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
77

88
[dependencies]
9-
diesel = { version = "2.2.0", default-features = false, features = ["returning_clauses_for_sqlite_3_35"] }
109
diesel-async = { version = "0.5.0", path = "../../", features = ["sync-connection-wrapper", "async-connection-wrapper"] }
11-
diesel_migrations = "2.2.0"
1210
futures-util = "0.3.21"
1311
tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] }
1412

13+
[dependencies.diesel]
14+
version = "2.2.0"
15+
default-features = false
16+
features = ["returning_clauses_for_sqlite_3_35"]
17+
git = "https://github.com/diesel-rs/diesel"
18+
branch = "master"
19+
20+
[dependencies.diesel_migrations]
21+
version = "2.2.0"
22+
git = "https://github.com/diesel-rs/diesel"
23+
branch = "master"
24+
1525
[features]
1626
default = ["sqlite"]
1727
sqlite = ["diesel-async/sqlite"]

‎src/async_connection_wrapper.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> =
100100
pub use self::implementation::AsyncConnectionWrapper;
101101

102102
mod implementation {
103-
use diesel::connection::{Instrumentation, SimpleConnection};
103+
use diesel::connection::{CacheSize, Instrumentation, SimpleConnection};
104104
use std::ops::{Deref, DerefMut};
105105

106106
use super::*;
@@ -187,6 +187,10 @@ mod implementation {
187187
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
188188
self.inner.set_instrumentation(instrumentation);
189189
}
190+
191+
fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
192+
self.inner.set_prepared_statement_cache_size(size)
193+
}
190194
}
191195

192196
impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B>

‎src/lib.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
)]
7575

7676
use diesel::backend::Backend;
77-
use diesel::connection::Instrumentation;
77+
use diesel::connection::{CacheSize, Instrumentation};
7878
use diesel::query_builder::{AsQuery, QueryFragment, QueryId};
7979
use diesel::result::Error;
8080
use diesel::row::Row;
@@ -354,4 +354,7 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send {
354354

355355
/// Set a specific [`Instrumentation`] implementation for this connection
356356
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation);
357+
358+
/// Set the prepared statement cache size to [`CacheSize`] for this connection
359+
fn set_prepared_statement_cache_size(&mut self, size: CacheSize);
357360
}

‎src/mysql/mod.rs

+62-53
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
use crate::stmt_cache::{PrepareCallback, StmtCache};
1+
use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
22
use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
3-
use diesel::connection::statement_cache::{MaybeCached, StatementCacheKey};
4-
use diesel::connection::Instrumentation;
5-
use diesel::connection::InstrumentationEvent;
3+
use diesel::connection::statement_cache::{
4+
MaybeCached, QueryFragmentForCachedStatement, StatementCache,
5+
};
66
use diesel::connection::StrQueryHelper;
7+
use diesel::connection::{CacheSize, Instrumentation};
8+
use diesel::connection::{DynInstrumentation, InstrumentationEvent};
79
use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
810
use diesel::query_builder::QueryBuilder;
911
use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
@@ -27,9 +29,9 @@ use self::serialize::ToSqlHelper;
2729
/// `mysql://[user[:password]@]host/database_name`
2830
pub struct AsyncMysqlConnection {
2931
conn: mysql_async::Conn,
30-
stmt_cache: StmtCache<Mysql, Statement>,
32+
stmt_cache: StatementCache<Mysql, Statement>,
3133
transaction_manager: AnsiTransactionManager,
32-
instrumentation: std::sync::Mutex<Option<Box<dyn Instrumentation>>>,
34+
instrumentation: DynInstrumentation,
3335
}
3436

3537
#[async_trait::async_trait]
@@ -72,7 +74,7 @@ impl AsyncConnection for AsyncMysqlConnection {
7274
type TransactionManager = AnsiTransactionManager;
7375

7476
async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
75-
let mut instrumentation = diesel::connection::get_default_instrumentation();
77+
let mut instrumentation = DynInstrumentation::default_instrumentation();
7678
instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
7779
database_url,
7880
));
@@ -82,7 +84,7 @@ impl AsyncConnection for AsyncMysqlConnection {
8284
r.as_ref().err(),
8385
));
8486
let mut conn = r?;
85-
conn.instrumentation = std::sync::Mutex::new(instrumentation);
87+
conn.instrumentation = instrumentation;
8688
Ok(conn)
8789
}
8890

@@ -177,16 +179,15 @@ impl AsyncConnection for AsyncMysqlConnection {
177179
}
178180

179181
fn instrumentation(&mut self) -> &mut dyn Instrumentation {
180-
self.instrumentation
181-
.get_mut()
182-
.unwrap_or_else(|p| p.into_inner())
182+
&mut *self.instrumentation
183183
}
184184

185185
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
186-
*self
187-
.instrumentation
188-
.get_mut()
189-
.unwrap_or_else(|p| p.into_inner()) = Some(Box::new(instrumentation));
186+
self.instrumentation = instrumentation.into();
187+
}
188+
189+
fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
190+
self.stmt_cache.set_cache_size(size);
190191
}
191192
}
192193

@@ -207,17 +208,24 @@ fn update_transaction_manager_status<T>(
207208
query_result
208209
}
209210

210-
#[async_trait::async_trait]
211-
impl PrepareCallback<Statement, MysqlType> for &'_ mut mysql_async::Conn {
212-
async fn prepare(
213-
self,
214-
sql: &str,
215-
_metadata: &[MysqlType],
216-
_is_for_cache: diesel::connection::statement_cache::PrepareForCache,
217-
) -> QueryResult<(Statement, Self)> {
218-
let s = self.prep(sql).await.map_err(ErrorHelper)?;
219-
Ok((s, self))
220-
}
211+
fn prepare_statement_helper<'a, 'b>(
212+
conn: &'a mut mysql_async::Conn,
213+
sql: &'b str,
214+
_is_for_cache: diesel::connection::statement_cache::PrepareForCache,
215+
_metadata: &[MysqlType],
216+
) -> CallbackHelper<impl Future<Output = QueryResult<(Statement, &'a mut mysql_async::Conn)>> + Send>
217+
{
218+
// ideally we wouldn't clone the SQL string here
219+
// but as we usually cache statements anyway
220+
// this is a fixed one time const
221+
//
222+
// The probleme with not cloning it is that we then cannot express
223+
// the right result lifetime anymore (at least not easily)
224+
let sql = sql.to_owned();
225+
CallbackHelper(async move {
226+
let s = conn.prep(sql).await.map_err(ErrorHelper)?;
227+
Ok((s, conn))
228+
})
221229
}
222230

223231
impl AsyncMysqlConnection {
@@ -229,11 +237,9 @@ impl AsyncMysqlConnection {
229237
use crate::run_query_dsl::RunQueryDsl;
230238
let mut conn = AsyncMysqlConnection {
231239
conn,
232-
stmt_cache: StmtCache::new(),
240+
stmt_cache: StatementCache::new(),
233241
transaction_manager: AnsiTransactionManager::default(),
234-
instrumentation: std::sync::Mutex::new(
235-
diesel::connection::get_default_instrumentation(),
236-
),
242+
instrumentation: DynInstrumentation::default_instrumentation(),
237243
};
238244

239245
for stmt in CONNECTION_SETUP_QUERIES {
@@ -286,36 +292,29 @@ impl AsyncMysqlConnection {
286292
} = bind_collector?;
287293
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
288294
let sql = sql?;
295+
let helper = QueryFragmentHelper {
296+
sql,
297+
safe_to_cache: is_safe_to_cache_prepared,
298+
};
289299
let inner = async {
290-
let cache_key = if let Some(query_id) = query_id {
291-
StatementCacheKey::Type(query_id)
292-
} else {
293-
StatementCacheKey::Sql {
294-
sql: sql.clone(),
295-
bind_types: metadata.clone(),
296-
}
297-
};
298-
299300
let (stmt, conn) = stmt_cache
300-
.cached_prepared_statement(
301-
cache_key,
302-
sql.clone(),
303-
is_safe_to_cache_prepared,
301+
.cached_statement_non_generic(
302+
query_id,
303+
&helper,
304+
&Mysql,
304305
&metadata,
305306
conn,
306-
instrumentation,
307+
prepare_statement_helper,
308+
&mut **instrumentation,
307309
)
308310
.await?;
309311
callback(conn, stmt, ToSqlHelper { metadata, binds }).await
310312
};
311313
let r = update_transaction_manager_status(inner.await, transaction_manager);
312-
instrumentation
313-
.get_mut()
314-
.unwrap_or_else(|p| p.into_inner())
315-
.on_connection_event(InstrumentationEvent::finish_query(
316-
&StrQueryHelper::new(&sql),
317-
r.as_ref().err(),
318-
));
314+
instrumentation.on_connection_event(InstrumentationEvent::finish_query(
315+
&StrQueryHelper::new(&helper.sql),
316+
r.as_ref().err(),
317+
));
319318
r
320319
}
321320
.boxed()
@@ -370,9 +369,9 @@ impl AsyncMysqlConnection {
370369

371370
Ok(AsyncMysqlConnection {
372371
conn,
373-
stmt_cache: StmtCache::new(),
372+
stmt_cache: StatementCache::new(),
374373
transaction_manager: AnsiTransactionManager::default(),
375-
instrumentation: std::sync::Mutex::new(None),
374+
instrumentation: DynInstrumentation::none(),
376375
})
377376
}
378377
}
@@ -427,3 +426,13 @@ mod tests {
427426
}
428427
}
429428
}
429+
430+
impl QueryFragmentForCachedStatement<Mysql> for QueryFragmentHelper {
431+
fn construct_sql(&self, _backend: &Mysql) -> QueryResult<String> {
432+
Ok(self.sql.clone())
433+
}
434+
435+
fn is_safe_to_cache_prepared(&self, _backend: &Mysql) -> QueryResult<bool> {
436+
Ok(self.safe_to_cache)
437+
}
438+
}

‎src/pg/mod.rs

+77-44
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
use self::error_helper::ErrorHelper;
88
use self::row::PgRow;
99
use self::serialize::ToSqlHelper;
10-
use crate::stmt_cache::{PrepareCallback, StmtCache};
10+
use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
1111
use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
12-
use diesel::connection::statement_cache::{PrepareForCache, StatementCacheKey};
13-
use diesel::connection::Instrumentation;
14-
use diesel::connection::InstrumentationEvent;
12+
use diesel::connection::statement_cache::{
13+
PrepareForCache, QueryFragmentForCachedStatement, StatementCache,
14+
};
1515
use diesel::connection::StrQueryHelper;
16+
use diesel::connection::{CacheSize, Instrumentation};
17+
use diesel::connection::{DynInstrumentation, InstrumentationEvent};
1618
use diesel::pg::{
1719
Pg, PgMetadataCache, PgMetadataCacheKey, PgMetadataLookup, PgQueryBuilder, PgTypeMetadata,
1820
};
@@ -122,13 +124,13 @@ const FAKE_OID: u32 = 0;
122124
/// [tokio_postgres_rustls]: https://docs.rs/tokio-postgres-rustls/0.12.0/tokio_postgres_rustls/
123125
pub struct AsyncPgConnection {
124126
conn: Arc<tokio_postgres::Client>,
125-
stmt_cache: Arc<Mutex<StmtCache<diesel::pg::Pg, Statement>>>,
127+
stmt_cache: Arc<Mutex<StatementCache<diesel::pg::Pg, Statement>>>,
126128
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
127129
metadata_cache: Arc<Mutex<PgMetadataCache>>,
128130
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
129131
shutdown_channel: Option<oneshot::Sender<()>>,
130132
// a sync mutex is fine here as we only hold it for a really short time
131-
instrumentation: Arc<std::sync::Mutex<Option<Box<dyn Instrumentation>>>>,
133+
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
132134
}
133135

134136
#[async_trait::async_trait]
@@ -162,7 +164,7 @@ impl AsyncConnection for AsyncPgConnection {
162164
type TransactionManager = AnsiTransactionManager;
163165

164166
async fn establish(database_url: &str) -> ConnectionResult<Self> {
165-
let mut instrumentation = diesel::connection::get_default_instrumentation();
167+
let mut instrumentation = DynInstrumentation::default_instrumentation();
166168
instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
167169
database_url,
168170
));
@@ -229,14 +231,25 @@ impl AsyncConnection for AsyncPgConnection {
229231
// that means there is only one instance of this arc and
230232
// we can simply access the inner data
231233
if let Some(instrumentation) = Arc::get_mut(&mut self.instrumentation) {
232-
instrumentation.get_mut().unwrap_or_else(|p| p.into_inner())
234+
&mut **(instrumentation.get_mut().unwrap_or_else(|p| p.into_inner()))
233235
} else {
234236
panic!("Cannot access shared instrumentation")
235237
}
236238
}
237239

238240
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
239-
self.instrumentation = Arc::new(std::sync::Mutex::new(Some(Box::new(instrumentation))));
241+
self.instrumentation = Arc::new(std::sync::Mutex::new(instrumentation.into()));
242+
}
243+
244+
fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
245+
// there should be no other pending future when this is called
246+
// that means there is only one instance of this arc and
247+
// we can simply access the inner data
248+
if let Some(cache) = Arc::get_mut(&mut self.stmt_cache) {
249+
cache.get_mut().set_cache_size(size)
250+
} else {
251+
panic!("Cannot access shared statement cache")
252+
}
240253
}
241254
}
242255

@@ -293,25 +306,33 @@ fn update_transaction_manager_status<T>(
293306
query_result
294307
}
295308

296-
#[async_trait::async_trait]
297-
impl PrepareCallback<Statement, PgTypeMetadata> for Arc<tokio_postgres::Client> {
298-
async fn prepare(
299-
self,
300-
sql: &str,
301-
metadata: &[PgTypeMetadata],
302-
_is_for_cache: PrepareForCache,
303-
) -> QueryResult<(Statement, Self)> {
304-
let bind_types = metadata
305-
.iter()
306-
.map(type_from_oid)
307-
.collect::<QueryResult<Vec<_>>>()?;
308-
309-
let stmt = self
310-
.prepare_typed(sql, &bind_types)
309+
fn prepare_statement_helper<'a>(
310+
conn: Arc<tokio_postgres::Client>,
311+
sql: &'a str,
312+
_is_for_cache: PrepareForCache,
313+
metadata: &[PgTypeMetadata],
314+
) -> CallbackHelper<
315+
impl Future<Output = QueryResult<(Statement, Arc<tokio_postgres::Client>)>> + Send,
316+
> {
317+
let bind_types = metadata
318+
.iter()
319+
.map(type_from_oid)
320+
.collect::<QueryResult<Vec<_>>>();
321+
// ideally we wouldn't clone the SQL string here
322+
// but as we usually cache statements anyway
323+
// this is a fixed one time const
324+
//
325+
// The probleme with not cloning it is that we then cannot express
326+
// the right result lifetime anymore (at least not easily)
327+
let sql = sql.to_string();
328+
CallbackHelper(async move {
329+
let bind_types = bind_types?;
330+
let stmt = conn
331+
.prepare_typed(&sql, &bind_types)
311332
.await
312333
.map_err(ErrorHelper);
313-
Ok((stmt?, self))
314-
}
334+
Ok((stmt?, conn))
335+
})
315336
}
316337

317338
fn type_from_oid(t: &PgTypeMetadata) -> QueryResult<Type> {
@@ -369,7 +390,7 @@ impl AsyncPgConnection {
369390
None,
370391
None,
371392
Arc::new(std::sync::Mutex::new(
372-
diesel::connection::get_default_instrumentation(),
393+
DynInstrumentation::default_instrumentation(),
373394
)),
374395
)
375396
.await
@@ -390,9 +411,7 @@ impl AsyncPgConnection {
390411
client,
391412
Some(error_rx),
392413
Some(shutdown_tx),
393-
Arc::new(std::sync::Mutex::new(
394-
diesel::connection::get_default_instrumentation(),
395-
)),
414+
Arc::new(std::sync::Mutex::new(DynInstrumentation::none())),
396415
)
397416
.await
398417
}
@@ -401,11 +420,11 @@ impl AsyncPgConnection {
401420
conn: tokio_postgres::Client,
402421
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
403422
shutdown_channel: Option<oneshot::Sender<()>>,
404-
instrumentation: Arc<std::sync::Mutex<Option<Box<dyn Instrumentation>>>>,
423+
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
405424
) -> ConnectionResult<Self> {
406425
let mut conn = Self {
407426
conn: Arc::new(conn),
408-
stmt_cache: Arc::new(Mutex::new(StmtCache::new())),
427+
stmt_cache: Arc::new(Mutex::new(StatementCache::new())),
409428
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
410429
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
411430
connection_future,
@@ -559,23 +578,27 @@ impl AsyncPgConnection {
559578
})?;
560579
}
561580
}
562-
let key = match query_id {
563-
Some(id) => StatementCacheKey::Type(id),
564-
None => StatementCacheKey::Sql {
565-
sql: sql.clone(),
566-
bind_types: bind_collector.metadata.clone(),
567-
},
568-
};
569581
let stmt = {
570582
let mut stmt_cache = stmt_cache.lock().await;
583+
let helper = QueryFragmentHelper {
584+
sql: sql.clone(),
585+
safe_to_cache: is_safe_to_cache_prepared,
586+
};
587+
let instrumentation = Arc::clone(&instrumentation);
571588
stmt_cache
572-
.cached_prepared_statement(
573-
key,
574-
sql.clone(),
575-
is_safe_to_cache_prepared,
589+
.cached_statement_non_generic(
590+
query_id,
591+
&helper,
592+
&Pg,
576593
&bind_collector.metadata,
577594
raw_connection.clone(),
578-
&instrumentation
595+
prepare_statement_helper,
596+
&mut move |event: InstrumentationEvent<'_>| {
597+
// we wrap this lock into another callback to prevent locking
598+
// the instrumentation longer than necessary
599+
instrumentation.lock().unwrap_or_else(|e| e.into_inner())
600+
.on_connection_event(event);
601+
},
579602
)
580603
.await?
581604
.0
@@ -894,6 +917,16 @@ impl crate::pooled_connection::PoolableConnection for AsyncPgConnection {
894917
}
895918
}
896919

920+
impl QueryFragmentForCachedStatement<Pg> for QueryFragmentHelper {
921+
fn construct_sql(&self, _backend: &Pg) -> QueryResult<String> {
922+
Ok(self.sql.clone())
923+
}
924+
925+
fn is_safe_to_cache_prepared(&self, _backend: &Pg) -> QueryResult<bool> {
926+
Ok(self.safe_to_cache)
927+
}
928+
}
929+
897930
#[cfg(test)]
898931
mod tests {
899932
use super::*;

‎src/pooled_connection/mod.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use crate::{AsyncConnection, SimpleAsyncConnection};
99
use crate::{TransactionManager, UpdateAndFetchResults};
1010
use diesel::associations::HasTable;
11-
use diesel::connection::Instrumentation;
11+
use diesel::connection::{CacheSize, Instrumentation};
1212
use diesel::QueryResult;
1313
use futures_util::{future, FutureExt};
1414
use std::borrow::Cow;
@@ -241,6 +241,10 @@ where
241241
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
242242
self.deref_mut().set_instrumentation(instrumentation);
243243
}
244+
245+
fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
246+
self.deref_mut().set_prepared_statement_cache_size(size);
247+
}
244248
}
245249

246250
#[doc(hidden)]

‎src/stmt_cache.rs

+43-77
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,57 @@
1-
use std::collections::HashMap;
2-
use std::hash::Hash;
3-
4-
use diesel::backend::Backend;
5-
use diesel::connection::statement_cache::{MaybeCached, PrepareForCache, StatementCacheKey};
6-
use diesel::connection::Instrumentation;
7-
use diesel::connection::InstrumentationEvent;
1+
use diesel::connection::statement_cache::{MaybeCached, StatementCallbackReturnType};
82
use diesel::QueryResult;
9-
use futures_util::{future, FutureExt};
3+
use futures_util::{future, FutureExt, TryFutureExt};
4+
use std::future::Future;
105

11-
#[derive(Default)]
12-
pub struct StmtCache<DB: Backend, S> {
13-
cache: HashMap<StatementCacheKey<DB>, S>,
14-
}
6+
pub(crate) struct CallbackHelper<F>(pub(crate) F);
157

16-
type PrepareFuture<'a, F, S> = future::Either<
17-
future::Ready<QueryResult<(MaybeCached<'a, S>, F)>>,
18-
future::BoxFuture<'a, QueryResult<(MaybeCached<'a, S>, F)>>,
8+
type PrepareFuture<'a, C, S> = future::Either<
9+
future::Ready<QueryResult<(MaybeCached<'a, S>, C)>>,
10+
future::BoxFuture<'a, QueryResult<(MaybeCached<'a, S>, C)>>,
1911
>;
2012

21-
#[async_trait::async_trait]
22-
pub trait PrepareCallback<S, M>: Sized {
23-
async fn prepare(
24-
self,
25-
sql: &str,
26-
metadata: &[M],
27-
is_for_cache: PrepareForCache,
28-
) -> QueryResult<(S, Self)>;
29-
}
13+
impl<'b, S, F, C> StatementCallbackReturnType<S, C> for CallbackHelper<F>
14+
where
15+
F: Future<Output = QueryResult<(S, C)>> + Send + 'b,
16+
S: 'static,
17+
{
18+
type Return<'a> = PrepareFuture<'a, C, S>;
3019

31-
impl<S, DB: Backend> StmtCache<DB, S> {
32-
pub fn new() -> Self {
33-
Self {
34-
cache: HashMap::new(),
35-
}
20+
fn from_error<'a>(e: diesel::result::Error) -> Self::Return<'a> {
21+
future::Either::Left(future::ready(Err(e)))
3622
}
3723

38-
pub fn cached_prepared_statement<'a, F>(
39-
&'a mut self,
40-
cache_key: StatementCacheKey<DB>,
41-
sql: String,
42-
is_query_safe_to_cache: bool,
43-
metadata: &[DB::TypeMetadata],
44-
prepare_fn: F,
45-
instrumentation: &std::sync::Mutex<Option<Box<dyn Instrumentation>>>,
46-
) -> PrepareFuture<'a, F, S>
24+
fn map_to_no_cache<'a>(self) -> Self::Return<'a>
4725
where
48-
S: Send,
49-
DB::QueryBuilder: Default,
50-
DB::TypeMetadata: Clone + Send + Sync,
51-
F: PrepareCallback<S, DB::TypeMetadata> + Send + 'a,
52-
StatementCacheKey<DB>: Hash + Eq,
26+
Self: 'a,
5327
{
54-
use std::collections::hash_map::Entry::{Occupied, Vacant};
55-
56-
if !is_query_safe_to_cache {
57-
let metadata = metadata.to_vec();
58-
let f = async move {
59-
let stmt = prepare_fn
60-
.prepare(&sql, &metadata, PrepareForCache::No)
61-
.await?;
62-
Ok((MaybeCached::CannotCache(stmt.0), stmt.1))
63-
}
64-
.boxed();
65-
return future::Either::Right(f);
66-
}
28+
future::Either::Right(
29+
self.0
30+
.map_ok(|(stmt, conn)| (MaybeCached::CannotCache(stmt), conn))
31+
.boxed(),
32+
)
33+
}
6734

68-
match self.cache.entry(cache_key) {
69-
Occupied(entry) => future::Either::Left(future::ready(Ok((
70-
MaybeCached::Cached(entry.into_mut()),
71-
prepare_fn,
72-
)))),
73-
Vacant(entry) => {
74-
let metadata = metadata.to_vec();
75-
instrumentation
76-
.lock()
77-
.unwrap_or_else(|p| p.into_inner())
78-
.on_connection_event(InstrumentationEvent::cache_query(&sql));
79-
let f = async move {
80-
let statement = prepare_fn
81-
.prepare(&sql, &metadata, PrepareForCache::Yes)
82-
.await?;
35+
fn map_to_cache<'a>(stmt: &'a mut S, conn: C) -> Self::Return<'a> {
36+
future::Either::Left(future::ready(Ok((MaybeCached::Cached(stmt), conn))))
37+
}
8338

84-
Ok((MaybeCached::Cached(entry.insert(statement.0)), statement.1))
85-
}
86-
.boxed();
87-
future::Either::Right(f)
88-
}
89-
}
39+
fn register_cache<'a>(
40+
self,
41+
callback: impl FnOnce(S) -> &'a mut S + Send + 'a,
42+
) -> Self::Return<'a>
43+
where
44+
Self: 'a,
45+
{
46+
future::Either::Right(
47+
self.0
48+
.map_ok(|(stmt, conn)| (MaybeCached::Cached(callback(stmt)), conn))
49+
.boxed(),
50+
)
9051
}
9152
}
53+
54+
pub(crate) struct QueryFragmentHelper {
55+
pub(crate) sql: String,
56+
pub(crate) safe_to_cache: bool,
57+
}

‎src/sync_connection_wrapper/mod.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
1010
use crate::{AsyncConnection, SimpleAsyncConnection, TransactionManager};
1111
use diesel::backend::{Backend, DieselReserveSpecialization};
12-
use diesel::connection::Instrumentation;
12+
use diesel::connection::{CacheSize, Instrumentation};
1313
use diesel::connection::{
1414
Connection, LoadConnection, TransactionManagerStatus, WithMetadataLookup,
1515
};
@@ -188,6 +188,20 @@ where
188188
panic!("Cannot access shared instrumentation")
189189
}
190190
}
191+
192+
fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
193+
// there should be no other pending future when this is called
194+
// that means there is only one instance of this arc and
195+
// we can simply access the inner data
196+
if let Some(inner) = Arc::get_mut(&mut self.inner) {
197+
inner
198+
.get_mut()
199+
.unwrap_or_else(|p| p.into_inner())
200+
.set_prepared_statement_cache_size(size)
201+
} else {
202+
panic!("Cannot access shared cache")
203+
}
204+
}
191205
}
192206

193207
/// A wrapper of a diesel transaction manager usable in async context.

0 commit comments

Comments
 (0)
Please sign in to comment.