diff --git a/.gitignore b/.gitignore index c259ee7..2fcda60 100644 --- a/.gitignore +++ b/.gitignore @@ -2,9 +2,12 @@ __pycache__/ *.py[oc] build/ +test_project/ +test_project/* dist/ wheels/ *.egg-info +.DS_Store # venv .venv diff --git a/pyproject.toml b/pyproject.toml index b15afa8..a052682 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "loguru>=0.7.2", "psycopg[binary]>=3.1.10", "redis>=5.0.0", + "httpx>=0.25.0", ] readme = "README.md" requires-python = ">= 3.8" diff --git a/requirements-dev.lock b/requirements-dev.lock index 84af0dc..b3a69e8 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -10,11 +10,14 @@ anyio==4.0.0 asgiref==3.7.2 black==23.9.1 +certifi==2023.7.22 click==8.1.7 dj-database-url==2.1.0 django==4.2.5 h11==0.14.0 +httpcore==0.18.0 httptools==0.6.0 +httpx==0.25.0 idna==3.4 loguru==0.7.2 mypy-extensions==1.0.0 diff --git a/requirements.lock b/requirements.lock index f47e45f..83bff58 100644 --- a/requirements.lock +++ b/requirements.lock @@ -9,11 +9,14 @@ -e file:. anyio==4.0.0 asgiref==3.7.2 +certifi==2023.7.22 click==8.1.7 dj-database-url==2.1.0 django==4.2.5 h11==0.14.0 +httpcore==0.18.0 httptools==0.6.0 +httpx==0.25.0 idna==3.4 loguru==0.7.2 psycopg==3.1.10 diff --git a/src/sse_relay_server/config.py b/src/sse_relay_server/config.py index d544bf5..10ea6c4 100644 --- a/src/sse_relay_server/config.py +++ b/src/sse_relay_server/config.py @@ -26,3 +26,6 @@ def get_postgres_url() -> str | None: def get_redis_url() -> str | None: return os.getenv("REDIS_URL") + +def get_last_messages_endpoint_url() : + return os.getenv("LAST_MESSAGES_ENDPOINT_URL") \ No newline at end of file diff --git a/src/sse_relay_server/gateways/_postgres.py b/src/sse_relay_server/gateways/_postgres.py index 89b1152..6e6ddb2 100644 --- a/src/sse_relay_server/gateways/_postgres.py +++ b/src/sse_relay_server/gateways/_postgres.py @@ -1,12 +1,14 @@ import json +import httpx from typing import AsyncGenerator import dj_database_url import psycopg from loguru import logger from sse_starlette import ServerSentEvent +from contextlib import suppress -from ..config import ConfigurationError +from ..config import ConfigurationError, get_last_messages_endpoint_url class PostgresGateway: @@ -26,11 +28,22 @@ def __init__(self, postgres_url: str) -> None: "host": parsed_params["HOST"], } - async def listen(self, channel: str) -> AsyncGenerator[ServerSentEvent, None]: + async def listen( + self, channel: str, last_id: str | None + ) -> AsyncGenerator[ServerSentEvent, None]: connection = await psycopg.AsyncConnection.connect( **self.db_params, autocommit=True, ) + if url:=get_last_messages_endpoint_url(): + if last_id: + response = httpx.get( + f"{url}/{last_id}/" + ) + with suppress(json.JSONDecodeError): + last_messages = response.json() + async for message in last_messages: + yield ServerSentEvent(**message) async with connection.cursor() as cursor: await cursor.execute(f"LISTEN {channel}") diff --git a/src/sse_relay_server/gateways/_redis.py b/src/sse_relay_server/gateways/_redis.py index 49788d0..0ddb1ab 100644 --- a/src/sse_relay_server/gateways/_redis.py +++ b/src/sse_relay_server/gateways/_redis.py @@ -1,18 +1,32 @@ import json +import httpx from typing import AsyncGenerator import redis import redis.asyncio as async_redis from loguru import logger +from contextlib import suppress from sse_starlette import ServerSentEvent +from ..config import get_last_messages_endpoint_url class RedisGateway: def __init__(self, redis_url: str) -> None: self.redis_url = redis_url - async def listen(self, channel: str) -> AsyncGenerator[ServerSentEvent, None]: + async def listen(self, channel: str, last_id: str|None) -> AsyncGenerator[ServerSentEvent, None]: r = async_redis.from_url(self.redis_url) + + if url:=get_last_messages_endpoint_url(): + if last_id: + response = httpx.get( + f"{url}/{last_id}/" + ) + with suppress(json.JSONDecodeError): + last_messages = response.json() + async for message in last_messages: + yield ServerSentEvent(**message) + async with r.pubsub() as pubsub: await pubsub.subscribe(channel) while True: diff --git a/src/sse_relay_server/main.py b/src/sse_relay_server/main.py index 47771b0..60320fe 100644 --- a/src/sse_relay_server/main.py +++ b/src/sse_relay_server/main.py @@ -29,8 +29,9 @@ async def generate_stop_event(): async def sse(request: Request): + last_id = request.query_params.get("LAST-EVENT-ID") if channel := request.query_params.get("channel"): - return EventSourceResponse(gateway.listen(channel)) + return EventSourceResponse(gateway.listen(channel, last_id)) else: return EventSourceResponse(generate_stop_event())