Skip to content

Commit

Permalink
Add trio for step 4
Browse files Browse the repository at this point in the history
  • Loading branch information
ngokchaoho committed Dec 17, 2023
1 parent 2293ebf commit a02ca33
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
10 changes: 10 additions & 0 deletions pyredis/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import asyncio
import trio

from pyredis.server import Server
from pyredis.asyncserver import RedisServerProtocol
Expand All @@ -20,6 +21,11 @@ async def amain(args):
await server.serve_forever()


async def tmain(args):
server = Server(args.port)
await server.run()


def main(args):
print(f"Starting PyRedis on port: {args.port}")

Expand All @@ -38,6 +44,7 @@ def main(args):
default=REDIS_DEFAULT_PORT,
)
parser.add_argument("--asyncio", action=argparse.BooleanOptionalAction)
parser.add_argument("--trio", action=argparse.BooleanOptionalAction)
parser.add_argument(
"-v",
"--verbose",
Expand All @@ -52,6 +59,9 @@ def main(args):
if args.asyncio:
logging.info("Using AsyncIO RedisServerProtocol")
asyncio.run(amain(args))
elif args.trio:
logging.info("Using Trio Stream API")
trio.run(tmain, args)
else:
logging.info("Using threading module for multi-threading")
main(args)
53 changes: 53 additions & 0 deletions pyredis/trioserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from trio import SocketListener, serve_tcp, SocketStream
import logging
import trio

from pyredis.protocol import extract_frame_from_buffer, encode_message
from pyredis.commands import handle_command
from pyredis.datastore import DataStore

RECV_SIZE = 2048
log = logging.getLogger("pyredis")


class Server:
def __init__(self, port) -> None:
self.port = port
self._running = False
self._datastore = DataStore()

async def run(self):
self._running = True

async with trio.open_nursery() as nursery:
await nursery.start(
serve_tcp,
self.handle_client_connection,
port=self.port,
host="127.0.0.1",
)

async def handle_client_connection(self, client_stream: SocketStream):
buffer = bytearray()
try:
while True:
data = await client_stream.receive_some(RECV_SIZE)
log.info("Received data from client")
if not data:
log.info("Readched EOF")
break
buffer.extend(data)
frame, frame_size = extract_frame_from_buffer(buffer)
log.info("Extracted one frame from received data")
if frame:
buffer = buffer[frame_size:]
log.info("Processing one frame")
result = handle_command(frame, self._datastore)
await client_stream.send_all(encode_message(result))

finally:
log.info("Attempt to close stream")
await client_stream.aclose()

def stop(self):
self._running = False

0 comments on commit a02ca33

Please sign in to comment.