From 27d0c4f4ca577f5f0c82518a09e53e2b9f04731a Mon Sep 17 00:00:00 2001 From: Edwin Kys Date: Thu, 1 Aug 2024 17:42:29 -0500 Subject: [PATCH] feat: add tokio for psql connection --- Cargo.lock | 93 ++++++++++++++++++++++++++++++++++++++ Cargo.toml | 3 +- examples/measure_recall.rs | 5 +- src/db/database.rs | 22 ++++++--- 4 files changed, 113 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d975a0a2..6f757ba0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +dependencies = [ + "gimli", +] + [[package]] name = "adler" version = "1.0.2" @@ -42,6 +51,21 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "backtrace" +version = "0.3.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "base64" version = "0.21.7" @@ -470,6 +494,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "gimli" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" + [[package]] name = "half" version = "2.4.1" @@ -509,6 +539,12 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + [[package]] name = "hex" version = "0.4.3" @@ -665,6 +701,18 @@ dependencies = [ "adler", ] +[[package]] +name = "mio" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +dependencies = [ + "hermit-abi", + "libc", + "wasi", + "windows-sys 0.52.0", +] + [[package]] name = "nom" version = "7.1.3" @@ -739,10 +787,20 @@ dependencies = [ "simsimd", "sqlx", "tar", + "tokio", "url", "uuid", ] +[[package]] +name = "object" +version = "0.36.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "081b846d1d56ddfc18fdf1a922e4f6e07a11768ea1b92dec44e42b72712ccfce" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -962,6 +1020,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + [[package]] name = "rustix" version = "0.38.34" @@ -1168,6 +1232,8 @@ dependencies = [ "smallvec", "sqlformat", "thiserror", + "tokio", + "tokio-stream", "tracing", "url", ] @@ -1207,6 +1273,7 @@ dependencies = [ "sqlx-sqlite", "syn 1.0.109", "tempfile", + "tokio", "url", ] @@ -1410,6 +1477,32 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokio" +version = "1.39.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "pin-project-lite", + "socket2", + "windows-sys 0.52.0", +] + +[[package]] +name = "tokio-stream" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tracing" version = "0.1.40" diff --git a/Cargo.toml b/Cargo.toml index 91ba5820..43c9b6b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ categories = ["database", "algorithms", "embedded"] [dependencies] uuid = { version = "1.9.1", features = ["v4", "fast-rng", "serde"] } half = { version = "2.4.1", features = ["serde"] } +tokio = { version = "1.39.2", features = ["rt-multi-thread"] } url = "2.5.2" futures = "0.3.30" rand = "0.8.5" @@ -34,7 +35,7 @@ serde_json = "1.0.120" [dependencies.sqlx] version = "0.7.4" default-features = false -features = ["all-databases"] +features = ["all-databases", "runtime-tokio"] [dev-dependencies] byteorder = "1.5.0" diff --git a/examples/measure_recall.rs b/examples/measure_recall.rs index ae5cc042..faa45a80 100644 --- a/examples/measure_recall.rs +++ b/examples/measure_recall.rs @@ -1,7 +1,7 @@ use common::Dataset; -use futures::executor; use oasysdb::prelude::*; use std::error::Error; +use tokio::runtime::Runtime; mod common; @@ -10,7 +10,8 @@ fn main() -> Result<(), Box> { let db_url = dataset.database_url(); let config = SourceConfig::new(dataset.name(), "id", "vector"); - executor::block_on(dataset.populate_database())?; + let rt = Runtime::new()?; + rt.block_on(dataset.populate_database())?; let db = Database::open("odb_example", Some(db_url))?; create_index_flat(&db, &config)?; diff --git a/src/db/database.rs b/src/db/database.rs index d8f53edc..bcadea43 100644 --- a/src/db/database.rs +++ b/src/db/database.rs @@ -1,9 +1,9 @@ use super::*; -use futures::executor; use futures::stream::StreamExt; use sqlx::any::install_default_drivers; use sqlx::Acquire; use std::sync::{Arc, Mutex}; +use tokio::runtime::Runtime; use url::Url; use uuid::Uuid; @@ -159,7 +159,8 @@ impl Database { algorithm: IndexAlgorithm, config: SourceConfig, ) -> Result<(), Error> { - executor::block_on(self.async_create_index(name, algorithm, config)) + let rt = Runtime::new()?; + rt.block_on(self.async_create_index(name, algorithm, config)) } /// Returns an index reference. @@ -263,7 +264,8 @@ impl Database { /// Updates the index with new records from the source synchronously. /// - `name`: Index name. pub fn refresh_index(&self, name: impl AsRef) -> Result<(), Error> { - executor::block_on(self.async_refresh_index(name)) + let rt = Runtime::new()?; + rt.block_on(self.async_refresh_index(name)) } /// Searches the index for nearest neighbors. @@ -417,7 +419,8 @@ impl DatabaseState { /// Connects to the source SQL database. pub fn connect(&self) -> Result { - executor::block_on(self.async_connect()) + let rt = Runtime::new()?; + rt.block_on(self.async_connect()) } /// Disconnects from the source SQL database asynchronously. @@ -429,7 +432,8 @@ impl DatabaseState { /// Disconnects from the source SQL database. /// - `conn`: Database connection. pub fn disconnect(conn: SourceConnection) -> Result<(), Error> { - executor::block_on(Self::async_disconnect(conn)) + let rt = Runtime::new()?; + rt.block_on(Self::async_disconnect(conn)) } /// Validates the connection to the source database. @@ -539,7 +543,9 @@ mod tests { fn test_database_refresh_index() -> Result<(), Error> { let db = create_test_database()?; let query = generate_insert_query(100, 10); - executor::block_on(db.async_execute_sql(query))?; + + let rt = Runtime::new()?; + rt.block_on(db.async_execute_sql(query))?; db.refresh_index(TEST_INDEX).unwrap(); @@ -616,7 +622,9 @@ mod tests { let state = db.state()?; assert_eq!(state.source_type(), SourceType::SQLITE); - executor::block_on(setup_test_source(&db_url))?; + let rt = Runtime::new()?; + rt.block_on(setup_test_source(&db_url))?; + create_test_index(&mut db)?; Ok(db) }