Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/22 #64

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@ test.db
.coverage
.pytest_cache/
.mypy_cache/
starlette.egg-info/
broadcaster.egg-info/
venv/
build/
dist/
8 changes: 3 additions & 5 deletions broadcaster/_backends/kafka.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import typing
from urllib.parse import urlparse

Expand All @@ -14,9 +13,8 @@ def __init__(self, url: str):
self._consumer_channels: typing.Set = set()

async def connect(self) -> None:
loop = asyncio.get_event_loop()
self._producer = AIOKafkaProducer(loop=loop, bootstrap_servers=self._servers)
self._consumer = AIOKafkaConsumer(loop=loop, bootstrap_servers=self._servers)
self._producer = AIOKafkaProducer(bootstrap_servers=self._servers)
self._consumer = AIOKafkaConsumer(bootstrap_servers=self._servers)
await self._producer.start()
await self._consumer.start()

Expand All @@ -29,7 +27,7 @@ async def subscribe(self, channel: str) -> None:
self._consumer.subscribe(topics=self._consumer_channels)

async def unsubscribe(self, channel: str) -> None:
await self._consumer.unsubscribe()
self._consumer.unsubscribe()

async def publish(self, channel: str, message: typing.Any) -> None:
await self._producer.send_and_wait(channel, message.encode("utf8"))
Expand Down
10 changes: 7 additions & 3 deletions broadcaster/_backends/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,23 @@ def __init__(self, url: str):

async def connect(self) -> None:
self._conn = await asyncpg.connect(self._url)
self._lock = asyncio.Lock()
self._listen_queue: asyncio.Queue = asyncio.Queue()

async def disconnect(self) -> None:
await self._conn.close()

async def subscribe(self, channel: str) -> None:
await self._conn.add_listener(channel, self._listener)
async with self._lock:
await self._conn.add_listener(channel, self._listener)

async def unsubscribe(self, channel: str) -> None:
await self._conn.remove_listener(channel, self._listener)
async with self._lock:
await self._conn.remove_listener(channel, self._listener)

async def publish(self, channel: str, message: str) -> None:
await self._conn.execute("SELECT pg_notify($1, $2);", channel, message)
async with self._lock:
await self._conn.execute("SELECT pg_notify($1, $2);", channel, message)

def _listener(self, *args: Any) -> None:
connection, pid, channel, payload = args
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ wheel

# Tests & Linting
autoflake
black==20.8b1
black==22.3.0
coverage==5.3
flake8
flake8-bugbear
flake8-pie==0.5.*
isort==5.*
mypy
pytest==5.*
pytest==7.*
pytest-asyncio
pytest-trio
trio
Expand Down
2 changes: 1 addition & 1 deletion scripts/check
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ set -x
${PREFIX}black --check --diff --target-version=py37 $SOURCE_FILES
${PREFIX}flake8 $SOURCE_FILES
${PREFIX}mypy $SOURCE_FILES
${PREFIX}isort --check --diff --project=httpx $SOURCE_FILES
${PREFIX}isort --check --diff --project=broadcaster $SOURCE_FILES
4 changes: 2 additions & 2 deletions scripts/lint
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ export PREFIX=""
if [ -d 'venv' ] ; then
export PREFIX="venv/bin/"
fi
export SOURCE_FILES="httpx tests"
export SOURCE_FILES="broadcaster tests"

set -x

${PREFIX}autoflake --in-place --recursive $SOURCE_FILES
${PREFIX}isort --project=httpx $SOURCE_FILES
${PREFIX}isort --project=broadcaster $SOURCE_FILES
${PREFIX}black --target-version=py37 $SOURCE_FILES
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ combine_as_imports = True

[tool:pytest]
addopts = -rxXs
asyncio_mode = strict
markers =
copied_from(source, changes=None): mark test as copied from somewhere else, along with a description of changes made to accodomate e.g. our test setup

