diff --git a/Cargo.lock b/Cargo.lock index 22fc4aadafe..153b2ede701 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2509,6 +2509,7 @@ version = "0.1.0" dependencies = [ "pyo3", "scopeguard", + "tokio", "tracing", "tracing-subscriber", ] diff --git a/edb/server/connpool/pool2.py b/edb/server/connpool/pool2.py index 1b9f749f86d..0d49b1cbc3a 100644 --- a/edb/server/connpool/pool2.py +++ b/edb/server/connpool/pool2.py @@ -134,7 +134,7 @@ def __init__( self._loop = asyncio.get_running_loop() self._channel = rust_async_channel.RustAsyncChannel( - self._pool, + self._pool._channel, self._process_message, ) diff --git a/edb/server/http.py b/edb/server/http.py index 07bb0f0056f..ab28eb49dbd 100644 --- a/edb/server/http.py +++ b/edb/server/http.py @@ -309,7 +309,7 @@ async def _boot(self, client) -> None: logger.info(f"HTTP client initialized, user_agent={self._user_agent}") try: channel = rust_async_channel.RustAsyncChannel( - client, self._process_message + client._channel, self._process_message ) try: await channel.run() diff --git a/edb/server/rust_async_channel.py b/edb/server/rust_async_channel.py index 9de2644a863..3c645b668db 100644 --- a/edb/server/rust_async_channel.py +++ b/edb/server/rust_async_channel.py @@ -48,6 +48,7 @@ def __init__( pipe: RustPipeProtocol, callback: Callable[[Tuple[Any, ...]], None], ) -> None: + self._closed = asyncio.Event() fd = pipe._fd self._buffered_reader = io.BufferedReader( io.FileIO(fd), buffer_size=MAX_BATCH_SIZE @@ -56,7 +57,6 @@ def __init__( self._pipe = pipe self._callback = callback self._skip_reads = 0 - self._closed = asyncio.Event() def __del__(self): if not self._closed.is_set(): diff --git a/rust/conn_pool/src/python.rs b/rust/conn_pool/src/python.rs index dcd01f17376..8573a8a33c2 100644 --- a/rust/conn_pool/src/python.rs +++ b/rust/conn_pool/src/python.rs @@ -5,22 +5,26 @@ use crate::{ PoolHandle, }; use derive_more::{Add, AddAssign}; -use futures::future::poll_fn; -use pyo3::{exceptions::PyException, prelude::*, types::PyByteArray}; -use pyo3_util::logging::{get_python_logger_level, initialize_logging_in_thread}; +use pyo3::{ + exceptions::{PyException, PyValueError}, + prelude::*, + types::PyByteArray, +}; +use pyo3_util::{ + channel::{new_python_channel, PythonChannel, PythonChannelImpl, RustChannel}, + logging::{get_python_logger_level, initialize_logging_in_thread}, +}; use serde_pickle::SerOptions; use std::{ cell::{Cell, RefCell}, collections::BTreeMap, - os::fd::IntoRawFd, - pin::Pin, rc::Rc, - sync::Mutex, + sync::Arc, thread, time::{Duration, Instant}, }; use strum::IntoEnumIterator; -use tokio::{io::AsyncWrite, task::LocalSet}; +use tokio::task::LocalSet; use tracing::{error, info, trace}; pyo3::create_exception!(_conn_pool, InternalError, PyException); @@ -38,10 +42,14 @@ enum RustToPythonMessage { Metrics(Vec), } -impl RustToPythonMessage { - fn to_object(&self, py: Python<'_>) -> PyResult { +impl<'py> IntoPyObject<'py> for RustToPythonMessage { + type Target = PyAny; + type Output = Bound<'py, PyAny>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> PyResult> { use RustToPythonMessage::*; - match self { + let res = match self { Acquired(a, b) => (0, a, b.0).into_pyobject(py), PerformConnect(conn, s) => (1, conn.0, s).into_pyobject(py), PerformDisconnect(conn) => (2, conn.0).into_pyobject(py), @@ -50,10 +58,10 @@ impl RustToPythonMessage { Failed(conn, error) => (5, conn, error.0).into_pyobject(py), Metrics(metrics) => { // This is not really fast but it should not be happening very often - (6, PyByteArray::new(py, metrics)).into_pyobject(py) + (6, PyByteArray::new(py, &metrics)).into_pyobject(py) } - } - .map(|e| e.into()) + }?; + Ok(res.into_any()) } } @@ -73,7 +81,12 @@ enum PythonToRustMessage { FailedAsync(ConnHandleId), } -type PipeSender = tokio::net::unix::pipe::Sender; +impl<'py> FromPyObject<'py> for PythonToRustMessage { + fn extract_bound(_: &Bound<'py, PyAny>) -> PyResult { + // Unused for this class + Err(PyValueError::new_err("Not implemented")) + } +} type PythonConnId = u64; #[derive(Debug, Default, Clone, Copy, Add, AddAssign, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -86,9 +99,7 @@ impl From for Box<(dyn std::error::Error + std::marker::Send + Syn } struct RpcPipe { - rust_to_python_notify: RefCell, - rust_to_python: std::sync::mpsc::Sender, - python_to_rust: RefCell>, + channel: RustChannel, handles: RefCell>>>, next_id: Cell, async_ops: RefCell>>, @@ -101,21 +112,6 @@ impl std::fmt::Debug for RpcPipe { } impl RpcPipe { - async fn write(&self, msg: RustToPythonMessage) -> ConnResult<(), String> { - self.rust_to_python - .send(msg) - .map_err(|_| ConnError::Shutdown)?; - // If we're shutting down, this may fail (but that's OK) - poll_fn(|cx| { - let pipe = &mut *self.rust_to_python_notify.borrow_mut(); - let this = Pin::new(pipe); - this.poll_write(cx, &[0]) - }) - .await - .map_err(|_| ConnError::Shutdown)?; - Ok(()) - } - async fn call( self: Rc, conn_id: ConnHandleId, @@ -124,7 +120,8 @@ impl RpcPipe { ) -> ConnResult { let (tx, rx) = tokio::sync::oneshot::channel(); self.async_ops.borrow_mut().insert(conn_id, tx); - self.write(msg) + self.channel + .write(msg) .await .map_err(|_| ConnError::Underlying(conn_id))?; if rx.await.is_ok() { @@ -176,9 +173,7 @@ impl Connector for Rc { #[pyclass] struct ConnPool { - python_to_rust: tokio::sync::mpsc::UnboundedSender, - rust_to_python: Mutex>, - notify_fd: u64, + channel: Arc>, } impl Drop for ConnPool { @@ -209,6 +204,7 @@ async fn run_and_block(config: PoolConfig, rpc_pipe: RpcPipe, stats_interval: f6 if last_stats.elapsed() > stats_interval { last_stats = Instant::now(); if rpc_pipe + .channel .write(RustToPythonMessage::Metrics( serde_pickle::to_vec(&pool.metrics(), SerOptions::new()) .unwrap_or_default(), @@ -224,8 +220,7 @@ async fn run_and_block(config: PoolConfig, rpc_pipe: RpcPipe, stats_interval: f6 }; loop { - let Some(rpc) = poll_fn(|cx| rpc_pipe.python_to_rust.borrow_mut().poll_recv(cx)).await - else { + let Some(rpc) = rpc_pipe.channel.recv().await else { info!("ConnPool shutting down"); pool_task.abort(); pool.shutdown().await; @@ -242,6 +237,7 @@ async fn run_and_block(config: PoolConfig, rpc_pipe: RpcPipe, stats_interval: f6 Ok(conn) => conn, Err(ConnError::Underlying(err)) => { _ = rpc_pipe + .channel .write(RustToPythonMessage::Failed(conn_id, err)) .await; return; @@ -254,6 +250,7 @@ async fn run_and_block(config: PoolConfig, rpc_pipe: RpcPipe, stats_interval: f6 let handle = conn.handle(); rpc_pipe.handles.borrow_mut().insert(conn_id, conn); _ = rpc_pipe + .channel .write(RustToPythonMessage::Acquired(conn_id, handle)) .await; } @@ -270,7 +267,10 @@ async fn run_and_block(config: PoolConfig, rpc_pipe: RpcPipe, stats_interval: f6 } Prune(conn_id, db) => { pool.drain_idle(&db).await; - _ = rpc_pipe.write(RustToPythonMessage::Pruned(conn_id)).await; + _ = rpc_pipe + .channel + .write(RustToPythonMessage::Pruned(conn_id)) + .await; } CompletedAsync(handle_id) => { rpc_pipe.async_ops.borrow_mut().remove(&handle_id); @@ -302,8 +302,6 @@ impl ConnPool { let level = get_python_logger_level(py, "edb.server.conn_pool")?; let min_idle_time_before_gc = min_idle_time_before_gc as usize; let new = py.allow_threads(|| { - let (txrp, rxrp) = std::sync::mpsc::channel(); - let (txpr, rxpr) = tokio::sync::mpsc::unbounded_channel(); let (txfd, rxfd) = std::sync::mpsc::channel(); thread::spawn(move || { initialize_logging_in_thread("edb.server.conn_pool", level); @@ -315,15 +313,13 @@ impl ConnPool { .build() .unwrap(); let _guard = rt.enter(); - let (txn, rxn) = tokio::net::unix::pipe::pipe().unwrap(); - let fd = rxn.into_nonblocking_fd().unwrap().into_raw_fd() as u64; - txfd.send(fd).unwrap(); + + let (rust, python) = new_python_channel(); + txfd.send(python).unwrap(); let local = LocalSet::new(); let rpc_pipe = RpcPipe { - python_to_rust: rxpr.into(), - rust_to_python: txrp, - rust_to_python_notify: txn.into(), + channel: rust, next_id: Default::default(), handles: Default::default(), async_ops: Default::default(), @@ -334,85 +330,46 @@ impl ConnPool { local.block_on(&rt, run_and_block(config, rpc_pipe, stats_interval)); }); - let notify_fd = rxfd.recv().unwrap(); + let channel = rxfd.recv().unwrap().into(); ConnPool { - python_to_rust: txpr, - rust_to_python: Mutex::new(rxrp), - notify_fd, + channel, } }); Ok(new) } #[getter] - fn _fd(&self) -> u64 { - self.notify_fd + fn _channel(&self) -> PyResult { + Ok(PythonChannel::new(self.channel.clone())) } fn _acquire(&self, id: u64, db: &str) -> PyResult<()> { - self.python_to_rust + self.channel .send(PythonToRustMessage::Acquire(id, db.to_owned())) .map_err(|_| internal_error("In shutdown")) } fn _release(&self, id: u64) -> PyResult<()> { - self.python_to_rust - .send(PythonToRustMessage::Release(id)) - .map_err(|_| internal_error("In shutdown")) + self.channel.send_err(PythonToRustMessage::Release(id)) } fn _discard(&self, id: u64) -> PyResult<()> { - self.python_to_rust - .send(PythonToRustMessage::Discard(id)) - .map_err(|_| internal_error("In shutdown")) + self.channel.send_err(PythonToRustMessage::Discard(id)) } fn _completed(&self, id: u64) -> PyResult<()> { - self.python_to_rust - .send(PythonToRustMessage::CompletedAsync(ConnHandleId(id))) - .map_err(|_| internal_error("In shutdown")) + self.channel + .send_err(PythonToRustMessage::CompletedAsync(ConnHandleId(id))) } fn _failed(&self, id: u64, _error: PyObject) -> PyResult<()> { - self.python_to_rust - .send(PythonToRustMessage::FailedAsync(ConnHandleId(id))) - .map_err(|_| internal_error("In shutdown")) + self.channel + .send_err(PythonToRustMessage::FailedAsync(ConnHandleId(id))) } fn _prune(&self, id: u64, db: &str) -> PyResult<()> { - self.python_to_rust - .send(PythonToRustMessage::Prune(id, db.to_owned())) - .map_err(|_| internal_error("In shutdown")) - } - - fn _read(&self, py: Python<'_>) -> PyResult> { - let Ok(msg) = self - .rust_to_python - .try_lock() - .expect("Unsafe thread access") - .try_recv() - else { - return Ok(py.None()); - }; - msg.to_object(py) - } - - fn _try_read(&self, py: Python<'_>) -> PyResult> { - let Ok(msg) = self - .rust_to_python - .try_lock() - .expect("Unsafe thread access") - .try_recv() - else { - return Ok(py.None()); - }; - msg.to_object(py) - } - - fn _close_pipe(&mut self) { - // Replace the channel with a dummy, closed one which will also - // signal the other side to exit. - self.rust_to_python = Mutex::new(std::sync::mpsc::channel().1); + self.channel + .send_err(PythonToRustMessage::Prune(id, db.to_owned())) } } diff --git a/rust/gel-http/src/python.rs b/rust/gel-http/src/python.rs index 184d3e6bd86..db9316f97ab 100644 --- a/rust/gel-http/src/python.rs +++ b/rust/gel-http/src/python.rs @@ -1,16 +1,20 @@ use eventsource_stream::Eventsource; -use futures::{future::poll_fn, TryStreamExt}; +use futures::TryStreamExt; use http::{HeaderMap, HeaderName, HeaderValue, Uri}; use http_body_util::BodyExt; -use pyo3::{exceptions::PyException, prelude::*, types::PyByteArray}; -use pyo3_util::logging::{get_python_logger_level, initialize_logging_in_thread}; +use pyo3::{ + exceptions::{PyException, PyValueError}, + prelude::*, + types::PyByteArray, +}; +use pyo3_util::{ + channel::{new_python_channel, PythonChannel, PythonChannelImpl, RustChannel}, + logging::{get_python_logger_level, initialize_logging_in_thread}, +}; use reqwest::Method; use scopeguard::{defer, guard, ScopeGuard}; use std::{ - cell::RefCell, collections::HashMap, - os::fd::IntoRawFd, - pin::Pin, rc::Rc, str::FromStr, sync::{Arc, Mutex}, @@ -18,11 +22,10 @@ use std::{ time::Duration, }; use tokio::{ - io::AsyncWrite, sync::{AcquireError, Semaphore, SemaphorePermit}, task::{JoinHandle, LocalSet}, }; -use tracing::{error, info, trace}; +use tracing::{info, trace}; use crate::cache::{Cache, CacheBefore}; @@ -33,6 +36,8 @@ const SSE_QUEUE_SIZE: usize = 100; type PythonConnId = u64; +type RpcPipe = RustChannel; + #[derive(Debug)] enum RustToPythonMessage { Response(PythonConnId, (u16, Vec, HashMap)), @@ -41,22 +46,26 @@ enum RustToPythonMessage { SSEEnd(PythonConnId), Error(PythonConnId, String), } -impl RustToPythonMessage { - fn to_object(&self, py: Python<'_>) -> PyResult { + +impl<'py> IntoPyObject<'py> for RustToPythonMessage { + type Target = PyAny; + type Output = Bound<'py, PyAny>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> PyResult> { use RustToPythonMessage::*; - trace!("Read: {self:?}"); - match self { + let res = match self { Error(conn, error) => (0, conn, error).into_pyobject(py), Response(conn, (status, body, headers)) => { - (1, conn, (status, PyByteArray::new(py, body), headers)).into_pyobject(py) + (1, conn, (status, PyByteArray::new(py, &body), headers)).into_pyobject(py) } SSEStart(conn, (status, headers)) => (2, conn, (status, headers)).into_pyobject(py), SSEEvent(conn, message) => { (3, conn, (&message.id, &message.data, &message.event)).into_pyobject(py) } SSEEnd(conn) => (4, conn, ()).into_pyobject(py), - } - .map(|e| e.into()) + }?; + Ok(res.into_any()) } } @@ -81,54 +90,13 @@ enum PythonToRustMessage { Ack(PythonConnId), } -type PipeSender = tokio::net::unix::pipe::Sender; - -struct RpcPipe { - rust_to_python_notify: RefCell, - rust_to_python: std::sync::mpsc::Sender, - python_to_rust: RefCell>, -} - -impl std::fmt::Debug for RpcPipe { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("RpcPipe") - } -} - -impl RpcPipe { - async fn write(&self, msg: RustToPythonMessage) -> Result<(), String> { - trace!("Rust -> Python: {msg:?}"); - self.rust_to_python.send(msg).map_err(|_| "Shutdown")?; - // If we're shutting down, this may fail (but that's OK) - poll_fn(|cx| { - let pipe = &mut *self.rust_to_python_notify.borrow_mut(); - let this = Pin::new(pipe); - this.poll_write(cx, &[0]) - }) - .await - .map_err(|_| "Shutdown")?; - Ok(()) - } -} - -#[pyclass] -struct Http { - python_to_rust: tokio::sync::mpsc::UnboundedSender, - rust_to_python: Mutex>, - notify_fd: u64, -} - -impl Drop for Http { - fn drop(&mut self) { - info!("Http dropped"); +impl<'py> FromPyObject<'py> for PythonToRustMessage { + fn extract_bound(_: &Bound<'py, PyAny>) -> PyResult { + // Unused for this class + Err(PyValueError::new_err("Not implemented")) } } -fn internal_error(message: &str) -> PyErr { - error!("{message}"); - InternalError::new_err(()) -} - /// If this is likely a stream, returns the `Stream` variant. /// Otherwise, returns the `Bytes` variant. enum MaybeResponse { @@ -478,8 +446,7 @@ async fn run_and_block(capacity: usize, rpc_pipe: RpcPipe) { let tasks = Arc::new(Mutex::new(HashMap::::new())); loop { - let Some(rpc) = poll_fn(|cx| rpc_pipe.python_to_rust.borrow_mut().poll_recv(cx)).await - else { + let Some(rpc) = rpc_pipe.recv().await else { info!("Http shutting down"); break; }; @@ -606,13 +573,9 @@ async fn execute( } } -impl Http { - fn send(&self, msg: PythonToRustMessage) -> PyResult<()> { - trace!("Python -> Rust: {msg:?}"); - self.python_to_rust - .send(msg) - .map_err(|_| internal_error("In shutdown")) - } +#[pyclass] +struct Http { + python: Arc>, } #[pymethods] @@ -624,8 +587,6 @@ impl Http { let level = get_python_logger_level(py, "edgedb.server.http")?; info!("Http::new(max_capacity={max_capacity})"); - let (txrp, rxrp) = std::sync::mpsc::channel(); - let (txpr, rxpr) = tokio::sync::mpsc::unbounded_channel(); let (txfd, rxfd) = std::sync::mpsc::channel(); thread::Builder::new() @@ -640,32 +601,23 @@ impl Http { .build() .unwrap(); let _guard = rt.enter(); - let (txn, rxn) = tokio::net::unix::pipe::pipe().unwrap(); - let fd = rxn.into_nonblocking_fd().unwrap().into_raw_fd() as u64; - txfd.send(fd).unwrap(); - let local = LocalSet::new(); - let rpc_pipe = RpcPipe { - python_to_rust: rxpr.into(), - rust_to_python: txrp, - rust_to_python_notify: txn.into(), - }; + let (rust, python) = new_python_channel(); + txfd.send(python).unwrap(); + let local = LocalSet::new(); - local.block_on(&rt, run_and_block(max_capacity, rpc_pipe)); + local.block_on(&rt, run_and_block(max_capacity, rust)); }) .expect("Failed to create HTTP thread"); - let notify_fd = rxfd.recv().unwrap(); Ok(Http { - python_to_rust: txpr, - rust_to_python: Mutex::new(rxrp), - notify_fd, + python: Arc::new(rxfd.recv().unwrap()), }) } #[getter] - fn _fd(&self) -> u64 { - self.notify_fd + fn _channel(&self) -> PyResult { + Ok(PythonChannel::new(self.python.clone())) } fn _request( @@ -677,7 +629,7 @@ impl Http { headers: Vec<(String, String)>, cache: bool, ) -> PyResult<()> { - self.send(PythonToRustMessage::Request( + self.python.send_err(PythonToRustMessage::Request( id, url, method, body, headers, cache, )) } @@ -690,52 +642,22 @@ impl Http { body: Vec, headers: Vec<(String, String)>, ) -> PyResult<()> { - self.send(PythonToRustMessage::RequestSse( + self.python.send_err(PythonToRustMessage::RequestSse( id, url, method, body, headers, )) } fn _close(&self, id: PythonConnId) -> PyResult<()> { - self.send(PythonToRustMessage::Close(id)) + self.python.send_err(PythonToRustMessage::Close(id)) } fn _ack_sse(&self, id: PythonConnId) -> PyResult<()> { - self.send(PythonToRustMessage::Ack(id)) + self.python.send_err(PythonToRustMessage::Ack(id)) } fn _update_limit(&self, limit: usize) -> PyResult<()> { - self.send(PythonToRustMessage::UpdateLimit(limit)) - } - - fn _read(&self, py: Python<'_>) -> PyResult { - let Ok(msg) = self - .rust_to_python - .try_lock() - .expect("Unsafe thread access") - .recv() - else { - return Ok(py.None()); - }; - msg.to_object(py) - } - - fn _try_read(&self, py: Python<'_>) -> PyResult { - let Ok(msg) = self - .rust_to_python - .try_lock() - .expect("Unsafe thread access") - .try_recv() - else { - return Ok(py.None()); - }; - msg.to_object(py) - } - - fn _close_pipe(&mut self) { - trace!("Closing pipe"); - // Replace the channel with a dummy, closed one which will also - // signal the other side to exit. - self.rust_to_python = Mutex::new(std::sync::mpsc::channel().1); + self.python + .send_err(PythonToRustMessage::UpdateLimit(limit)) } } diff --git a/rust/gel-stream/src/python/mod.rs b/rust/gel-stream/src/python/mod.rs new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/rust/gel-stream/src/python/mod.rs @@ -0,0 +1 @@ + diff --git a/rust/pgrust/src/python/async_connector.rs b/rust/pgrust/src/python/async_connector.rs new file mode 100644 index 00000000000..15c07d994dd --- /dev/null +++ b/rust/pgrust/src/python/async_connector.rs @@ -0,0 +1,7 @@ +use tokio::net::UnixSocket; + +struct AsyncConnector { + unix_socket: UnixSocket, +} + +impl AsyncConnector {} diff --git a/rust/pgrust/src/python.rs b/rust/pgrust/src/python/mod.rs similarity index 99% rename from rust/pgrust/src/python.rs rename to rust/pgrust/src/python/mod.rs index a31c515fbd4..7fcd55e79e1 100644 --- a/rust/pgrust/src/python.rs +++ b/rust/pgrust/src/python/mod.rs @@ -27,6 +27,8 @@ use std::collections::HashMap; use std::{borrow::Cow, path::Path}; use tracing::warn; +mod async_connector; + #[derive(Clone, Copy, PartialEq, Eq)] #[pyclass(eq, eq_int)] pub enum SSLMode { diff --git a/rust/pyo3_util/Cargo.toml b/rust/pyo3_util/Cargo.toml index 4f2e029a3f6..fb75983a496 100644 --- a/rust/pyo3_util/Cargo.toml +++ b/rust/pyo3_util/Cargo.toml @@ -12,6 +12,7 @@ workspace = true pyo3.workspace = true tracing.workspace = true tracing-subscriber.workspace = true +tokio.workspace = true scopeguard = "1" diff --git a/rust/pyo3_util/src/channel.rs b/rust/pyo3_util/src/channel.rs new file mode 100644 index 00000000000..b4748f4968f --- /dev/null +++ b/rust/pyo3_util/src/channel.rs @@ -0,0 +1,193 @@ +use std::{ + cell::RefCell, + future::poll_fn, + os::fd::IntoRawFd, + pin::Pin, + sync::{Arc, Mutex}, +}; + +use pyo3::{ + exceptions::PyException, prelude::*, BoundObject, FromPyObject, IntoPyObject, PyAny, PyResult, +}; +use tokio::io::AsyncWrite; +use tracing::{error, trace}; + +pyo3::create_exception!(_channel, InternalError, PyException); + +fn internal_error(message: &str) -> PyErr { + error!("{message}"); + InternalError::new_err(()) +} + +pub trait RustToPython: for<'py> IntoPyObject<'py> + Send + std::fmt::Debug {} +pub trait PythonToRust: for<'py> FromPyObject<'py> + Send + std::fmt::Debug {} + +impl RustToPython for T where T: for<'py> IntoPyObject<'py> + Send + std::fmt::Debug {} +impl PythonToRust for T where T: for<'py> FromPyObject<'py> + Send + std::fmt::Debug {} + +/// A channel that can be used to send and receive messages between Rust and Python. +pub struct RustChannel FromPyObject<'py>, TX: for<'py> IntoPyObject<'py> + Send> { + rust_to_python_notify: RefCell, + rust_to_python: std::sync::mpsc::Sender, + python_to_rust: RefCell>, +} + +impl RustChannel { + pub async fn recv(&self) -> Option { + let msg = self.python_to_rust.borrow_mut().recv().await; + msg + } + + pub async fn write(&self, msg: TX) -> Result<(), String> { + trace!("Rust -> Python: {msg:?}"); + self.rust_to_python.send(msg).map_err(|_| "Shutdown")?; + // If we're shutting down, this may fail (but that's OK) + poll_fn(|cx| { + let pipe = &mut *self.rust_to_python_notify.borrow_mut(); + let this = Pin::new(pipe); + this.poll_write(cx, &[0]) + }) + .await + .map_err(|_| "Shutdown")?; + Ok(()) + } +} + +pub struct PythonChannelImpl { + python_to_rust: tokio::sync::mpsc::UnboundedSender, + rust_to_python: Mutex>, + notify_fd: u64, +} + +impl PythonChannelImpl { + pub fn send(&self, msg: RX) -> Result<(), RX> { + self.python_to_rust.send(msg).map_err(|e| e.0) + } + + pub fn send_err(&self, msg: RX) -> PyResult<()> { + self.python_to_rust + .send(msg) + .map_err(|_| internal_error("In shutdown")) + } +} + +pub trait PythonChannelProtocol: Send + Sync { + fn _write<'py>(&self, py: Python<'py>, msg: Py) -> PyResult<()>; + fn _read<'py>(&self, py: Python<'py>) -> PyResult>; + fn _try_read<'py>(&self, py: Python<'py>) -> PyResult>; + fn _close_pipe(&mut self) -> (); + fn _fd(&self) -> u64; +} + +impl PythonChannelProtocol for Arc> { + fn _write<'py>(&self, py: Python<'py>, msg: Py) -> PyResult<()> { + let msg = msg.extract(py)?; + trace!("Python -> Rust: {msg:?}"); + self.python_to_rust + .send(msg) + .map_err(|_| internal_error("In shutdown")) + } + + fn _read<'py>(&self, py: Python<'py>) -> PyResult> { + let Ok(msg) = self + .rust_to_python + .try_lock() + .expect("Unsafe thread access") + .try_recv() + else { + return Ok(py.None().into_bound(py)); + }; + Ok(msg + .into_pyobject(py) + .map_err(|e| e.into())? + .into_bound() + .into_any()) + } + + fn _try_read<'py>(&self, py: Python<'py>) -> PyResult> { + let Ok(msg) = self + .rust_to_python + .try_lock() + .expect("Unsafe thread access") + .try_recv() + else { + return Ok(py.None().into_bound(py)); + }; + + Ok(msg + .into_pyobject(py) + .map_err(|e| e.into())? + .into_bound() + .into_any()) + } + + fn _close_pipe(&mut self) -> () { + *self + .rust_to_python + .try_lock() + .expect("Unsafe thread access") = std::sync::mpsc::channel().1; + } + + fn _fd(&self) -> u64 { + self.notify_fd + } +} + +#[pyclass] +pub struct PythonChannel { + _impl: Box, +} + +impl PythonChannel { + pub fn new(imp: T) -> Self { + Self { + _impl: Box::new(imp), + } + } +} + +#[pymethods] +impl PythonChannel { + fn _write<'py>(&self, py: Python<'py>, msg: Py) -> PyResult<()> { + self._impl._write(py, msg) + } + + fn _read<'py>(&self, py: Python<'py>) -> PyResult> { + self._impl._read(py) + } + + fn _try_read<'py>(&self, py: Python<'py>) -> PyResult> { + self._impl._try_read(py) + } + + fn _close_pipe(&mut self) { + // Replace the channel with a dummy, closed one which will also + // signal the other side to exit. + self._impl._close_pipe() + } + + #[getter] + fn _fd(&self) -> u64 { + self._impl._fd() + } +} + +/// Create a new Python <-> Rust channel from within a tokio runtime. +pub fn new_python_channel( +) -> (RustChannel, PythonChannelImpl) { + let (tx_sync, rx_sync) = std::sync::mpsc::channel(); + let (tx_async, rx_async) = tokio::sync::mpsc::unbounded_channel(); + let (tx_pipe, rx_pipe) = tokio::net::unix::pipe::pipe().unwrap(); + let notify_fd = rx_pipe.into_nonblocking_fd().unwrap().into_raw_fd() as u64; + let rust = RustChannel { + rust_to_python_notify: RefCell::new(tx_pipe), + rust_to_python: tx_sync, + python_to_rust: RefCell::new(rx_async), + }; + let python = PythonChannelImpl { + python_to_rust: tx_async, + rust_to_python: Mutex::new(rx_sync), + notify_fd, + }; + (rust, python) +} diff --git a/rust/pyo3_util/src/lib.rs b/rust/pyo3_util/src/lib.rs index 31348d2fdac..72b9a37e34c 100644 --- a/rust/pyo3_util/src/lib.rs +++ b/rust/pyo3_util/src/lib.rs @@ -1 +1,2 @@ +pub mod channel; pub mod logging;