diff --git a/pyredis/__main__.py b/pyredis/__main__.py index 1e1ec60..a2dcf5a 100644 --- a/pyredis/__main__.py +++ b/pyredis/__main__.py @@ -1,5 +1,6 @@ import argparse import asyncio +import trio from pyredis.server import Server from pyredis.asyncserver import RedisServerProtocol @@ -20,6 +21,11 @@ async def amain(args): await server.serve_forever() +async def tmain(args): + server = Server(args.port) + await server.run() + + def main(args): print(f"Starting PyRedis on port: {args.port}") @@ -38,6 +44,7 @@ def main(args): default=REDIS_DEFAULT_PORT, ) parser.add_argument("--asyncio", action=argparse.BooleanOptionalAction) + parser.add_argument("--trio", action=argparse.BooleanOptionalAction) parser.add_argument( "-v", "--verbose", @@ -52,6 +59,9 @@ def main(args): if args.asyncio: logging.info("Using AsyncIO RedisServerProtocol") asyncio.run(amain(args)) + elif args.trio: + logging.info("Using Trio Stream API") + trio.run(tmain, args) else: logging.info("Using threading module for multi-threading") main(args) diff --git a/pyredis/trioserver.py b/pyredis/trioserver.py new file mode 100644 index 0000000..dfdfc1a --- /dev/null +++ b/pyredis/trioserver.py @@ -0,0 +1,53 @@ +from trio import SocketListener, serve_tcp, SocketStream +import logging +import trio + +from pyredis.protocol import extract_frame_from_buffer, encode_message +from pyredis.commands import handle_command +from pyredis.datastore import DataStore + +RECV_SIZE = 2048 +log = logging.getLogger("pyredis") + + +class Server: + def __init__(self, port) -> None: + self.port = port + self._running = False + self._datastore = DataStore() + + async def run(self): + self._running = True + + async with trio.open_nursery() as nursery: + await nursery.start( + serve_tcp, + self.handle_client_connection, + port=self.port, + host="127.0.0.1", + ) + + async def handle_client_connection(self, client_stream: SocketStream): + buffer = bytearray() + try: + while True: + data = await client_stream.receive_some(RECV_SIZE) + log.info("Received data from client") + if not data: + log.info("Readched EOF") + break + buffer.extend(data) + frame, frame_size = extract_frame_from_buffer(buffer) + log.info("Extracted one frame from received data") + if frame: + buffer = buffer[frame_size:] + log.info("Processing one frame") + result = handle_command(frame, self._datastore) + await client_stream.send_all(encode_message(result)) + + finally: + log.info("Attempt to close stream") + await client_stream.aclose() + + def stop(self): + self._running = False