Skip to content

Commit

Permalink
Migrate some async connection stuff to reusable bits
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Feb 19, 2025
1 parent 70c2baf commit 8b3bb66
Show file tree
Hide file tree
Showing 12 changed files with 310 additions and 225 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion edb/server/connpool/pool2.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(

self._loop = asyncio.get_running_loop()
self._channel = rust_async_channel.RustAsyncChannel(
self._pool,
self._pool._channel,
self._process_message,
)

Expand Down
2 changes: 1 addition & 1 deletion edb/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ async def _boot(self, client) -> None:
logger.info(f"HTTP client initialized, user_agent={self._user_agent}")
try:
channel = rust_async_channel.RustAsyncChannel(
client, self._process_message
client._channel, self._process_message
)
try:
await channel.run()
Expand Down
2 changes: 1 addition & 1 deletion edb/server/rust_async_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
pipe: RustPipeProtocol,
callback: Callable[[Tuple[Any, ...]], None],
) -> None:
self._closed = asyncio.Event()
fd = pipe._fd
self._buffered_reader = io.BufferedReader(
io.FileIO(fd), buffer_size=MAX_BATCH_SIZE
Expand All @@ -56,7 +57,6 @@ def __init__(
self._pipe = pipe
self._callback = callback
self._skip_reads = 0
self._closed = asyncio.Event()

def __del__(self):
if not self._closed.is_set():
Expand Down
155 changes: 56 additions & 99 deletions rust/conn_pool/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@ use crate::{
PoolHandle,
};
use derive_more::{Add, AddAssign};
use futures::future::poll_fn;
use pyo3::{exceptions::PyException, prelude::*, types::PyByteArray};
use pyo3_util::logging::{get_python_logger_level, initialize_logging_in_thread};
use pyo3::{
exceptions::{PyException, PyValueError},
prelude::*,
types::PyByteArray,
};
use pyo3_util::{
channel::{new_python_channel, PythonChannel, PythonChannelImpl, RustChannel},
logging::{get_python_logger_level, initialize_logging_in_thread},
};
use serde_pickle::SerOptions;
use std::{
cell::{Cell, RefCell},
collections::BTreeMap,
os::fd::IntoRawFd,
pin::Pin,
rc::Rc,
sync::Mutex,
sync::Arc,
thread,
time::{Duration, Instant},
};
use strum::IntoEnumIterator;
use tokio::{io::AsyncWrite, task::LocalSet};
use tokio::task::LocalSet;
use tracing::{error, info, trace};

pyo3::create_exception!(_conn_pool, InternalError, PyException);
Expand All @@ -38,10 +42,14 @@ enum RustToPythonMessage {
Metrics(Vec<u8>),
}

impl RustToPythonMessage {
fn to_object(&self, py: Python<'_>) -> PyResult<PyObject> {
impl<'py> IntoPyObject<'py> for RustToPythonMessage {
type Target = PyAny;
type Output = Bound<'py, PyAny>;
type Error = PyErr;

fn into_pyobject(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
use RustToPythonMessage::*;
match self {
let res = match self {
Acquired(a, b) => (0, a, b.0).into_pyobject(py),
PerformConnect(conn, s) => (1, conn.0, s).into_pyobject(py),
PerformDisconnect(conn) => (2, conn.0).into_pyobject(py),
Expand All @@ -50,10 +58,10 @@ impl RustToPythonMessage {
Failed(conn, error) => (5, conn, error.0).into_pyobject(py),
Metrics(metrics) => {
// This is not really fast but it should not be happening very often
(6, PyByteArray::new(py, metrics)).into_pyobject(py)
(6, PyByteArray::new(py, &metrics)).into_pyobject(py)
}
}
.map(|e| e.into())
}?;
Ok(res.into_any())
}
}

Expand All @@ -73,7 +81,12 @@ enum PythonToRustMessage {
FailedAsync(ConnHandleId),
}

type PipeSender = tokio::net::unix::pipe::Sender;
impl<'py> FromPyObject<'py> for PythonToRustMessage {
fn extract_bound(_: &Bound<'py, PyAny>) -> PyResult<Self> {
// Unused for this class
Err(PyValueError::new_err("Not implemented"))
}
}

type PythonConnId = u64;
#[derive(Debug, Default, Clone, Copy, Add, AddAssign, PartialEq, Eq, Hash, PartialOrd, Ord)]
Expand All @@ -86,9 +99,7 @@ impl From<ConnHandleId> for Box<(dyn std::error::Error + std::marker::Send + Syn
}

struct RpcPipe {
rust_to_python_notify: RefCell<PipeSender>,
rust_to_python: std::sync::mpsc::Sender<RustToPythonMessage>,
python_to_rust: RefCell<tokio::sync::mpsc::UnboundedReceiver<PythonToRustMessage>>,
channel: RustChannel<PythonToRustMessage, RustToPythonMessage>,
handles: RefCell<BTreeMap<PythonConnId, PoolHandle<Rc<RpcPipe>>>>,
next_id: Cell<ConnHandleId>,
async_ops: RefCell<BTreeMap<ConnHandleId, tokio::sync::oneshot::Sender<()>>>,
Expand All @@ -101,21 +112,6 @@ impl std::fmt::Debug for RpcPipe {
}

impl RpcPipe {
async fn write(&self, msg: RustToPythonMessage) -> ConnResult<(), String> {
self.rust_to_python
.send(msg)
.map_err(|_| ConnError::Shutdown)?;
// If we're shutting down, this may fail (but that's OK)
poll_fn(|cx| {
let pipe = &mut *self.rust_to_python_notify.borrow_mut();
let this = Pin::new(pipe);
this.poll_write(cx, &[0])
})
.await
.map_err(|_| ConnError::Shutdown)?;
Ok(())
}

async fn call<T>(
self: Rc<Self>,
conn_id: ConnHandleId,
Expand All @@ -124,7 +120,8 @@ impl RpcPipe {
) -> ConnResult<T, ConnHandleId> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.async_ops.borrow_mut().insert(conn_id, tx);
self.write(msg)
self.channel
.write(msg)
.await
.map_err(|_| ConnError::Underlying(conn_id))?;
if rx.await.is_ok() {
Expand Down Expand Up @@ -176,9 +173,7 @@ impl Connector for Rc<RpcPipe> {

#[pyclass]
struct ConnPool {
python_to_rust: tokio::sync::mpsc::UnboundedSender<PythonToRustMessage>,
rust_to_python: Mutex<std::sync::mpsc::Receiver<RustToPythonMessage>>,
notify_fd: u64,
channel: Arc<PythonChannelImpl<PythonToRustMessage, RustToPythonMessage>>,
}

impl Drop for ConnPool {
Expand Down Expand Up @@ -209,6 +204,7 @@ async fn run_and_block(config: PoolConfig, rpc_pipe: RpcPipe, stats_interval: f6
if last_stats.elapsed() > stats_interval {
last_stats = Instant::now();
if rpc_pipe
.channel
.write(RustToPythonMessage::Metrics(
serde_pickle::to_vec(&pool.metrics(), SerOptions::new())
.unwrap_or_default(),
Expand All @@ -224,8 +220,7 @@ async fn run_and_block(config: PoolConfig, rpc_pipe: RpcPipe, stats_interval: f6
};

loop {
let Some(rpc) = poll_fn(|cx| rpc_pipe.python_to_rust.borrow_mut().poll_recv(cx)).await
else {
let Some(rpc) = rpc_pipe.channel.recv().await else {
info!("ConnPool shutting down");
pool_task.abort();
pool.shutdown().await;
Expand All @@ -242,6 +237,7 @@ async fn run_and_block(config: PoolConfig, rpc_pipe: RpcPipe, stats_interval: f6
Ok(conn) => conn,
Err(ConnError::Underlying(err)) => {
_ = rpc_pipe
.channel
.write(RustToPythonMessage::Failed(conn_id, err))
.await;
return;
Expand All @@ -254,6 +250,7 @@ async fn run_and_block(config: PoolConfig, rpc_pipe: RpcPipe, stats_interval: f6
let handle = conn.handle();
rpc_pipe.handles.borrow_mut().insert(conn_id, conn);
_ = rpc_pipe
.channel
.write(RustToPythonMessage::Acquired(conn_id, handle))
.await;
}
Expand All @@ -270,7 +267,10 @@ async fn run_and_block(config: PoolConfig, rpc_pipe: RpcPipe, stats_interval: f6
}
Prune(conn_id, db) => {
pool.drain_idle(&db).await;
_ = rpc_pipe.write(RustToPythonMessage::Pruned(conn_id)).await;
_ = rpc_pipe
.channel
.write(RustToPythonMessage::Pruned(conn_id))
.await;
}
CompletedAsync(handle_id) => {
rpc_pipe.async_ops.borrow_mut().remove(&handle_id);
Expand Down Expand Up @@ -302,8 +302,6 @@ impl ConnPool {
let level = get_python_logger_level(py, "edb.server.conn_pool")?;
let min_idle_time_before_gc = min_idle_time_before_gc as usize;
let new = py.allow_threads(|| {
let (txrp, rxrp) = std::sync::mpsc::channel();
let (txpr, rxpr) = tokio::sync::mpsc::unbounded_channel();
let (txfd, rxfd) = std::sync::mpsc::channel();
thread::spawn(move || {
initialize_logging_in_thread("edb.server.conn_pool", level);
Expand All @@ -315,15 +313,13 @@ impl ConnPool {
.build()
.unwrap();
let _guard = rt.enter();
let (txn, rxn) = tokio::net::unix::pipe::pipe().unwrap();
let fd = rxn.into_nonblocking_fd().unwrap().into_raw_fd() as u64;
txfd.send(fd).unwrap();

let (rust, python) = new_python_channel();
txfd.send(python).unwrap();
let local = LocalSet::new();

let rpc_pipe = RpcPipe {
python_to_rust: rxpr.into(),
rust_to_python: txrp,
rust_to_python_notify: txn.into(),
channel: rust,
next_id: Default::default(),
handles: Default::default(),
async_ops: Default::default(),
Expand All @@ -334,85 +330,46 @@ impl ConnPool {
local.block_on(&rt, run_and_block(config, rpc_pipe, stats_interval));
});

let notify_fd = rxfd.recv().unwrap();
let channel = rxfd.recv().unwrap().into();
ConnPool {
python_to_rust: txpr,
rust_to_python: Mutex::new(rxrp),
notify_fd,
channel,
}
});
Ok(new)
}

#[getter]
fn _fd(&self) -> u64 {
self.notify_fd
fn _channel(&self) -> PyResult<PythonChannel> {
Ok(PythonChannel::new(self.channel.clone()))
}

fn _acquire(&self, id: u64, db: &str) -> PyResult<()> {
self.python_to_rust
self.channel
.send(PythonToRustMessage::Acquire(id, db.to_owned()))
.map_err(|_| internal_error("In shutdown"))
}

fn _release(&self, id: u64) -> PyResult<()> {
self.python_to_rust
.send(PythonToRustMessage::Release(id))
.map_err(|_| internal_error("In shutdown"))
self.channel.send_err(PythonToRustMessage::Release(id))
}

fn _discard(&self, id: u64) -> PyResult<()> {
self.python_to_rust
.send(PythonToRustMessage::Discard(id))
.map_err(|_| internal_error("In shutdown"))
self.channel.send_err(PythonToRustMessage::Discard(id))
}

fn _completed(&self, id: u64) -> PyResult<()> {
self.python_to_rust
.send(PythonToRustMessage::CompletedAsync(ConnHandleId(id)))
.map_err(|_| internal_error("In shutdown"))
self.channel
.send_err(PythonToRustMessage::CompletedAsync(ConnHandleId(id)))
}

fn _failed(&self, id: u64, _error: PyObject) -> PyResult<()> {
self.python_to_rust
.send(PythonToRustMessage::FailedAsync(ConnHandleId(id)))
.map_err(|_| internal_error("In shutdown"))
self.channel
.send_err(PythonToRustMessage::FailedAsync(ConnHandleId(id)))
}

fn _prune(&self, id: u64, db: &str) -> PyResult<()> {
self.python_to_rust
.send(PythonToRustMessage::Prune(id, db.to_owned()))
.map_err(|_| internal_error("In shutdown"))
}

fn _read(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
let Ok(msg) = self
.rust_to_python
.try_lock()
.expect("Unsafe thread access")
.try_recv()
else {
return Ok(py.None());
};
msg.to_object(py)
}

fn _try_read(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
let Ok(msg) = self
.rust_to_python
.try_lock()
.expect("Unsafe thread access")
.try_recv()
else {
return Ok(py.None());
};
msg.to_object(py)
}

fn _close_pipe(&mut self) {
// Replace the channel with a dummy, closed one which will also
// signal the other side to exit.
self.rust_to_python = Mutex::new(std::sync::mpsc::channel().1);
self.channel
.send_err(PythonToRustMessage::Prune(id, db.to_owned()))
}
}

Expand Down
Loading

0 comments on commit 8b3bb66

Please sign in to comment.