Skip to content

Commit

Permalink
Final changes for robustness
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Aug 7, 2024
1 parent 7ce76e5 commit 8c9ee2d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 45 deletions.
58 changes: 35 additions & 23 deletions edb/server/conn_pool/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
pool::{Pool, PoolConfig},
PoolHandle,
};
use derive_more::{Add, AddAssign};
use futures::future::poll_fn;
use pyo3::{exceptions::PyException, prelude::*};
use std::{
Expand All @@ -29,19 +30,19 @@ enum RustToPythonMessage {
PerformDisconnect(ConnHandleId),
PerformReconnect(ConnHandleId, String),

Failed(PythonConnId, String),
Failed(PythonConnId, ConnHandleId),
}

impl ToPyObject for RustToPythonMessage {
fn to_object(&self, py: Python<'_>) -> PyObject {
use RustToPythonMessage::*;
match self {
Acquired(a, b) => (0, a, b).to_object(py),
PerformConnect(conn, s) => (1, conn, s).to_object(py),
PerformDisconnect(conn) => (2, conn).to_object(py),
PerformReconnect(conn, s) => (3, conn, s).to_object(py),
Acquired(a, b) => (0, a, b.0).to_object(py),
PerformConnect(conn, s) => (1, conn.0, s).to_object(py),
PerformDisconnect(conn) => (2, conn.0).to_object(py),
PerformReconnect(conn, s) => (3, conn.0, s).to_object(py),
Pruned(conn) => (4, conn).to_object(py),
Failed(conn, error) => (5, conn, error).to_object(py),
Failed(conn, error) => (5, conn, error.0).to_object(py),
}
}
}
Expand All @@ -59,21 +60,28 @@ enum PythonToRustMessage {
/// Completed an async request made by Rust.
CompletedAsync(ConnHandleId),
/// Failed an async request made by Rust.
FailedAsync(ConnHandleId, String),
FailedAsync(ConnHandleId),
}

type PipeSender = tokio::net::unix::pipe::Sender;

type PythonConnId = u64;
type ConnHandleId = u64;
#[derive(Debug, Default, Clone, Copy, Add, AddAssign, PartialEq, Eq, Hash, PartialOrd, Ord)]
struct ConnHandleId(u64);

impl Into<Box<(dyn derive_more::Error + std::marker::Send + Sync + 'static)>> for ConnHandleId {
fn into(self) -> Box<(dyn derive_more::Error + std::marker::Send + Sync + 'static)> {
Box::new(ConnError::Underlying(format!("{self:?}")))
}
}

struct RpcPipe {
rust_to_python_notify: RefCell<PipeSender>,
rust_to_python: std::sync::mpsc::Sender<RustToPythonMessage>,
python_to_rust: RefCell<tokio::sync::mpsc::UnboundedReceiver<PythonToRustMessage>>,
handles: RefCell<BTreeMap<PythonConnId, PoolHandle<Rc<RpcPipe>>>>,
next_id: Cell<ConnHandleId>,
async_ops: RefCell<BTreeMap<ConnHandleId, tokio::sync::oneshot::Sender<String>>>,
async_ops: RefCell<BTreeMap<ConnHandleId, tokio::sync::oneshot::Sender<()>>>,
}

impl std::fmt::Debug for RpcPipe {
Expand All @@ -100,15 +108,15 @@ impl RpcPipe {

async fn call<T>(
self: Rc<Self>,
id: u64,
conn_id: ConnHandleId,
ok: T,
msg: RustToPythonMessage,
) -> ConnResult<T, String> {
) -> ConnResult<T, ConnHandleId> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.async_ops.borrow_mut().insert(id, tx);
self.write(msg).await?;
if let Ok(error) = rx.await {
Err(ConnError::Underlying(error.into()))
self.async_ops.borrow_mut().insert(conn_id, tx);
self.write(msg).await.map_err(|_| ConnError::Underlying(conn_id))?;
if let Ok(_) = rx.await {
Err(ConnError::Underlying(conn_id))
} else {
Ok(ok)
}
Expand All @@ -117,7 +125,7 @@ impl RpcPipe {

impl Connector for Rc<RpcPipe> {
type Conn = ConnHandleId;
type Error = String;
type Error = ConnHandleId;

fn connect(
&self,
Expand All @@ -126,7 +134,7 @@ impl Connector for Rc<RpcPipe> {
Output = ConnResult<<Self as Connector>::Conn, <Self as Connector>::Error>,
> + 'static {
let id = self.next_id.get();
self.next_id.set(id + 1);
self.next_id.set(id + ConnHandleId(1));
let msg = RustToPythonMessage::PerformConnect(id, db.to_owned());
self.clone().call(id, id, msg)
}
Expand Down Expand Up @@ -207,12 +215,16 @@ async fn run_and_block(max_capacity: usize, rpc_pipe: RpcPipe) {
Acquire(conn_id, db) => {
let conn = match pool.acquire(&db).await {
Ok(conn) => conn,
Err(err) => {
Err(ConnError::Underlying(err)) => {
_ = rpc_pipe
.write(RustToPythonMessage::Failed(conn_id, err.to_string()))
.write(RustToPythonMessage::Failed(conn_id, err))
.await;
return;
}
Err(_) => {
// TODO
return;
}
};
let handle = conn.handle();
rpc_pipe.handles.borrow_mut().insert(conn_id, conn);
Expand All @@ -238,13 +250,13 @@ async fn run_and_block(max_capacity: usize, rpc_pipe: RpcPipe) {
CompletedAsync(handle_id) => {
rpc_pipe.async_ops.borrow_mut().remove(&handle_id);
}
FailedAsync(handle_id, error) => {
FailedAsync(handle_id) => {
_ = rpc_pipe
.async_ops
.borrow_mut()
.remove(&handle_id)
.unwrap()
.send(error);
.send(());
}
}
});
Expand Down Expand Up @@ -322,13 +334,13 @@ impl ConnPool {

fn _completed(&self, id: u64) -> PyResult<()> {
self.python_to_rust
.send(PythonToRustMessage::CompletedAsync(id))
.send(PythonToRustMessage::CompletedAsync(ConnHandleId(id)))
.map_err(|_| internal_error("In shutdown"))
}

fn _failed(&self, id: u64, error: PyObject) -> PyResult<()> {
self.python_to_rust
.send(PythonToRustMessage::FailedAsync(id, error.to_string()))
.send(PythonToRustMessage::FailedAsync(ConnHandleId(id)))
.map_err(|_| internal_error("In shutdown"))
}

Expand Down
53 changes: 36 additions & 17 deletions edb/server/connpool/pool2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

from . import config

guard = edb.server._conn_pool.LoggingGuard()
logger = logging.getLogger("edb.server.connpool")
guard = edb.server._conn_pool.LoggingGuard()

# Connections must be hashable because we use them to reverse-lookup
# an internal ID.
Expand Down Expand Up @@ -140,13 +140,19 @@ def __del__(self) -> None:

async def close(self) -> None:
if self._task:
# Cancel the currently-executing futures
for acq in self._acquires.values():
acq.set_exception(asyncio.CancelledError())
for prune in self._prunes.values():
prune.set_exception(asyncio.CancelledError())
logger.info("Closing connection pool...")
self._task.cancel()
task = self._task
self._task = None
task.cancel()
try:
await self._task
await task
except asyncio.exceptions.CancelledError:
pass
self._task = None
self._pool = None
logger.info("Closed connection pool")

Expand All @@ -158,7 +164,7 @@ async def _boot(self, loop: asyncio.AbstractEventLoop) -> None:
transport, _ = await loop.connect_read_pipe(lambda: reader_protocol, fd)
try:
while len(await reader.read(1)) == 1:
if not self._pool:
if not self._pool or not self._task:
break
if self._skip_reads > 0:
self._skip_reads -= 1
Expand All @@ -181,8 +187,14 @@ def _try_read(self) -> None:
self._process_message(msg)

def _process_message(self, msg: typing.Any) -> None:
# If we're closing, don't dispatch any operations
if not self._task:
return
if msg[0] == 0:
self._acquires[msg[1]].set_result(msg[2])
if f := self._acquires.pop(msg[1], None):
f.set_result(msg[2])
else:
logger.warn(f"Duplicate result for acquire {msg[1]}")
elif msg[0] == 1:
self._loop.create_task(self._perform_connect(msg[1], msg[2]))
elif msg[0] == 2:
Expand All @@ -192,11 +204,13 @@ def _process_message(self, msg: typing.Any) -> None:
elif msg[0] == 4:
self._loop.create_task(self._perform_prune(msg[1]))
elif msg[0] == 5:
error = self._errors.pop(msg[1], None)
if not error:
error = Exception(
f"Connection error was unexpectedly lost: {msg[2]}")
self._acquires[msg[1]].set_exception(error)
# Note that we might end up with duplicated messages at shutdown
error = self._errors.pop(msg[2], None)
if error:
if f := self._acquires.pop(msg[1], None):
f.set_exception(error)
else:
logger.warn(f"Duplicate exception for acquire {msg[1]}")
else:
logger.critical(f'Unexpected message: {msg}')

Expand Down Expand Up @@ -252,26 +266,29 @@ async def _perform_prune(self, id: int) -> None:
async def acquire(self, dbname: str) -> C:
"""Acquire a connection from the database. This connection must be
released."""
if not self._task:
raise asyncio.CancelledError()
for i in range(config.CONNECT_FAILURE_RETRIES + 1):
id = self._next_conn_id
self._next_conn_id += 1
self._acquires[id] = asyncio.Future()
acquire: asyncio.Future[int] = asyncio.Future()
self._acquires[id] = acquire
self._pool._acquire(id, dbname)
self._try_read()
# This may throw!
try:
conn = await self._acquires[id]
conn = await acquire
c = self._conns[conn]
self._conns_held[c] = id
return c
except Exception as e:
# Allow the final exception to escape
if i == config.CONNECT_FAILURE_RETRIES:
logger.exception('Failed to acquire connection, will not '
f'retry: {dbname}')
raise
logger.error(f'Failed to acquire connection, will retry: {
dbname}', e)
finally:
del self._acquires[id]
logger.exception('Failed to acquire connection, will retry: '
f'{dbname}')
raise AssertionError("Unreachable end of loop")

def release(self, dbname: str, conn: C, discard: bool = False) -> None:
Expand All @@ -286,6 +303,8 @@ def release(self, dbname: str, conn: C, discard: bool = False) -> None:
self._try_read()

async def prune_inactive_connections(self, dbname: str) -> None:
if not self._task:
raise asyncio.CancelledError()
id = self._next_conn_id
self._next_conn_id += 1
self._prunes[id] = asyncio.Future()
Expand Down
6 changes: 1 addition & 5 deletions edb/server/tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,6 @@ def stop(self) -> None:
self._cluster.stop_watching()
self._stop_watching_files()
self._server.request_frontend_stop(self)
# The pool may require some async shutdown tasks to fully resolve
if not self._task_group:
self._task_group = asyncio.TaskGroup()
self._task_group.create_task(
self._pg_pool.close(), name="pool shutdown")

def _stop_watching_files(self):
while self._file_watch_finalizers:
Expand All @@ -534,6 +529,7 @@ async def wait_stopped(self) -> None:
tg = self._task_group
self._task_group = None
await tg.__aexit__(*sys.exc_info())
await self._pg_pool.close()

def terminate_sys_pgcon(self) -> None:
if self.__sys_pgcon is not None:
Expand Down

0 comments on commit 8c9ee2d

Please sign in to comment.