diff --git a/src/webserver/database/connect.rs b/src/webserver/database/connect.rs index 39f2035d..df150779 100644 --- a/src/webserver/database/connect.rs +++ b/src/webserver/database/connect.rs @@ -1,11 +1,12 @@ -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use super::Database; use crate::{app_config::AppConfig, ON_CONNECT_FILE}; +use futures_util::future::BoxFuture; use sqlx::{ any::{Any, AnyConnectOptions, AnyKind}, pool::PoolOptions, - ConnectOptions, Executor, + sqlite, AnyConnection, ConnectOptions, Executor, }; impl Database { @@ -94,31 +95,50 @@ impl Database { } fn add_on_connection_handler(pool_options: PoolOptions) -> PoolOptions { + let on_connect_sql = on_connection_handler_file(); + pool_options.after_connect(move |conn, _metadata| { + let on_connect_sql = on_connect_sql.clone(); + Box::pin(async move { + if let Some(sql) = (&on_connect_sql).as_ref() { + log::debug!("Running {ON_CONNECT_FILE:?} on new connection"); + let sql = std::sync::Arc::clone(&sql); + let r = conn.execute(sql.as_str()).await?; + log::debug!("Finished running connection handler on new connection: {r:?}"); + } + if let sqlx::any::AnyConnectionKind::Sqlite(sqlite_conn) = conn.private_get_mut() { + sqlite_on_connection_handler(sqlite_conn).await?; + } + Ok(()) + }) + }) +} + +async fn sqlite_on_connection_handler( + sqlite_conn: &mut sqlx::sqlite::SqliteConnection, +) -> sqlx::Result<()> { + let handle = sqlite_conn.lock_handle().await?; + log::warn!("TODO: bind sqlite functions"); + Ok(()) +} + +fn on_connection_handler_file() -> Option> { let on_connect_file = std::env::current_dir() .unwrap_or_default() .join(ON_CONNECT_FILE); if !on_connect_file.exists() { log::debug!("Not creating a custom SQL database connection handler because {on_connect_file:?} does not exist"); - return pool_options; + return None; } log::info!("Creating a custom SQL database connection handler from {on_connect_file:?}"); let sql = match std::fs::read_to_string(&on_connect_file) { Ok(sql) => std::sync::Arc::new(sql), Err(e) => { log::error!("Unable to read the file {on_connect_file:?}: {e}"); - return pool_options; + return None; } }; log::trace!("The custom SQL database connection handler is:\n{sql}"); - pool_options.after_connect(move |conn, _metadata| { - log::debug!("Running {on_connect_file:?} on new connection"); - let sql = std::sync::Arc::clone(&sql); - Box::pin(async move { - let r = conn.execute(sql.as_str()).await?; - log::debug!("Finished running connection handler on new connection: {r:?}"); - Ok(()) - }) - }) + Some(sql) } fn set_custom_connect_options(options: &mut AnyConnectOptions, config: &AppConfig) { @@ -127,5 +147,12 @@ fn set_custom_connect_options(options: &mut AnyConnectOptions, config: &AppConfi log::info!("Loading SQLite extension: {}", extension_name); *sqlite_options = std::mem::take(sqlite_options).extension(extension_name.clone()); } + *sqlite_options = std::mem::take(sqlite_options).collation("NOCASE", sqlite_collate_nocase); + *sqlite_options = + std::mem::take(sqlite_options).thread_name(|i| format!("sqlpage_sqlite_{}", i)); } } + +fn sqlite_collate_nocase(a: &str, b: &str) -> std::cmp::Ordering { + a.to_lowercase().cmp(&b.to_lowercase()) +}