Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Redis pub/sub subscribe errors #147

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions broadcaster/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Broadcast:
def __init__(self, url: str | None = None, *, backend: BroadcastBackend | None = None) -> None:
assert url or backend, "Either `url` or `backend` must be provided."
self._backend = backend or self._create_backend(cast(str, url))
self._subscribers: dict[str, set[asyncio.Queue[Event | None]]] = {}
self._subscribers: dict[str, set[asyncio.Queue[Event | BaseException | None]]] = {}

def _create_backend(self, url: str) -> BroadcastBackend:
parsed_url = urlparse(url)
Expand Down Expand Up @@ -69,10 +69,23 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
async def connect(self) -> None:
await self._backend.connect()
self._listener_task = asyncio.create_task(self._listener())
self._listener_task.add_done_callback(self.drop)

def drop(self, task: asyncio.Task[None]) -> None:
try:
exc = task.exception()
except asyncio.CancelledError:
pass
else:
for queues in self._subscribers.values():
for queue in queues:
queue.put_nowait(exc)

async def disconnect(self) -> None:
if self._listener_task.done():
self._listener_task.result()
exc = self._listener_task.exception()
if exc is None:
self._listener_task.result()
else:
self._listener_task.cancel()
await self._backend.disconnect()
Expand All @@ -88,7 +101,7 @@ async def publish(self, channel: str, message: Any) -> None:

@asynccontextmanager
async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]:
queue: asyncio.Queue[Event | None] = asyncio.Queue()
queue: asyncio.Queue[Event | BaseException | None] = asyncio.Queue()

try:
if not self._subscribers.get(channel):
Expand All @@ -107,7 +120,7 @@ async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]:


class Subscriber:
def __init__(self, queue: asyncio.Queue[Event | None]) -> None:
def __init__(self, queue: asyncio.Queue[Event | BaseException | None]) -> None:
self._queue = queue

async def __aiter__(self) -> AsyncGenerator[Event | None, None]:
Expand All @@ -119,6 +132,8 @@ async def __aiter__(self) -> AsyncGenerator[Event | None, None]:

async def get(self) -> Event:
item = await self._queue.get()
if isinstance(item, BaseException):
raise item
if item is None:
raise Unsubscribed()
return item
18 changes: 16 additions & 2 deletions broadcaster/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ def __init__(self, url: str):
self._conn = redis.Redis.from_url(url)
self._pubsub = self._conn.pubsub()
self._ready = asyncio.Event()
self._queue: asyncio.Queue[Event] = asyncio.Queue()
self._queue: asyncio.Queue[Event | BaseException | None] = asyncio.Queue()
self._listener: asyncio.Task[None] | None = None

async def connect(self) -> None:
self._listener = asyncio.create_task(self._pubsub_listener())
self._listener.add_done_callback(self.drop)
await self._pubsub.connect()

async def disconnect(self) -> None:
Expand All @@ -27,6 +28,14 @@ async def disconnect(self) -> None:
if self._listener is not None:
self._listener.cancel()

def drop(self, task: asyncio.Task[None]) -> None:
try:
exc = task.exception()
except asyncio.CancelledError:
pass
else:
self._queue.put_nowait(exc)

async def subscribe(self, channel: str) -> None:
self._ready.set()
await self._pubsub.subscribe(channel)
Expand All @@ -38,7 +47,12 @@ async def publish(self, channel: str, message: typing.Any) -> None:
await self._conn.publish(channel, message)

async def next_published(self) -> Event:
return await self._queue.get()
result = await self._queue.get()
if result is None:
raise RuntimeError
if isinstance(result, BaseException):
raise result
return result

async def _pubsub_listener(self) -> None:
# redis-py does not listen to the pubsub connection if there are no channels subscribed
Expand Down
40 changes: 40 additions & 0 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing

import pytest
import redis

from broadcaster import Broadcast, BroadcastBackend, Event
from broadcaster.backends.kafka import KafkaBackend
Expand Down Expand Up @@ -56,6 +57,45 @@ async def test_redis():
assert event.message == "hello"


@pytest.mark.asyncio
async def test_redis_server_disconnect():
with pytest.raises(redis.ConnectionError) as exc:
async with Broadcast("redis://localhost:6379") as broadcast:
async with broadcast.subscribe("chatroom") as subscriber:
await broadcast.publish("chatroom", "hello")
await broadcast._backend._conn.connection_pool.aclose() # type: ignore[attr-defined]
event = await subscriber.get()
assert event.channel == "chatroom"
assert event.message == "hello"
await subscriber.get()
assert False

assert exc.value.args == ("Connection closed by server.",)


@pytest.mark.asyncio
async def test_redis_does_not_log_loop_error_messages_if_subscribing(caplog):
async with Broadcast("redis://localhost:6379") as broadcast:
async with broadcast.subscribe("chatroom") as subscriber:
await broadcast.publish("chatroom", "hello")
event = await subscriber.get()
assert event.channel == "chatroom"
assert event.message == "hello"

assert caplog.messages == []


@pytest.mark.asyncio
async def test_redis_does_not_log_loop_error_messages_if_not_subscribing(caplog):
async with Broadcast("redis://localhost:6379") as broadcast:
await broadcast.publish("chatroom", "hello")

# Give the loop an opportunity to catch any errors before checking
# the logs.
await asyncio.sleep(0.1)
assert caplog.messages == []


@pytest.mark.asyncio
async def test_redis_stream():
async with Broadcast("redis-stream://localhost:6379") as broadcast:
Expand Down
Loading