[coverage:run]
omit = venv/*
include = httpx/*, tests/*
include = broadcaster/*, tests/*
59 changes: 59 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Check for #22"""

import asyncio
import functools

import pytest_asyncio

from broadcaster import Broadcast
from broadcaster._backends.kafka import KafkaBackend


async def __has_topic_now(client, topic):
if await client.force_metadata_update():
if topic in client.cluster.topics():
print(f'Topic "{topic}" exists')
return True
return False


async def __wait_has_topic(client, topic, *, timeout_sec=5):
poll_time_sec = 1 / 10000
from datetime import datetime

pre = datetime.now()
while True:
if (datetime.now() - pre).total_seconds() >= timeout_sec:
raise ValueError(f'No topic "{topic}" exists')
if await __has_topic_now(client, topic):
return
await asyncio.sleep(poll_time_sec)


def kafka_backend_setup(kafka_backend):
"""Block until consumer client contains the topic"""
subscribe_impl = kafka_backend.subscribe

@functools.wraps(subscribe_impl)
async def subscribe(channel: str) -> None:
await subscribe_impl(channel)
await __wait_has_topic(kafka_backend._consumer._client, channel)

kafka_backend.subscribe = subscribe


BROADCASTS_SETUP = {
KafkaBackend: kafka_backend_setup,
}


@pytest_asyncio.fixture(scope="function")
async def setup_broadcast(request):
url = request.param
async with Broadcast(url) as broadcast:
backend = broadcast._backend
for klass, setup in BROADCASTS_SETUP.items():
if isinstance(backend, klass):
setup(backend)
break
yield broadcast
58 changes: 18 additions & 40 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,23 @@
import pytest

from broadcaster import Broadcast


@pytest.mark.asyncio
async def test_memory():
async with Broadcast("memory://") as broadcast:
async with broadcast.subscribe("chatroom") as subscriber:
await broadcast.publish("chatroom", "hello")
event = await subscriber.get()
assert event.channel == "chatroom"
assert event.message == "hello"


@pytest.mark.asyncio
async def test_redis():
async with Broadcast("redis://localhost:6379") as broadcast:
async with broadcast.subscribe("chatroom") as subscriber:
await broadcast.publish("chatroom", "hello")
event = await subscriber.get()
assert event.channel == "chatroom"
assert event.message == "hello"
from uuid import uuid4

import pytest

@pytest.mark.asyncio
async def test_postgres():
async with Broadcast(
"postgres://postgres:postgres@localhost:5432/broadcaster"
) as broadcast:
async with broadcast.subscribe("chatroom") as subscriber:
await broadcast.publish("chatroom", "hello")
event = await subscriber.get()
assert event.channel == "chatroom"
assert event.message == "hello"
URLS = [
("memory://",),
("redis://localhost:6379",),
("postgres://postgres:postgres@localhost:5432/broadcaster",),
("kafka://localhost:9092",),
]


@pytest.mark.asyncio
async def test_kafka():
async with Broadcast("kafka://localhost:9092") as broadcast:
async with broadcast.subscribe("chatroom") as subscriber:
await broadcast.publish("chatroom", "hello")
event = await subscriber.get()
assert event.channel == "chatroom"
assert event.message == "hello"
@pytest.mark.parametrize(["setup_broadcast"], URLS, indirect=True)
async def test_broadcast(setup_broadcast):
uid = uuid4()
channel = f"chatroom-{uid}"
msg = f"hello {uid}"
async with setup_broadcast.subscribe(channel) as subscriber:
await setup_broadcast.publish(channel, msg)
event = await subscriber.get()
assert event.channel == channel
assert event.message == msg
62 changes: 62 additions & 0 deletions tests/test_concurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Check for #22"""
import asyncio
from uuid import uuid4

import pytest

MESSAGES = ["hello", "goodbye"]

URLS = [
("memory://",),
("redis://localhost:6379",),
("postgres://postgres:postgres@localhost:5432/broadcaster",),
("kafka://localhost:9092",),
]


@pytest.mark.asyncio
@pytest.mark.parametrize(["setup_broadcast"], URLS, indirect=True)
async def test_broadcast(setup_broadcast):
uid = uuid4()
channel = f"chatroom-{uid}"
msgs = [f"{msg} {uid}" for msg in MESSAGES]
async with setup_broadcast.subscribe(channel) as subscriber:
to_publish = [setup_broadcast.publish(channel, msg) for msg in msgs]

await asyncio.gather(*to_publish)
for msg in msgs:
event = await subscriber.get()
assert event.channel == channel
assert event.message == msg


@pytest.mark.asyncio
@pytest.mark.parametrize(["setup_broadcast"], URLS, indirect=True)
async def test_sub(setup_broadcast):
uid = uuid4()
channel1 = f"chatroom-{uid}1"
channel2 = f"chatroom-{uid}2"

to_sub = [
setup_broadcast._backend.subscribe(channel1),
setup_broadcast._backend.subscribe(channel2),
]
await asyncio.gather(*to_sub)


@pytest.mark.asyncio
@pytest.mark.parametrize(["setup_broadcast"], URLS, indirect=True)
async def test_unsub(setup_broadcast):
uid = uuid4()
channel1 = f"chatroom-{uid}1"
channel2 = f"chatroom-{uid}2"

await setup_broadcast._backend.subscribe(channel1)
await setup_broadcast._backend.subscribe(channel2)

to_unsub = [
setup_broadcast._backend.unsubscribe(channel1),
setup_broadcast._backend.unsubscribe(channel2),
]

await asyncio.gather(*to_unsub)