From e635682fdfbba74ffa37c657108412cd2ddbe953 Mon Sep 17 00:00:00 2001 From: janbjorge Date: Tue, 20 Feb 2024 13:36:30 +0100 Subject: [PATCH] Deferred Listener Connection and README Update (#5) --- .github/workflows/ci.yml | 9 ++++ README.md | 40 ++++++++++++------ src/pgcachewatch/listeners.py | 80 ++++++++++++++++++++--------------- tests/test_decoraters.py | 6 +-- tests/test_fastapi.py | 3 +- tests/test_integration.py | 21 ++++----- tests/test_listeners.py | 5 ++- tests/test_strategies.py | 11 +++-- tests/test_utils.py | 17 +++----- 9 files changed, 115 insertions(+), 77 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 27f232a..89357cf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,6 +9,7 @@ on: jobs: ci: strategy: + fail-fast: false matrix: python-version: ["3.10", "3.11", "3.12"] postgres-version: ["14", "15", "16"] @@ -54,3 +55,11 @@ jobs: - name: Full test run: pytest -v + + check: + name: Check test matrix passed. + needs: ci + runs-on: ubuntu-latest + steps: + - name: Check status + run: echo "All tests passed; ready to merge." diff --git a/README.md b/README.md index d43b103..4762b92 100644 --- a/README.md +++ b/README.md @@ -29,24 +29,40 @@ pgcachewatch install Example showing how to use PGCacheWatch for cache invalidation in a FastAPI app ```python +import contextlib +import typing + import asyncpg from fastapi import FastAPI from pgcachewatch import decorators, listeners, models, strategies -app = FastAPI() +listener = listeners.PGEventQueue() + -async def setup_app(channel: models.PGChannel) -> FastAPI: +@contextlib.asynccontextmanager +async def app_setup_teardown(_: FastAPI) -> typing.AsyncGenerator[None, None]: conn = await asyncpg.connect() - listener = await listeners.PGEventQueue.create(channel, conn) + await listener.connect(conn, models.PGChannel("ch_pgcachewatch_table_change")) + yield + await conn.close() + - @decorators.cache(strategy=strategies.Greedy(listener=listener)) - async def cached_query(): - # Simulate a database query - return {"data": "query result"} +APP = FastAPI(lifespan=app_setup_teardown) - @app.get("/data") - async def get_data(): - return await cached_query() - return app -``` \ No newline at end of file +# Only allow for cache refresh after an update +@decorators.cache( + strategy=strategies.Gready( + listener=listener, + predicate=lambda x: x.operation == "update", + ) +) +async def cached_query() -> dict[str, str]: + # Simulate a database query + return {"data": "query result"} + + +@APP.get("/data") +async def get_data() -> dict: + return await cached_query() +``` diff --git a/src/pgcachewatch/listeners.py b/src/pgcachewatch/listeners.py index 791e08e..9155cbc 100644 --- a/src/pgcachewatch/listeners.py +++ b/src/pgcachewatch/listeners.py @@ -12,7 +12,7 @@ def _critical_termination_listener(*_: object, **__: object) -> None: # Must be defined in the global namespace, as ayncpg keeps # a set of functions to call. This this will now happen once as # all instance will point to the same function. - logging.critical("Connection is closed / terminated!") + logging.critical("Connection is closed / terminated.") class PGEventQueue(asyncio.Queue[models.Event]): @@ -23,48 +23,59 @@ class PGEventQueue(asyncio.Queue[models.Event]): def __init__( self, - pgchannel: models.PGChannel, - pgconn: asyncpg.Connection, max_size: int = 0, max_latency: datetime.timedelta = datetime.timedelta(milliseconds=500), - _called_by_create: bool = False, ) -> None: - """ - Initializes the PGEventQueue instance. Use the create() classmethod to - instantiate. - """ - if not _called_by_create: - raise RuntimeError( - "Use classmethod create(...) to instantiate PGEventQueue." - ) super().__init__(maxsize=max_size) - self._pg_channel = pgchannel - self._pg_connection = pgconn + self._pg_channel: None | models.PGChannel = None + self._pg_connection: None | asyncpg.Connection = None self._max_latency = max_latency - @classmethod - async def create( - cls, - pgchannel: models.PGChannel, - pgconn: asyncpg.Connection, - maxsize: int = 0, - max_latency: datetime.timedelta = datetime.timedelta(milliseconds=500), - ) -> "PGEventQueue": - """ - Creates and initializes a new PGEventQueue instance, connecting to the specified - PostgreSQL channel. Returns the initialized PGEventQueue instance. + async def connect( + self, + connection: asyncpg.Connection, + channel: models.PGChannel, + ) -> None: """ - me = cls( - pgchannel=pgchannel, - pgconn=pgconn, - max_size=maxsize, - max_latency=max_latency, - _called_by_create=True, + Asynchronously connects the PGEventQueue to a specified + PostgreSQL channel and connection. + + This method establishes a listener on a PostgreSQL channel + using the provided connection. It is designed to be called + once per PGEventQueue instance to ensure a one-to-one relationship + between the event queue and a database channel. If an attempt is + made to connect a PGEventQueue instance to more than one channel + or connection, a RuntimeError is raised to enforce this constraint. + + Parameters: + - connection: asyncpg.Connection + The asyncpg connection object to be used for listening to database events. + - channel: models.PGChannel + The database channel to listen on for events. + + Raises: + - RuntimeError: If the PGEventQueue is already connected to a + channel or connection. + + Usage: + ```python + await pg_event_queue.connect( + connection=your_asyncpg_connection, + channel=your_pg_channel, ) - me._pg_connection.add_termination_listener(_critical_termination_listener) - await me._pg_connection.add_listener(me._pg_channel, me.parse_and_put) # type: ignore[arg-type] + ``` + """ + if self._pg_channel or self._pg_connection: + raise RuntimeError( + "PGEventQueue instance is already connected to a channel and/or " + "connection. Only supports one channel and connection per " + "PGEventQueue instance." + ) - return me + self._pg_channel = channel + self._pg_connection = connection + self._pg_connection.add_termination_listener(_critical_termination_listener) + await self._pg_connection.add_listener(self._pg_channel, self.parse_and_put) # type: ignore[arg-type] def parse_and_put( self, @@ -87,6 +98,7 @@ def parse_and_put( except Exception: logging.exception("Unable to parse `%s`.", payload) else: + logging.info("Received event: %s on %s", parsed, channel) try: self.put_nowait(parsed) except Exception: diff --git a/tests/test_decoraters.py b/tests/test_decoraters.py index 9e6f062..01a2ec6 100644 --- a/tests/test_decoraters.py +++ b/tests/test_decoraters.py @@ -10,10 +10,8 @@ @pytest.mark.parametrize("N", (4, 16, 64, 512)) async def test_gready_cache_decorator(N: int, pgconn: asyncpg.Connection) -> None: statistics = collections.Counter[str]() - listener = await listeners.PGEventQueue.create( - models.PGChannel("test_cache_decorator"), - pgconn=pgconn, - ) + listener = listeners.PGEventQueue() + await listener.connect(pgconn, models.PGChannel("test_cache_decorator")) @decorators.cache( strategy=strategies.Gready(listener=listener), diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index b2609c2..430c166 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -14,7 +14,8 @@ async def fastapitestapp( ) -> fastapi.FastAPI: app = fastapi.FastAPI() - listener = await listeners.PGEventQueue.create(channel, pgconn) + listener = listeners.PGEventQueue() + await listener.connect(pgconn, channel) @decorators.cache(strategy=strategies.Gready(listener=listener)) async def slow_db_read() -> dict: diff --git a/tests/test_integration.py b/tests/test_integration.py index 66a6cde..5d76109 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -37,10 +37,8 @@ async def test_2_caching( pgpool: asyncpg.Pool, ) -> None: statistics = collections.Counter[str]() - listener = await listeners.PGEventQueue.create( - models.PGChannel("test_2_caching"), - pgconn=pgconn, - ) + listener = listeners.PGEventQueue() + await listener.connect(pgconn, models.PGChannel("test_2_caching")) cnt = 0 @@ -64,9 +62,10 @@ async def test_3_cache_invalidation_update( pgpool: asyncpg.Pool, ) -> None: statistics = collections.Counter[str]() - listener = await listeners.PGEventQueue.create( + listener = listeners.PGEventQueue() + await listener.connect( + pgconn, models.PGChannel("ch_pgcachewatch_table_change"), - pgconn=pgconn, ) @decorators.cache( @@ -97,9 +96,10 @@ async def test_3_cache_invalidation_insert( pgpool: asyncpg.Pool, ) -> None: statistics = collections.Counter[str]() - listener = await listeners.PGEventQueue.create( + listener = listeners.PGEventQueue() + await listener.connect( + pgconn, models.PGChannel("ch_pgcachewatch_table_change"), - pgconn=pgconn, ) @decorators.cache( @@ -131,9 +131,10 @@ async def test_3_cache_invalidation_delete( pgpool: asyncpg.Pool, ) -> None: statistics = collections.Counter[str]() - listener = await listeners.PGEventQueue.create( + listener = listeners.PGEventQueue() + await listener.connect( + pgconn, models.PGChannel("ch_pgcachewatch_table_change"), - pgconn=pgconn, ) @decorators.cache( diff --git a/tests/test_listeners.py b/tests/test_listeners.py index a29b084..2c32006 100644 --- a/tests/test_listeners.py +++ b/tests/test_listeners.py @@ -16,7 +16,8 @@ async def test_eventqueue_and_pglistner( pgpool: asyncpg.Pool, ) -> None: channel = models.PGChannel(f"test_eventqueue_and_pglistner_{N}_{operation}") - eq = await listeners.PGEventQueue.create(channel, pgconn) + listener = listeners.PGEventQueue() + await listener.connect(pgconn, channel) for _ in range(N): await utils.emit_event( @@ -32,7 +33,7 @@ async def test_eventqueue_and_pglistner( evnets = list[models.Event]() while True: try: - evnets.append(eq.get_nowait()) + evnets.append(listener.get_nowait()) except asyncio.QueueEmpty: break diff --git a/tests/test_strategies.py b/tests/test_strategies.py index b4e0309..fd5c39c 100644 --- a/tests/test_strategies.py +++ b/tests/test_strategies.py @@ -9,7 +9,10 @@ @pytest.mark.parametrize("N", (4, 16, 64)) async def test_gready_strategy(N: int, pgconn: asyncpg.Connection) -> None: channel = models.PGChannel("test_gready_strategy") - listener = await listeners.PGEventQueue.create(channel, pgconn) + + listener = listeners.PGEventQueue() + await listener.connect(pgconn, channel) + strategy = strategies.Gready( listener=listener, predicate=lambda e: e.operation == "insert", @@ -47,7 +50,8 @@ async def test_windowed_strategy( pgconn: asyncpg.Connection, ) -> None: channel = models.PGChannel("test_windowed_strategy") - listener = await listeners.PGEventQueue.create(channel, pgconn) + listener = listeners.PGEventQueue() + await listener.connect(pgconn, channel) strategy = strategies.Windowed( listener=listener, window=["insert", "update", "delete"] ) @@ -111,7 +115,8 @@ async def test_timed_strategy( pgconn: asyncpg.Connection, ) -> None: channel = models.PGChannel("test_timed_strategy") - listener = await listeners.PGEventQueue.create(channel, pgconn) + listener = listeners.PGEventQueue() + await listener.connect(pgconn, channel) strategy = strategies.Timed(listener=listener, timedelta=dt) # Bursed spaced out accoring to min dt req. to trigger a refresh. diff --git a/tests/test_utils.py b/tests/test_utils.py index 693666a..09064cc 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,9 +17,8 @@ async def test_emit_event( pgpool: asyncpg.Pool, ) -> None: channel = "test_emit_event" - listener = await listeners.PGEventQueue.create( - models.PGChannel(channel), pgconn=pgconn - ) + listener = listeners.PGEventQueue() + await listener.connect(pgconn, models.PGChannel(channel)) await asyncio.gather( *[ utils.emit_event( @@ -47,10 +46,8 @@ async def test_pick_until_deadline_max_iter( pgconn: asyncpg.Connection, ) -> None: channel = "test_pick_until_deadline_max_iter" - listener = await listeners.PGEventQueue.create( - models.PGChannel(channel), - pgconn=pgconn, - ) + listener = listeners.PGEventQueue() + await listener.connect(pgconn, models.PGChannel(channel)) items = list(range(max_iter * 2)) for item in items: @@ -87,10 +84,8 @@ async def test_pick_until_deadline_max_time( pgconn: asyncpg.Connection, ) -> None: channel = "test_pick_until_deadline_max_time" - listener = await listeners.PGEventQueue.create( - models.PGChannel(channel), - pgconn=pgconn, - ) + listener = listeners.PGEventQueue() + await listener.connect(pgconn, models.PGChannel(channel)) x = -1