Skip to content

Commit

Permalink
perf(test_utils): replace sleep with async queue join when stopping t…
Browse files Browse the repository at this point in the history
…est stream (#78)

TestStreamClient uses async queues in the background, which can be awaited for all processing to be done. This speeds up tests starting and stopping the TestStreamClient by removing the 1s sleep.
  • Loading branch information
JeroennC authored Nov 23, 2022
1 parent c15f31f commit 4cffa7a
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 4 deletions.
14 changes: 13 additions & 1 deletion kstreams/test_utils/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from kstreams.types import Headers

from .structs import RecordMetadata, TopicPartition
from .topics import TopicManager
from .topics import Topic, TopicManager


class Base:
Expand All @@ -17,6 +17,8 @@ async def start(self):


class TestProducer(Base, Producer):
__test__ = False

async def send(
self,
topic_name: str,
Expand Down Expand Up @@ -62,11 +64,14 @@ async def fut():


class TestConsumer(Base, Consumer):
__test__ = False

def __init__(self, *topics: str, group_id: Optional[str] = None, **kwargs) -> None:
# copy the aiokafka behavior
self.topics: Tuple[str, ...] = topics
self._group_id: Optional[str] = group_id
self._assignment: List[TopicPartition] = []
self._previous_topic: Optional[Topic] = None
self.partitions_committed: Dict[TopicPartition, int] = {}

for topic_name in topics:
Expand Down Expand Up @@ -145,6 +150,11 @@ def partitions_for_topic(self, topic: str) -> Set:
async def getone(
self,
) -> Optional[ConsumerRecord]: # The return type must be fixed later on
if self._previous_topic:
# Assumes previous record retrieved through getone was completed
self._previous_topic.task_done()
self._previous_topic = None

topic = None
for topic_partition in self._assignment:
topic = TopicManager.get(topic_partition.topic)
Expand All @@ -155,5 +165,7 @@ async def getone(
if topic is not None:
consumer_record = await topic.get()
self._check_partition_assignments(consumer_record)
self._previous_topic = topic
return consumer_record

return None
6 changes: 3 additions & 3 deletions kstreams/test_utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from types import TracebackType
from typing import Any, Dict, List, Optional, Type

Expand All @@ -14,6 +13,8 @@


class TestStreamClient:
__test__ = False

def __init__(self, stream_engine: StreamEngine) -> None:
self.stream_engine = stream_engine

Expand All @@ -40,8 +41,7 @@ async def start(self) -> None:
async def stop(self) -> None:
# If there are streams, we must wait until all the messages are consumed
if self.stream_engine._streams:
while not TopicManager.all_messages_consumed():
await asyncio.sleep(1)
await TopicManager.join()

await self.stream_engine.stop()

Expand Down
13 changes: 13 additions & 0 deletions kstreams/test_utils/topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ async def put(self, event: ConsumerRecord) -> None:
async def get(self) -> ConsumerRecord:
return await self.queue.get()

def task_done(self) -> None:
self.queue.task_done()

async def join(self) -> None:
await self.queue.join()

def is_empty(self) -> bool:
return self.queue.empty()

Expand Down Expand Up @@ -106,6 +112,13 @@ def all_messages_consumed(cls) -> bool:
return False
return True

@classmethod
async def join(cls) -> None:
"""
Wait for all topic messages to be processed
"""
await asyncio.gather(*[topic.join() for topic in cls.topics.values()])

@classmethod
def clean(cls) -> None:
cls.topics = {}
3 changes: 3 additions & 0 deletions tests/test_stream_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from unittest import mock

import pytest
Expand Down Expand Up @@ -168,6 +169,8 @@ async def stream(_):
Consumer.start.assert_awaited()
stream_engine._producer.start.assert_awaited()

await asyncio.sleep(0) # Allow stream coroutine to run once

await stream_engine.stop()
stream_engine._producer.stop.assert_awaited()
Consumer.stop.assert_awaited()
Expand Down

0 comments on commit 4cffa7a

Please sign in to comment.