diff --git a/Cargo.lock b/Cargo.lock index 37d53be81157..289756992974 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -543,18 +543,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "edbrust" -version = "0.1.0" -dependencies = [ - "conn_pool", - "http 0.1.0", - "pgrust", - "pyo3", - "pyo3_util", - "tokio", -] - [[package]] name = "edgedb-errors" version = "0.4.2" @@ -2010,6 +1998,18 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "rust_native" +version = "0.1.0" +dependencies = [ + "conn_pool", + "http 0.1.0", + "pgrust", + "pyo3", + "pyo3_util", + "tokio", +] + [[package]] name = "rustc-demangle" version = "0.1.24" diff --git a/edb/server/_rust_native/Cargo.toml b/edb/server/_rust_native/Cargo.toml index e4bd8d0e2185..08b6927a7974 100644 --- a/edb/server/_rust_native/Cargo.toml +++ b/edb/server/_rust_native/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "edbrust" +name = "rust_native" version = "0.1.0" license = "MIT/Apache-2.0" authors = ["MagicStack Inc. "] @@ -12,7 +12,7 @@ workspace = true python_extension = ["pyo3/extension-module", "pyo3/serde"] [dependencies] -pyo3 = { workspace = true, optional = true } +pyo3 = { workspace = true } tokio.workspace = true pyo3_util.workspace = true conn_pool = { workspace = true, features = [ "python_extension" ] } diff --git a/edb/server/_rust_native/src/lib.rs b/edb/server/_rust_native/src/lib.rs index bae2b70cdfab..a34aae794628 100644 --- a/edb/server/_rust_native/src/lib.rs +++ b/edb/server/_rust_native/src/lib.rs @@ -1,31 +1,34 @@ use pyo3::{ - pyfunction, pymodule, + pymodule, types::{PyAnyMethods, PyModule, PyModuleMethods}, - wrap_pyfunction, Bound, IntoPy, Py, PyAny, PyResult, Python, + Bound, PyResult, Python, }; -#[pymodule(name = "module")] -fn _rust_native(py: Python, m: &Bound) -> PyResult<()> { - let child_module = PyModule::new_bound(py, "edb.server._rust_native.module._conn_pool")?; - conn_pool::python::_conn_pool(py, &child_module)?; - m.add("_conn_pool", &child_module)?; - py.import_bound("sys")? - .getattr("modules")? - .set_item("edb.server._rust_native.module._conn_pool", child_module)?; +const MODULE_PREFIX: &str = "edb.server._rust_native"; + +fn add_child_module( + py: Python, + parent: &Bound, + name: &str, + init_fn: fn(Python, &Bound) -> PyResult<()>, +) -> PyResult<()> { + let full_name = format!("{}.{}", MODULE_PREFIX, name); + let child_module = PyModule::new(py, &full_name)?; + init_fn(py, &child_module)?; + parent.add(name, &child_module)?; + + // Add the child module to the sys.modules dictionary so it can be imported + // by name. + let sys_modules = py.import("sys")?.getattr("modules")?; + sys_modules.set_item(full_name, child_module)?; + Ok(()) +} - let child_module = PyModule::new_bound(py, "edb.server._rust_native.module._pg_rust")?; - pgrust::python::_pg_rust(py, &child_module)?; - m.add("_pg_rust", &child_module)?; - py.import_bound("sys")? - .getattr("modules")? - .set_item("edb.server._rust_native.module._pg_rust", child_module)?; +#[pymodule] +fn _rust_native(py: Python, m: &Bound) -> PyResult<()> { + add_child_module(py, m, "_conn_pool", conn_pool::python::_conn_pool)?; + add_child_module(py, m, "_pg_rust", pgrust::python::_pg_rust)?; + add_child_module(py, m, "_http", http::python::_http)?; - let child_module = PyModule::new_bound(py, "edb.server._rust_native.module._http")?; - http::python::_http(py, &child_module)?; - m.add("_http", &child_module)?; - py.import_bound("sys")? - .getattr("modules")? - .set_item("edb.server._rust_native.module._http", child_module)?; - Ok(()) } diff --git a/edb/server/connpool/pool2.py b/edb/server/connpool/pool2.py index f5b12bdfad50..67a9997478af 100644 --- a/edb/server/connpool/pool2.py +++ b/edb/server/connpool/pool2.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import edb.server._conn_pool +import edb.server._rust_native._conn_pool as _rust import asyncio import time import typing @@ -26,7 +26,7 @@ from .config import logger from edb.server import rust_async_channel -guard = edb.server._conn_pool.LoggingGuard() +guard = _rust.LoggingGuard() # Connections must be hashable because we use them to reverse-lookup # an internal ID. @@ -86,7 +86,7 @@ def __call__(self, stats: Snapshot) -> None: class Pool(typing.Generic[C]): - _pool: edb.server._conn_pool.ConnPool + _pool: _rust.ConnPool _next_conn_id: int _failed_connects: int _failed_disconnects: int @@ -123,7 +123,7 @@ def __init__( ) self._connect = connect self._disconnect = disconnect - self._pool = edb.server._conn_pool.ConnPool( + self._pool = _rust.ConnPool( max_capacity, min_idle_time_before_gc, config.STATS_COLLECT_INTERVAL ) self._max_capacity = max_capacity @@ -360,11 +360,11 @@ def _build_snapshot(self, *, now: float) -> Snapshot: v = stats['value'] block_snapshot = BlockSnapshot( dbname=dbname, - nconns=v[edb.server._conn_pool.METRIC_ACTIVE], - nwaiters_avg=v[edb.server._conn_pool.METRIC_WAITING], - npending=v[edb.server._conn_pool.METRIC_CONNECTING] - + v[edb.server._conn_pool.METRIC_RECONNECTING], - nwaiters=v[edb.server._conn_pool.METRIC_WAITING], + nconns=v[_rust.METRIC_ACTIVE], + nwaiters_avg=v[_rust.METRIC_WAITING], + npending=v[_rust.METRIC_CONNECTING] + + v[_rust.METRIC_RECONNECTING], + nwaiters=v[_rust.METRIC_WAITING], quota=stats['target'], ) blocks.append(block_snapshot) diff --git a/edb/server/http.py b/edb/server/http.py index ae37b148af21..ed6ef329dba6 100644 --- a/edb/server/http.py +++ b/edb/server/http.py @@ -37,7 +37,7 @@ import time from http import HTTPStatus as HTTPStatus -from edb.server._http import Http +from edb.server._rust_native._http import Http from . import rust_async_channel logger = logging.getLogger("edb.server") diff --git a/edb/server/pgcon/rust_transport.py b/edb/server/pgcon/rust_transport.py index f3ae6b81eb43..cc37d724b103 100644 --- a/edb/server/pgcon/rust_transport.py +++ b/edb/server/pgcon/rust_transport.py @@ -37,7 +37,7 @@ from enum import Enum, auto from dataclasses import dataclass -from edb.server import _pg_rust as pgrust +from edb.server._rust_native import _pg_rust as pgrust from edb.server.pgconnparams import ( ConnectionParams, SSLMode, diff --git a/edb/server/pgconnparams.py b/edb/server/pgconnparams.py index 4bd1214ecfd5..cd5b7af7ddde 100644 --- a/edb/server/pgconnparams.py +++ b/edb/server/pgconnparams.py @@ -22,7 +22,7 @@ import platform import warnings -from edb.server._pg_rust import PyConnectionParams +from edb.server._rust_native._pg_rust import PyConnectionParams _system = platform.uname().system if _system == 'Windows':