diff --git a/.gitignore b/.gitignore index 0d20b64..6da8eed 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ *.pyc +*.aof diff --git a/pyredis/__main__.py b/pyredis/__main__.py index 664c20e..34fee5f 100644 --- a/pyredis/__main__.py +++ b/pyredis/__main__.py @@ -9,6 +9,7 @@ from pyredis.asyncserver import RedisServerProtocol from pyredis.trioserver import TrioServer from pyredis.datastore import DataStore +from pyredis.persistence import AppendOnlyPersister REDIS_DEFAULT_PORT = 6379 @@ -18,13 +19,13 @@ def check_expiry_task(datastore): while True: datastore.remove_expired_keys() - sleep(0.1) + sleep(1) async def acheck_expiry_task(datastore): while True: datastore.remove_expired_keys() - await asyncio.sleep(0.1) + await asyncio.sleep(1) async def amain(args): @@ -34,10 +35,12 @@ async def amain(args): loop = asyncio.get_running_loop() - monitor_task = loop.create_task(acheck_expiry_task(datastore)) + loop.create_task(acheck_expiry_task(datastore)) + + persister = AppendOnlyPersister("ccdb.aof") server = await loop.create_server( - lambda: RedisServerProtocol(datastore), "127.0.0.1", args.port + lambda: RedisServerProtocol(datastore, persister), "127.0.0.1", args.port ) async with server: @@ -53,6 +56,7 @@ def main(args): log.info(f"Starting PyRedis on port: {args.port}") datastore = DataStore() + expiration_monitor = threading.Thread(target=check_expiry_task, args=(datastore,)) expiration_monitor.start() diff --git a/pyredis/asyncserver.py b/pyredis/asyncserver.py index b3b132f..7de043b 100644 --- a/pyredis/asyncserver.py +++ b/pyredis/asyncserver.py @@ -5,9 +5,10 @@ class RedisServerProtocol(asyncio.Protocol): - def __init__(self, datastore): + def __init__(self, datastore, persister): self.buffer = bytearray() self._datastore = datastore + self._persister = persister def connection_made(self, transport): self.transport = transport @@ -22,5 +23,5 @@ def data_received(self, data): if frame: self.buffer = self.buffer[frame_size:] - result = handle_command(frame, self._datastore) + result = handle_command(frame, self._datastore, self._persister) self.transport.write(encode_message(result)) diff --git a/pyredis/commands.py b/pyredis/commands.py index 7aa0e2c..2ecae62 100644 --- a/pyredis/commands.py +++ b/pyredis/commands.py @@ -21,7 +21,7 @@ def _handle_ping(command): return Error(data="ERR wrong number of arguments for 'ping' command") -def _handle_set(command, datastore): +def _handle_set(command, datastore, persister): length = len(command) if length >= 3: key = command[1].data.decode() @@ -29,6 +29,8 @@ def _handle_set(command, datastore): if length == 3: datastore[key] = value + if persister: + persister.log_command(command) return SimpleString("OK") elif length == 5: expiry_mode = command[3].data.decode() @@ -39,9 +41,13 @@ def _handle_set(command, datastore): if expiry_mode == "ex": datastore.set_with_expiry(key, value, expiry * 1000) + if persister: + persister.log_command(command) return SimpleString("OK") elif expiry_mode == "px": datastore.set_with_expiry(key, value, expiry) + if persister: + persister.log_command(command) return SimpleString("OK") return Error("ERR syntax error") @@ -70,41 +76,47 @@ def _handle_exists(command, datastore): return Error("ERR wrong number of arguments for 'exists' command") -def _handle_del(command, datastore): +def _handle_del(command, datastore, persister): if len(command) >= 2: found = 0 for key in command[1:]: if key.data.decode() in datastore: del datastore._data[key.data.decode()] found += 1 + if persister: + persister.log_command(command) return Integer(found) else: return Error("ERR wrong number of arguments for 'del' command") -def _handle_incr(command, datastore): +def _handle_incr(command, datastore, persister): if len(command) == 2: key = command[1].data.decode() try: value = datastore.incr(key) + if persister: + persister.log_command(command) + return Integer(value) except TypeError: return Error("ERR value is not an integer or out of range") - return Integer(value) return Error("ERR wrong number of arguments for 'incr' command") -def _handle_decr(command, datastore): +def _handle_decr(command, datastore, persister): if len(command) == 2: key = command[1].data.decode() try: value = datastore.decr(key) + if persister: + persister.log_command(command) + return Integer(value) except TypeError: return Error("ERR value is not an integer or out of range") - return Integer(value) return Error("ERR wrong number of arguments for 'decr' command") -def _handle_lpush(command, datastore): +def _handle_lpush(command, datastore, persister): if len(command) >= 2: count = 0 key = command[1].data.decode() @@ -113,6 +125,8 @@ def _handle_lpush(command, datastore): for c in command[2:]: item = c.data.decode() count = datastore.prepend(key, item) + if persister: + persister.log_command(command) return Integer(count) except TypeError: return Error( @@ -138,7 +152,7 @@ def _handle_lrange(command, datastore): return Error("ERR wrong number of arguments for 'lrange' command") -def _handle_rpush(command, datastore): +def _handle_rpush(command, datastore, persister): if len(command) >= 2: count = 0 key = command[1].data.decode() @@ -147,6 +161,8 @@ def _handle_rpush(command, datastore): for c in command[2:]: item = c.data.decode() count = datastore.append(key, item) + if persister: + persister.log_command(command) return Integer(count) except TypeError: return Error( @@ -162,7 +178,7 @@ def _handle_unrecognised_command(command, *args): ) -def handle_command(command, datastore): +def handle_command(command, datastore, persister): match command[0].data.decode().upper(): case "ECHO": return _handle_echo(command) @@ -171,21 +187,21 @@ def handle_command(command, datastore): return _handle_ping(command) case "SET": - return _handle_set(command, datastore) + return _handle_set(command, datastore, persister) case "GET": return _handle_get(command, datastore) case "EXISTS": return _handle_exists(command, datastore) case "DEL": - return _handle_del(command, datastore) + return _handle_del(command, datastore, persister) case "INCR": - return _handle_incr(command, datastore) + return _handle_incr(command, datastore, persister) case "DECR": - return _handle_decr(command, datastore) + return _handle_decr(command, datastore, persister) case "LPUSH": - return _handle_lpush(command, datastore) + return _handle_lpush(command, datastore, persister) case "RPUSH": - return _handle_rpush(command, datastore) + return _handle_rpush(command, datastore, persister) case "LRANGE": return _handle_lrange(command, datastore) return _handle_unrecognised_command(command) diff --git a/pyredis/persistence.py b/pyredis/persistence.py new file mode 100644 index 0000000..2d1639e --- /dev/null +++ b/pyredis/persistence.py @@ -0,0 +1,10 @@ +class AppendOnlyPersister: + def __init__(self, filename): + self._filename = filename + self._file = open(filename, mode="ab", buffering=0) + + def log_command(self, command): + self._file.write(f"*{len(command)}\r\n".encode()) + + for item in command: + self._file.write(item.resp_encode()) diff --git a/pyredis/types.py b/pyredis/types.py index 84c90a1..9638313 100644 --- a/pyredis/types.py +++ b/pyredis/types.py @@ -35,7 +35,14 @@ def resp_encode(self): if self.data is None: return "$-1\r\n".encode() else: - return f"${len(self.data)}\r\n{self.data}\r\n".encode() + if isinstance(self.data, str): + return f"${len(self.data)}\r\n{self.data}\r\n".encode() + else: + return ( + f"${len(self.data)}\r\n".encode() + + bytes(self.data) + + "\r\n".encode() + ) @dataclass diff --git a/tests/test_commands.py b/tests/test_commands.py index 4c6d504..7be5693 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -2,6 +2,7 @@ from time import sleep, time_ns from pyredis.commands import handle_command +from pyredis.persistence import AppendOnlyPersister from pyredis.datastore import DataStore from pyredis.types import Array, BulkString, Error, Integer, SimpleString @@ -15,6 +16,12 @@ def datastore(): return datastore +@pytest.fixture(scope="module") +def persister(): + persister = AppendOnlyPersister("test.aof") + return persister + + @pytest.mark.parametrize( "command, expected", [ @@ -145,12 +152,12 @@ def datastore(): ), ], ) -def test_handle_command(command, expected, datastore): - result = handle_command(command, datastore) +def test_handle_command(command, expected, datastore, persister): + result = handle_command(command, datastore, persister) assert result == expected -def test_get_with_expiry(datastore): +def test_get_with_expiry(datastore, persister): key = "key" value = "value" px = 100 @@ -162,15 +169,15 @@ def test_get_with_expiry(datastore): BulkString(b"px"), BulkString(f"{px}".encode()), ] - result = handle_command(command, datastore) + result = handle_command(command, datastore, persister) assert result == SimpleString("OK") sleep((px + 100) / 1000) command = [BulkString(b"get"), SimpleString(b"key")] - result = handle_command(command, datastore) + result = handle_command(command, datastore, persister) assert result == BulkString(None) -def test_set_with_expiry(): +def test_set_with_expiry(persister): datastore = DataStore() key = "key" value = "value" @@ -183,7 +190,7 @@ def test_set_with_expiry(): command = base_command[:] command.extend([BulkString(b"ex"), BulkString(f"{ex}".encode())]) expected_expiry = time_ns() + (ex * 10**9) - result = handle_command(command, datastore) + result = handle_command(command, datastore, persister) assert result == SimpleString("OK") stored = datastore._data[key] assert stored.value == value @@ -194,7 +201,7 @@ def test_set_with_expiry(): command = base_command[:] command.extend([BulkString(b"px"), BulkString(f"{px}".encode())]) expected_expiry = time_ns() + (ex * 10**6) - result = handle_command(command, datastore) + result = handle_command(command, datastore, persister) assert result == SimpleString("OK") stored = datastore._data[key] assert stored.value == value @@ -202,7 +209,7 @@ def test_set_with_expiry(): assert diff < 10000 -def test_get_with_expiry(): +def test_get_with_expiry(persister): datastore = DataStore() key = "key" value = "value" @@ -215,11 +222,11 @@ def test_get_with_expiry(): BulkString(b"px"), BulkString(f"{px}".encode()), ] - result = handle_command(command, datastore) + result = handle_command(command, datastore, persister) assert result == SimpleString("OK") sleep((px + 100) / 1000) command = [BulkString(b"get"), SimpleString(b"key")] - result = handle_command(command, datastore) + result = handle_command(command, datastore, persister) assert result == BulkString(None) @@ -236,50 +243,52 @@ def test_expire_on_read(datastore): # Incr Tests -def test_handle_incr_command_valid_key(): +def test_handle_incr_command_valid_key(persister): datastore = DataStore() result = handle_command( - Array([BulkString(b"incr"), SimpleString(b"ki")]), datastore + Array([BulkString(b"incr"), SimpleString(b"ki")]), datastore, persister ) assert result == Integer(1) result = handle_command( - Array([BulkString(b"incr"), SimpleString(b"ki")]), datastore + Array([BulkString(b"incr"), SimpleString(b"ki")]), datastore, persister ) assert result == Integer(2) # Decr Tests -def test_handle_decr(): +def test_handle_decr(persister): datastore = DataStore() result = handle_command( - Array([BulkString(b"incr"), SimpleString(b"kd")]), datastore + Array([BulkString(b"incr"), SimpleString(b"kd")]), datastore, persister ) assert result == Integer(1) result = handle_command( - Array([BulkString(b"incr"), SimpleString(b"kd")]), datastore + Array([BulkString(b"incr"), SimpleString(b"kd")]), datastore, persister ) assert result == Integer(2) result = handle_command( - Array([BulkString(b"decr"), SimpleString(b"kd")]), datastore + Array([BulkString(b"decr"), SimpleString(b"kd")]), datastore, persister ) assert result == Integer(1) result = handle_command( - Array([BulkString(b"decr"), SimpleString(b"kd")]), datastore + Array([BulkString(b"decr"), SimpleString(b"kd")]), datastore, persister ) assert result == Integer(0) # Lpush Tests -def test_handle_lpush_lrange(): +def test_handle_lpush_lrange(persister): datastore = DataStore() result = handle_command( Array([BulkString(b"lpush"), SimpleString(b"klp"), SimpleString(b"second")]), datastore, + persister, ) assert result == Integer(1) result = handle_command( Array([BulkString(b"lpush"), SimpleString(b"klp"), SimpleString(b"first")]), datastore, + persister, ) assert result == Integer(2) result = handle_command( @@ -292,21 +301,24 @@ def test_handle_lpush_lrange(): ] ), datastore, + persister, ) assert result == Array(data=[BulkString("first"), BulkString("second")]) # Rpush Tests -def test_handle_rpush_lrange(): +def test_handle_rpush_lrange(persister): datastore = DataStore() result = handle_command( Array([BulkString(b"rpush"), SimpleString(b"krp"), SimpleString(b"first")]), datastore, + persister, ) assert result == Integer(1) result = handle_command( Array([BulkString(b"rpush"), SimpleString(b"krp"), SimpleString(b"second")]), datastore, + persister, ) assert result == Integer(2) result = handle_command( @@ -319,6 +331,7 @@ def test_handle_rpush_lrange(): ] ), datastore, + persister, ) assert result == Array(data=[BulkString("first"), BulkString("second")])