Skip to content

Commit

Permalink
Migrate to the new modules
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Nov 22, 2024
1 parent f737ca0 commit d79e421
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 49 deletions.
24 changes: 12 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions edb/server/_rust_native/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[package]
name = "edbrust"
name = "rust_native"
version = "0.1.0"
license = "MIT/Apache-2.0"
authors = ["MagicStack Inc. <[email protected]>"]
Expand All @@ -12,7 +12,7 @@ workspace = true
python_extension = ["pyo3/extension-module", "pyo3/serde"]

[dependencies]
pyo3 = { workspace = true, optional = true }
pyo3 = { workspace = true }
tokio.workspace = true
pyo3_util.workspace = true
conn_pool = { workspace = true, features = [ "python_extension" ] }
Expand Down
49 changes: 26 additions & 23 deletions edb/server/_rust_native/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
use pyo3::{
pyfunction, pymodule,
pymodule,
types::{PyAnyMethods, PyModule, PyModuleMethods},
wrap_pyfunction, Bound, IntoPy, Py, PyAny, PyResult, Python,
Bound, PyResult, Python,
};

#[pymodule(name = "module")]
fn _rust_native(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
let child_module = PyModule::new_bound(py, "edb.server._rust_native.module._conn_pool")?;
conn_pool::python::_conn_pool(py, &child_module)?;
m.add("_conn_pool", &child_module)?;
py.import_bound("sys")?
.getattr("modules")?
.set_item("edb.server._rust_native.module._conn_pool", child_module)?;
const MODULE_PREFIX: &str = "edb.server._rust_native";

fn add_child_module(
py: Python,
parent: &Bound<PyModule>,
name: &str,
init_fn: fn(Python, &Bound<PyModule>) -> PyResult<()>,
) -> PyResult<()> {
let full_name = format!("{}.{}", MODULE_PREFIX, name);
let child_module = PyModule::new(py, &full_name)?;
init_fn(py, &child_module)?;
parent.add(name, &child_module)?;

// Add the child module to the sys.modules dictionary so it can be imported
// by name.
let sys_modules = py.import("sys")?.getattr("modules")?;
sys_modules.set_item(full_name, child_module)?;
Ok(())
}

let child_module = PyModule::new_bound(py, "edb.server._rust_native.module._pg_rust")?;
pgrust::python::_pg_rust(py, &child_module)?;
m.add("_pg_rust", &child_module)?;
py.import_bound("sys")?
.getattr("modules")?
.set_item("edb.server._rust_native.module._pg_rust", child_module)?;
#[pymodule]
fn _rust_native(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
add_child_module(py, m, "_conn_pool", conn_pool::python::_conn_pool)?;
add_child_module(py, m, "_pg_rust", pgrust::python::_pg_rust)?;
add_child_module(py, m, "_http", http::python::_http)?;

let child_module = PyModule::new_bound(py, "edb.server._rust_native.module._http")?;
http::python::_http(py, &child_module)?;
m.add("_http", &child_module)?;
py.import_bound("sys")?
.getattr("modules")?
.set_item("edb.server._rust_native.module._http", child_module)?;

Ok(())
}
18 changes: 9 additions & 9 deletions edb/server/connpool/pool2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import edb.server._conn_pool
import edb.server._rust_native._conn_pool as _rust
import asyncio
import time
import typing
Expand All @@ -26,7 +26,7 @@
from .config import logger
from edb.server import rust_async_channel

guard = edb.server._conn_pool.LoggingGuard()
guard = _rust.LoggingGuard()

# Connections must be hashable because we use them to reverse-lookup
# an internal ID.
Expand Down Expand Up @@ -86,7 +86,7 @@ def __call__(self, stats: Snapshot) -> None:


class Pool(typing.Generic[C]):
_pool: edb.server._conn_pool.ConnPool
_pool: _rust.ConnPool
_next_conn_id: int
_failed_connects: int
_failed_disconnects: int
Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(
)
self._connect = connect
self._disconnect = disconnect
self._pool = edb.server._conn_pool.ConnPool(
self._pool = _rust.ConnPool(
max_capacity, min_idle_time_before_gc, config.STATS_COLLECT_INTERVAL
)
self._max_capacity = max_capacity
Expand Down Expand Up @@ -360,11 +360,11 @@ def _build_snapshot(self, *, now: float) -> Snapshot:
v = stats['value']
block_snapshot = BlockSnapshot(
dbname=dbname,
nconns=v[edb.server._conn_pool.METRIC_ACTIVE],
nwaiters_avg=v[edb.server._conn_pool.METRIC_WAITING],
npending=v[edb.server._conn_pool.METRIC_CONNECTING]
+ v[edb.server._conn_pool.METRIC_RECONNECTING],
nwaiters=v[edb.server._conn_pool.METRIC_WAITING],
nconns=v[_rust.METRIC_ACTIVE],
nwaiters_avg=v[_rust.METRIC_WAITING],
npending=v[_rust.METRIC_CONNECTING]
+ v[_rust.METRIC_RECONNECTING],
nwaiters=v[_rust.METRIC_WAITING],
quota=stats['target'],
)
blocks.append(block_snapshot)
Expand Down
2 changes: 1 addition & 1 deletion edb/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import time
from http import HTTPStatus as HTTPStatus

from edb.server._http import Http
from edb.server._rust_native._http import Http
from . import rust_async_channel

logger = logging.getLogger("edb.server")
Expand Down
2 changes: 1 addition & 1 deletion edb/server/pgcon/rust_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from enum import Enum, auto
from dataclasses import dataclass

from edb.server import _pg_rust as pgrust
from edb.server._rust_native import _pg_rust as pgrust
from edb.server.pgconnparams import (
ConnectionParams,
SSLMode,
Expand Down
2 changes: 1 addition & 1 deletion edb/server/pgconnparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import platform
import warnings

from edb.server._pg_rust import PyConnectionParams
from edb.server._rust_native._pg_rust import PyConnectionParams

_system = platform.uname().system
if _system == 'Windows':
Expand Down

0 comments on commit d79e421

Please sign in to comment.