Skip to content

Commit

Permalink
Add expiry for Step 5 to async and multithreading server
Browse files Browse the repository at this point in the history
  • Loading branch information
ngokchaoho committed Dec 17, 2023
1 parent 33d427f commit 66bb439
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 20 deletions.
40 changes: 33 additions & 7 deletions pyredis/__main__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,42 @@
import argparse
import asyncio
import trio
import logging
import threading
from time import sleep

from pyredis.server import Server
from pyredis.asyncserver import RedisServerProtocol
import logging
from pyredis.datastore import DataStore


REDIS_DEFAULT_PORT = 6379
log = logging.getLogger("pyredis")


def check_expiry_task(datastore):
while True:
datastore.remove_expired_keys()
sleep(0.1)


async def acheck_expiry_task(datastore):
while True:
datastore.remove_expired_keys()
await asyncio.sleep(0.1)


async def amain(args):
log.info(f"Starting Pyredis on port: {args.port}")

datastore = DataStore()

loop = asyncio.get_running_loop()

monitor_task = loop.create_task(acheck_expiry_task(datastore))

server = await loop.create_server(
lambda: RedisServerProtocol(), "127.0.0.1", args.port
lambda: RedisServerProtocol(datastore), "127.0.0.1", args.port
)

async with server:
Expand All @@ -27,9 +49,13 @@ async def tmain(args):


def main(args):
print(f"Starting PyRedis on port: {args.port}")
log.info(f"Starting PyRedis on port: {args.port}")

server = Server(args.port)
datastore = DataStore()
expiration_monitor = threading.Thread(target=check_expiry_task, args=(datastore,))
expiration_monitor.start()

server = Server(args.port, datastore)
server.run()


Expand Down Expand Up @@ -57,11 +83,11 @@ def main(args):
logging.basicConfig(level=args.loglevel)

if args.asyncio:
logging.info("Using AsyncIO RedisServerProtocol")
log.info("Using AsyncIO RedisServerProtocol")
asyncio.run(amain(args))
elif args.trio:
logging.info("Using Trio Stream API")
log.info("Using Trio Stream API")
trio.run(tmain, args)
else:
logging.info("Using threading module for multi-threading")
log.info("Using threading module for multi-threading")
main(args)
4 changes: 2 additions & 2 deletions pyredis/asyncserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@


class RedisServerProtocol(asyncio.Protocol):
def __init__(self):
def __init__(self, datastore):
self.buffer = bytearray()
self._datastore = DataStore()
self._datastore = datastore

def connection_made(self, transport):
self.transport = transport
Expand Down
1 change: 0 additions & 1 deletion pyredis/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def main(args):
while True:
data = client_socket.recv(RECV_SIZE)
buffer.extend(data)

frame, frame_size = extract_frame_from_buffer(buffer)

if frame:
Expand Down
40 changes: 34 additions & 6 deletions pyredis/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from typing import Any
from time import time

import random
import logging


EXPIRY_TEST_SAMPLE_SIZE = 20
log = logging.getLogger("pyredis")


@dataclass
class DataEntry:
Expand Down Expand Up @@ -30,12 +37,12 @@ def __init__(self, initial_data=None):

def __getitem__(self, key):
with self._lock:
log.info("Try to get key %s", key)
item = self._data[key]

log.info("key exist %s, checking expiry", key)
# if key expired
if item.expiry and item.expiry < int(time() * 1000):
del self._data[key]
raise KeyError
if self.check_expiry(key, item):
raise KeyError # catched in _handle_get

return item.value

Expand All @@ -48,5 +55,26 @@ def set_with_expiry(self, key, value, expiry: int):
calculated_expiry = int(time() * 1000) + expiry # in miliseconds
self._data[key] = DataEntry(value, calculated_expiry)

def check_expiry(datastore):
pass
def check_expiry(self, key: str, value: DataEntry) -> bool:
# if key expired then delete
if value.expiry and value.expiry < int(time() * 1000):
log.info("%s key expired", key)
del self._data[key]
return True
else:
return False

def remove_expired_keys(self):
expired_count = 0
with self._lock:
keys = random.sample(
sorted(self._data), min(EXPIRY_TEST_SAMPLE_SIZE, len(self._data))
)

for key in keys:
if self.check_expiry(key, self._data[key]):
expired_count += 1
# after release lock
# if more than
if expired_count > EXPIRY_TEST_SAMPLE_SIZE * 0.25:
self.remove_expired_keys()
2 changes: 1 addition & 1 deletion pyredis/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def extract_frame_from_buffer(buffer):
data_size = int(paylood)
# NULL bulk String
if data_size == -1:
return None, 5
return BulkString(None), 5
content_saparator = buffer.find(_MSG_SEPARATOR, separator + 1)
if (
data_size >= 0
Expand Down
4 changes: 2 additions & 2 deletions pyredis/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@


class Server:
def __init__(self, port) -> None:
def __init__(self, port, datastore) -> None:
self.port = port
self._running = False
self._datastore = DataStore()
self._datastore = datastore

def run(self):
self._running = True
Expand Down
2 changes: 1 addition & 1 deletion tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
(b"$12\r\nHello, World\r\n", (BulkString(b"Hello, World"), 19)),
(b"$12\r\nHello\r\nWorld\r\n", (BulkString(b"Hello\r\nWorld"), 19)),
(b"$0\r\n\r\n", (BulkString(b""), 6)),
(b"$-1\r\n", (None, 5)),
(b"$-1\r\n", (BulkString(None), 5)),
# Test case for Arrays
(b"*0", (None, 0)),
(b"*0\r\n", (Array([]), 4)),
Expand Down

0 comments on commit 66bb439

Please sign in to comment.