Skip to content

Commit

Permalink
Rework the async communication with Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Nov 18, 2024
1 parent 8ddfd73 commit fa6bd8f
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 48 deletions.
6 changes: 6 additions & 0 deletions edb/server/conn_pool/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ impl ConnPool {
};
msg.to_object(py)
}

fn _close_pipe(&mut self) {
// Replace the channel with a dummy, closed one which will also
// signal the other side to exit.
(_, self.rust_to_python) = std::sync::mpsc::channel();
}
}

/// Ensure that logging does not outlive the Python runtime.
Expand Down
35 changes: 9 additions & 26 deletions edb/server/connpool/pool2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from . import config
from .config import logger
from edb.server import rust_async_channel

guard = edb.server._conn_pool.LoggingGuard()

Expand Down Expand Up @@ -101,7 +102,6 @@ class Pool(typing.Generic[C]):
_errors: dict[int, BaseException]
_conns_held: dict[C, int]
_loop: asyncio.AbstractEventLoop
_skip_reads: int
_counts: typing.Any
_stats_collector: typing.Optional[StatsCollector]

Expand Down Expand Up @@ -130,10 +130,12 @@ def __init__(self, *, connect: Connector[C],
self._errors = {}
self._conns_held = {}
self._prunes = {}
self._skip_reads = 0
self._channel = None

self._loop = asyncio.get_running_loop()
self._task = self._loop.create_task(self._boot(self._loop))
self._channel = rust_async_channel.RustAsyncChannel(self._pool, self._process_message)

self._task = self._loop.create_task(self._boot())

self._failed_connects = 0
self._failed_disconnects = 0
Expand Down Expand Up @@ -170,34 +172,15 @@ async def close(self) -> None:
self._pool = None
logger.info("Closed connection pool")

async def _boot(self, loop: asyncio.AbstractEventLoop) -> None:
async def _boot(self) -> None:
logger.info("Python-side connection pool booted")
reader = asyncio.StreamReader(loop=loop)
reader_protocol = asyncio.StreamReaderProtocol(reader)
fd = os.fdopen(self._pool._fd, 'rb')
transport, _ = await loop.connect_read_pipe(lambda: reader_protocol, fd)
try:
while len(await reader.read(1)) == 1:
if not self._pool or not self._task:
break
if self._skip_reads > 0:
self._skip_reads -= 1
continue
msg = self._pool._read()
if not msg:
break
self._process_message(msg)

await self._channel.run()
finally:
transport.close()
self._channel.close()

# Allow readers to skip the self-pipe for performing reads which may reduce
# latency a small degree. We'll still need to eventually pick up a self-pipe
# read but we increment a counter to skip at that point.
def _try_read(self) -> None:
while msg := self._pool._try_read():
self._skip_reads += 1
self._process_message(msg)
self._channel.read_hint()

def _process_message(self, msg: typing.Any) -> None:
# If we're closing, don't dispatch any operations
Expand Down
24 changes: 14 additions & 10 deletions edb/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def __init__(
self._stat_callback = stat_callback

def __del__(self) -> None:
self.close()
if self._task is not None:
logger.error(f"HttpClient {id(self)} was not closed")

def __enter__(self) -> HttpClient:
return self
Expand All @@ -102,9 +103,7 @@ def _ensure_task(self):
raise Exception("HttpClient was closed")
if self._task is None:
self._client = Http(self._limit)
self._task = self._loop.create_task(
self._boot(self._loop, self._client._fd)
)
self._task = self._loop.create_task(self._boot(self._loop))

def _ensure_client(self):
if self._client is None:
Expand All @@ -131,7 +130,6 @@ def _process_headers(self, headers: HeaderType) -> list[tuple[str, str]]:
return [(k, v) for k, v in headers.items()]
if isinstance(headers, list):
return headers
print(headers)
raise ValueError(f"Invalid headers type: {type(headers)}")

def _process_content(
Expand Down Expand Up @@ -297,13 +295,19 @@ async def stream_sse(
finally:
del self._requests[id]

async def _boot(self, loop: asyncio.AbstractEventLoop, fd: int) -> None:
async def _boot(self, loop: asyncio.AbstractEventLoop) -> None:
logger.info(f"HTTP client initialized, user_agent={self._user_agent}")
channel = rust_async_channel.RustAsyncChannel(fd, self._client._pipe, self._process_message)
try:
await channel.run()
finally:
channel.close()
channel = rust_async_channel.RustAsyncChannel(
self._client, self._process_message
)
try:
await channel.run()
finally:
channel.close()
except Exception as e:
logger.error(f"Error in HTTP client: {e}", exc_info=True)
raise

def _process_message(self, msg):
try:
Expand Down
6 changes: 6 additions & 0 deletions edb/server/http/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,12 @@ impl Http {
};
msg.to_object(py)
}

fn _close_pipe(&mut self) {
// Replace the channel with a dummy, closed one which will also
// signal the other side to exit.
(_, self.rust_to_python) = std::sync::mpsc::channel();
}
}

#[pymodule]
Expand Down
37 changes: 26 additions & 11 deletions edb/server/rust_async_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,50 @@


class RustPipeProtocol(Protocol):
def _read(self) -> Tuple[Any, ...]:
...
def _read(self) -> Tuple[Any, ...]: ...

def _try_read(self) -> Optional[Tuple[Any, ...]]:
...
def _try_read(self) -> Optional[Tuple[Any, ...]]: ...

def _close_pipe(self) -> None: ...

_fd: int


class RustAsyncChannel[T: RustPipeProtocol]:
_buffered_reader: io.BufferedReader
_buffer: bytes
_buffer: bytes
_skip_reads: int
_closed: asyncio.Event

def __init__(self, fd: int, pipe: T, callback: Callable[[], Tuple[Any, ...]]) -> None:
self._buffered_reader = io.BufferedReader(fd)
def __init__(
self, pipe: T, callback: Callable[[], Tuple[Any, ...]]
) -> None:
fd = pipe._fd
self._buffered_reader = io.BufferedReader(io.FileIO(fd))
self._fd = fd
self._pipe = pipe
self._callback = callback
self._buffer = bytes(MAX_BATCH_SIZE)
self._buffer = bytearray(MAX_BATCH_SIZE)
self._skip_reads = 0
self._closed = asyncio.Event()

def __del__(self):
if not self._closed.is_set():
logger.error(f"RustAsyncChannel {id(self)} was not closed")

async def run(self):
asyncio.add_reader(self._fd, self._channel_read)
await self._closed.wait()
loop = asyncio.get_running_loop()
loop.add_reader(self._fd, self._channel_read)
try:
await self._closed.wait()
finally:
loop.remove_reader(self._fd)

def close(self):
if not self._closed.is_set():
self._pipe._close_pipe()
self._buffered_reader.close()
self._closed.set()
asyncio.remove_reader(self._fd)

def read_hint(self):
while msg := self._pipe._try_read():
Expand Down
1 change: 0 additions & 1 deletion tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ async def client_task():

assert is_closed

@unittest.skip("Hangs on CI")
async def test_sse_with_mock_server_close(self):
"""Try to close the server-side stream and see if the client detects
an end for the iterator. Note that this is technically not correct SSE:
Expand Down

0 comments on commit fa6bd8f

Please sign in to comment.