diff --git a/edb/server/conn_pool/src/python.rs b/edb/server/conn_pool/src/python.rs index 8377b1779b95..c84ff1ada154 100644 --- a/edb/server/conn_pool/src/python.rs +++ b/edb/server/conn_pool/src/python.rs @@ -1,6 +1,380 @@ -use pyo3::{pymodule, types::PyModule, PyResult, Python}; +use crate::{ + conn::{ConnError, Connector}, + pool::{Pool, PoolConfig}, + PoolHandle, +}; +use futures::TryFutureExt; +use pyo3::{ + exceptions::PyException, + prelude::*, + types::{PyDict, PyTuple}, +}; +use std::{ + cell::RefCell, + collections::HashMap, + rc::Rc, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, RwLock, + }, + time::Duration, +}; +use tokio::task::LocalSet; +use tracing::{error, trace}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +pyo3::create_exception!(_conn_pool, InternalError, PyException); + +#[derive(Debug)] +#[repr(u8)] +enum ConnectOp { + Connect, + Disconnect, + Reconnect, +} + +#[derive(Debug, Default)] +struct PythonConnectionMap { + /// Connection : [`PoolHandle`] (to keep the handle alive) + handle: HashMap>, + py_dict: Option>, + next_id: usize, +} + +impl PythonConnectionMap { + pub fn insert(&mut self, py: Python, handle: PoolHandle) { + let py_dict = self + .py_dict + .get_or_insert_with(|| PyDict::new(py).into()) + .as_ref(py); + _ = handle.with_handle(|conn| py_dict.set_item(conn, self.next_id)); + self.handle.insert(self.next_id, handle); + self.next_id += 1; + } + + pub fn remove( + &mut self, + py: Python, + conn: PyObject, + ) -> Option> { + let Some(py_dict) = &mut self.py_dict else { + return None; + }; + let py_dict = py_dict.as_ref(py); + let item = py_dict.get_item(conn.clone_ref(py)).ok()??; + _ = py_dict.del_item(conn); + let key = item.extract::().ok()?; + self.handle.remove(&key) + } +} + +/// Implementation of the [`Connector`] interface. We don't pass the pool or Python objects +/// between threads, but rather use a usize ID that allows us to keep two maps in sync on +/// both sides of this interface. +#[derive(Debug)] +struct PythonConnectionFactory { + /// The _callback method that triggers the correctly-threaded task for the + /// connection operation. + callback: PyObject, + /// RPC callbacks. + responses: Arc>>>, + /// Next RPC ID. + next_response_id: Arc, +} + +impl PythonConnectionFactory { + fn send( + &self, + op: ConnectOp, + args: impl IntoPy>, + ) -> impl futures::Future> + 'static { + let (sender, receiver) = tokio::sync::oneshot::channel::(); + let response_id = self.next_response_id.fetch_add(1, Ordering::SeqCst); + self.responses.write().unwrap().insert(response_id, sender); + let success = Python::with_gil(|py| { + let args0: Py = (op as u8, response_id).into_py(py); + let args = args.into_py(py); + + let Ok(result) = self.callback.call(py, (args0, args), None) else { + error!("Unexpected failure in _callback"); + return false; + }; + let Ok(result) = result.is_true(py) else { + error!("Unexpected return value from _callback"); + return false; + }; + if !result { + return false; + } + true + }); + async move { + if success { + let conn = receiver.await.unwrap(); + let conn = Python::with_gil(|py| conn.to_object(py)); + trace!("Thread received {response_id} {}", conn); + Ok(conn) + } else { + Err(ConnError::Shutdown) + } + } + } +} + +impl Connector for PythonConnectionFactory { + type Conn = PyObject; + + fn connect( + &self, + db: &str, + ) -> impl futures::Future> + 'static { + self.send(ConnectOp::Connect, (db,)) + } + + fn disconnect( + &self, + conn: Self::Conn, + ) -> impl futures::Future> + 'static { + self.send(ConnectOp::Disconnect, (conn,)).map_ok(|_| ()) + } + + fn reconnect( + &self, + conn: Self::Conn, + db: &str, + ) -> impl futures::Future> + 'static { + self.send(ConnectOp::Reconnect, (conn, db)) + } +} + +impl PythonConnectionFactory { + fn new(callback: PyObject) -> Self { + Self { + callback, + responses: Default::default(), + next_response_id: Default::default(), + } + } +} + +#[derive(Debug)] +enum PoolRPC { + Acquire(String, PyObject), + Release(PyObject, bool), +} + +#[pyclass] +struct ConnPool { + connector: RwLock>, + responses: Arc>>>, + rpc_tx: RwLock>>, +} + +fn internal_error(py: Python, message: &str) { + error!("{message}"); + InternalError::new_err(()).restore(py); +} + +async fn run_and_block( + connector: PythonConnectionFactory, + mut rpc_rx: tokio::sync::mpsc::UnboundedReceiver, +) { + let pool = Rc::new(Pool::::new( + PoolConfig::suggested_default_for(100), + connector, + )); + let conns = Rc::new(RefCell::new(PythonConnectionMap::default())); + + let pool_task = { + let pool = pool.clone(); + tokio::task::spawn_local(async move { + loop { + pool.run_once(); + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + }; + + loop { + let Some(rpc) = rpc_rx.recv().await else { + pool_task.abort(); + break; + }; + let pool = pool.clone(); + let conns = conns.clone(); + trace!("Received RPC: {rpc:?}"); + tokio::task::spawn_local(async move { + match rpc { + PoolRPC::Acquire(db, callback) => { + let conn = pool.acquire(&db).await.unwrap(); + trace!("Acquired a handle to return to Python!"); + Python::with_gil(|py| { + let handle = conn.handle_clone(); + conns.borrow_mut().insert(py, conn); + callback.call1(py, (handle,)).unwrap(); + }); + } + PoolRPC::Release(conn, dispose) => { + Python::with_gil(|py| { + let Some(conn) = conns.borrow_mut().remove(py, conn) else { + error!("Attempted to dispose a connection that does not exist"); + return; + }; + + if dispose { + conn.poison(); + } + + drop(conn); + }); + } + } + }); + } +} + +#[pymethods] +impl ConnPool { + #[new] + fn new(callback: PyObject) -> Self { + let connector = PythonConnectionFactory::new(callback); + let responses = connector.responses.clone(); + ConnPool { + connector: RwLock::new(Some(connector)), + responses, + rpc_tx: Default::default(), + } + } + + fn _respond(&self, py: Python, response_id: usize, object: PyObject) { + trace!("_respond({response_id}, {object})"); + let response = self.responses.write().unwrap().remove(&response_id); + if let Some(response) = response { + response.send(object).unwrap(); + } else { + internal_error(py, "Missing response sender"); + } + } + + fn halt(&self, _py: Python) { + self.rpc_tx.write().unwrap().take(); + } + + /// Asynchronously acquires a connection, returning it to the callback + fn acquire(&self, db: &str, callback: PyObject) { + self.rpc_tx + .read() + .unwrap() + .as_ref() + .unwrap() + .send(PoolRPC::Acquire(db.to_owned(), callback)) + .unwrap(); + } + + /// Releases a connection when possible, potentially discarding it + fn release(&self, conn: PyObject, discard: bool) { + self.rpc_tx + .read() + .unwrap() + .as_ref() + .unwrap() + .send(PoolRPC::Release(conn, discard)) + .unwrap(); + } + + /// Boot the connection pool on this thread. + fn run_and_block(&self, py: Python) { + let connector = self.connector.write().unwrap().take().unwrap(); + let (rpc_tx, rpc_rx) = tokio::sync::mpsc::unbounded_channel(); + *self.rpc_tx.write().unwrap() = Some(rpc_tx); + py.allow_threads(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_time() + .build() + .unwrap(); + let local = LocalSet::new(); + local.block_on(&rt, run_and_block(connector, rpc_rx)); + }) + } +} #[pymodule] fn _conn_pool(py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add("InternalError", py.get_type::())?; + + let logging = py.import("logging")?; + let logger = logging + .getattr("getLogger")? + .call(("edb.server.connpool",), None)?; + let level = logger + .getattr("getEffectiveLevel")? + .call((), None)? + .extract::()?; + let logger = logger.to_object(py); + + struct PythonSubscriber { + logger: Py, + } + + impl tracing_subscriber::Layer for PythonSubscriber { + fn on_event(&self, event: &tracing::Event, _ctx: tracing_subscriber::layer::Context) { + let mut message = format!("[{}] ", event.metadata().target()); + + #[derive(Default)] + struct Visitor(String); + impl tracing::field::Visit for Visitor { + fn record_debug( + &mut self, + field: &tracing::field::Field, + value: &dyn std::fmt::Debug, + ) { + if field.name() == "message" { + self.0 += &format!("{value:?} "); + } else { + self.0 += &format!("{}={:?} ", field.name(), value) + } + } + } + + let mut visitor = Visitor::default(); + event.record(&mut visitor); + message += &visitor.0; + + Python::with_gil(|py| { + let log = match *event.metadata().level() { + tracing::Level::TRACE => self.logger.getattr(py, "debug").unwrap(), + tracing::Level::DEBUG => self.logger.getattr(py, "warning").unwrap(), + tracing::Level::INFO => self.logger.getattr(py, "info").unwrap(), + tracing::Level::WARN => self.logger.getattr(py, "warning").unwrap(), + tracing::Level::ERROR => self.logger.getattr(py, "error").unwrap(), + }; + log.call1(py, (message,)).unwrap(); + }) + } + } + + let level = if level < 10 { + tracing_subscriber::filter::LevelFilter::TRACE + } else if level <= 10 { + tracing_subscriber::filter::LevelFilter::DEBUG + } else if level <= 20 { + tracing_subscriber::filter::LevelFilter::INFO + } else if level <= 30 { + tracing_subscriber::filter::LevelFilter::WARN + } else if level <= 40 { + tracing_subscriber::filter::LevelFilter::ERROR + } else { + tracing_subscriber::filter::LevelFilter::OFF + }; + + let subscriber = PythonSubscriber { logger }; + tracing_subscriber::registry() + .with(level) + .with(subscriber) + .init(); + + tracing::info!("ConnPool initialized (level = {level})"); + Ok(()) } diff --git a/edb/server/connpool/pool.py b/edb/server/connpool/pool.py index 0f129948f5f2..821f8d11677a 100644 --- a/edb/server/connpool/pool.py +++ b/edb/server/connpool/pool.py @@ -419,6 +419,16 @@ def failed_connects(self) -> int: def failed_disconnects(self) -> int: return self._failed_disconnects + + async def __aenter__(self) -> typing.Self: + return self + + async def __aexit__(self, + exc_type: typing.Optional[type], + exc_val: typing.Optional[BaseException], + exc_tb: typing.Optional[typing.Any]) -> None: + pass + def get_pending_conns(self) -> int: return sum( block.count_pending_conns() for block in self._blocks.values() diff --git a/edb/server/connpool/pool2.py b/edb/server/connpool/pool2.py index d4f994ac4488..8b4dee5b0675 100644 --- a/edb/server/connpool/pool2.py +++ b/edb/server/connpool/pool2.py @@ -1,5 +1,275 @@ -import edb.server._conn_pool # noqa: F401 +import edb.server._conn_pool +import asyncio +import threading +import typing +import dataclasses +# Connections must be hashable because we use them to reverse-lookup +# an internal ID. +C = typing.TypeVar("C", bound=typing.Hashable) -class Pool: - pass +CP1 = typing.TypeVar('CP1', covariant=True) +CP2 = typing.TypeVar('CP2', contravariant=True) + + +class Connector(typing.Protocol[CP1]): + + def __call__(self, dbname: str) -> typing.Awaitable[CP1]: + pass + + +class Disconnector(typing.Protocol[CP2]): + + def __call__(self, conn: CP2) -> typing.Awaitable[None]: + pass + + +@dataclasses.dataclass +class Snapshot: + timestamp: float + capacity: int + + failed_connects: int + failed_disconnects: int + successful_connects: int + successful_disconnects: int + + +class StatsCollector(typing.Protocol): + + def __call__(self, stats: Snapshot) -> None: + pass + + +class ConnectionFactory(typing.Protocol[C]): + """The async interface to create and destroy database connections. + + All connections returned from successful calls to `connect` or reconnect + are guaranteed to be `disconnect`ed or `reconnect`ed.""" + + async def connect(self, db: str) -> C: + """Create a new connection asynchronously. + + This method must retry exceptions internally. If an exception is thrown + from this method, the database is considered to be failed.""" + ... + + async def disconnect(self, conn: C) -> None: + """Gracefully disconnect a connection asynchronously. + + If an exception is thrown from this method, the connection is simply + forgotten.""" + ... + + async def reconnect(self, conn: C, db: str) -> C: + """Reconnects a connection to the given database. If this is not + possible, it is permissable to return a new connection and gracefully + disconnect the other connection in parallel or in the background. + + This method must retry exceptions internally. If an exception is thrown + from this method, the database is considered to be failed.""" + ... + + +class ConnPool(typing.Generic[C]): + _connection_factory: ConnectionFactory[C] + _pool: edb.server._conn_pool.ConnPool + _loop: asyncio.AbstractEventLoop + _completion: asyncio.Future[bool] + _ready: asyncio.Future[bool] + _active_conns: set[C] + _cur_capacity: int + _cur_waiters: int + + def __init__(self, connection_factory: ConnectionFactory[C]): + self._connection_factory = connection_factory + self._loop = asyncio.get_event_loop() + self._pool = None + self._completion = self._loop.create_future() + self._ready = self._loop.create_future() + self._active_conns = set() + self._cur_capacity = 0 + self._cur_waiters = 0 + + def _callback(self, args0: typing.Any, args: typing.Any) -> bool: + """Receives the callback from the Rust connection pool. + + Required to call pool._respond on the main thread with the result of + this callback. + """ + (kind, response_id) = args0 + if self._loop.is_closed(): + return False + else: + self._loop.call_soon_threadsafe( + self._loop.create_task, + self._perform_async(kind, response_id, *args), + ) + return True + + async def _perform_async(self, kind: int, + response_id: int, + *args: typing.Any) -> None: + """Delegates the callback from Rust to the appropriate connection + factory method.""" + if kind == 0: + self._cur_capacity += 1 + response = await self._connection_factory.connect(*args) + elif kind == 1: + await self._connection_factory.disconnect(*args) + self._cur_capacity -= 1 + response = None + elif kind == 2: + response = await self._connection_factory.reconnect(*args) + if self._pool is not None: + self._pool._respond(response_id, response) + + def _thread_main(self) -> None: + self._loop.call_soon_threadsafe(self._ready.set_result, True) + self._pool.run_and_block() + if not self._loop.is_closed(): + self._loop.call_soon_threadsafe(self._completion.set_result, True) + + async def run(self) -> None: + """Creates a long-lived task that manages the connection pool. Required + before any connections may be acquired.""" + if self._pool is not None: + raise RuntimeError(f"pool already started") from None + + self._pool = edb.server._conn_pool.ConnPool(self._callback) + threading.Thread(target=self._thread_main, daemon=True).start() + try: + await self._completion + except asyncio.exceptions.CancelledError: + self._pool.halt() + self._pool = None + + async def acquire(self, db: str) -> C: + """Acquire a connection from the database. This connection must be + released.""" + await self._ready + future: asyncio.Future[C] = self._loop.create_future() + # Note that this callback is called on the internal pool's thread + self._cur_waiters += 1 + try: + self._pool.acquire( + db, + lambda res: self._loop.call_soon_threadsafe(future.set_result, + res), + ) + conn = await future + finally: + self._cur_waiters -= 1 + self._active_conns.add(conn) + return conn + + def release(self, _db: str, conn: C, discard: bool = False) -> None: + """Releases a connection back into the pool, discarding or returning it + in the background.""" + self._active_conns.remove(conn) + self._pool.release(conn, discard) + pass + + def count_waiters(self) -> int: + return self._cur_waiters + + +class FactoryAdapter(typing.Generic[C]): + _connect: Connector[C] + _disconnect: Disconnector[C] + + def __init__(self, + connect: Connector[C], + disconnect: Disconnector[C]) -> None: + self._connect = connect + self._disconnect = disconnect + + async def connect(self, db: str) -> C: + return await self._connect(db) + + async def disconnect(self, conn: C) -> None: + await self._disconnect(conn) + + async def reconnect(self, conn: C, db: str) -> C: + await self._disconnect(conn) + return await self._connect(db) + + +class Pool(typing.Generic[C]): + _pool: ConnPool[C] + _failed_connects: int + _failed_disconnects: int + _successful_connects: int + _successful_disconnects: int + _cur_capacity: int + _max_capacity: int + _task: typing.Optional[asyncio.Task[None]] + + def __init__(self, *, connect: Connector[C], + disconnect: Disconnector[C], + stats_collector: typing.Optional[StatsCollector], + max_capacity: int) -> None: + self._pool = ConnPool(connection_factory=FactoryAdapter(connect, + disconnect,)) + self._failed_connects = 0 + self._failed_disconnects = 0 + self._successful_connects = 0 + self._successful_disconnects = 0 + self._task = None + + if stats_collector: + stats_collector(Snapshot( + timestamp=0, + capacity=10, + failed_connects=0, + failed_disconnects=0, + successful_connects=0, + successful_disconnects=0)) + pass + + async def __aenter__(self) -> typing.Self: + self._task = asyncio.create_task(self._pool.run()) + return self + + async def __aexit__(self, + exc_type: typing.Optional[type], + exc_val: typing.Optional[BaseException], + exc_tb: typing.Optional[typing.Any]) -> None: + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + print("Task has been successfully cancelled") + self._task = None + print("Exiting context") + + async def acquire(self, db: str) -> C: + """Acquire a connection from the database. This connection must be + released.""" + if self._task is None: + raise RuntimeError("Not entered") + return await self._pool.acquire(db) + + def release(self, _db: str, conn: C, discard: bool = False) -> None: + """Releases a connection back into the pool, discarding or returning it + in the background.""" + if self._task is None: + raise RuntimeError("Not entered") + self._pool.release(_db, conn, discard) + + @property + def max_capacity(self) -> int: + return self._max_capacity + + @property + def current_capacity(self) -> int: + return self._cur_capacity + + @property + def failed_connects(self) -> int: + return self._failed_connects + + @property + def failed_disconnects(self) -> int: + return self._failed_disconnects diff --git a/tests/test_server_pool.py b/tests/test_server_pool.py index f5ae94bea058..77ba66030ccd 100644 --- a/tests/test_server_pool.py +++ b/tests/test_server_pool.py @@ -511,54 +511,54 @@ def on_stats(stat): stat = dataclasses.asdict(stat) sim.stats.append(stat) - pool = pool_cls( + async with pool_cls( connect=self.make_fake_connect( sim, spec.conn_cost_base, spec.conn_cost_var), disconnect=self.make_fake_disconnect( sim, spec.disconn_cost_base, spec.disconn_cost_var), stats_collector=on_stats if collect_stats else None, max_capacity=spec.capacity, - ) - if hasattr(pool, '_gc_interval'): - pool._gc_interval = 0.1 * TIME_SCALE + ) as pool: + if hasattr(pool, '_gc_interval'): + pool._gc_interval = 0.1 * TIME_SCALE started_at = time.monotonic() async with asyncio.TaskGroup() as g: for db in spec.dbs: g.create_task(self.simulate_db(sim, pool, g, db)) - self.assertEqual(sim.failed_disconnects, 0) - self.assertEqual(sim.failed_queries, 0) + self.assertEqual(sim.failed_disconnects, 0) + self.assertEqual(sim.failed_queries, 0) - self.assertEqual(pool.failed_disconnects, 0) - self.assertEqual(pool.failed_connects, 0) + self.assertEqual(pool.failed_disconnects, 0) + self.assertEqual(pool.failed_connects, 0) - try: - for db in sim.latencies: - int(db[1:]) - except ValueError: - key_func = lambda x: x - else: - key_func = lambda x: int(x[0][1:]) - - if collect_stats: - pn = f'{type(pool).__module__}.{type(pool).__qualname__}' - score = int(round(sum(sm.calculate(sim) for sm in spec.score))) - print('weighted score:'.rjust(68), score) - js_data = { - 'test_started_at': started_at, - 'total_lats': calc_total_percentiles(sim.latencies), - "score": score, - 'scores': sim.scores, - 'lats': { - db: calc_percentiles(lats) - for db, lats in sorted(sim.latencies.items(), key=key_func) - }, - 'pool_name': pn, - 'stats': sim.stats, - } + try: + for db in sim.latencies: + int(db[1:]) + except ValueError: + key_func = lambda x: x + else: + key_func = lambda x: int(x[0][1:]) + + if collect_stats: + pn = f'{type(pool).__module__}.{type(pool).__qualname__}' + score = int(round(sum(sm.calculate(sim) for sm in spec.score))) + print('weighted score:'.rjust(68), score) + js_data = { + 'test_started_at': started_at, + 'total_lats': calc_total_percentiles(sim.latencies), + "score": score, + 'scores': sim.scores, + 'lats': { + db: calc_percentiles(lats) + for db, lats in sorted(sim.latencies.items(), key=key_func) + }, + 'pool_name': pn, + 'stats': sim.stats, + } - return js_data + return js_data async def simulate(self, testname, spec): if os.environ.get('EDGEDB_TEST_DEBUG_POOL'): @@ -578,7 +578,7 @@ async def simulate(self, testname, spec): ) async def simulate_and_collect_stats(self, testname, spec): - pools = [connpool.Pool, connpool._NaivePool] + pools = [connpool.Pool2, connpool.Pool, connpool._NaivePool] js_data = [] for pool_cls in pools: