Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement test for backlog writing and make it pass. #61

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dabapush/Reader/NDJSONReader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""NDJSON Writer plug-in for dabapush"""

import weakref

# pylint: disable=R,I1101
from typing import Iterator, List

Expand All @@ -25,7 +27,7 @@ def read_and_split(
if not flatten_records
else flatten(ujson.loads(line))
),
source=record,
source=weakref.ref(record),
)
for line_number, line in enumerate(file)
]
Expand Down
30 changes: 26 additions & 4 deletions dabapush/Record.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

EventHandler = Callable[[Self], None]
EventType = Literal["on_done", "on_error", "on_start"]
RecordState = Literal["done", "error", "start", "rejected"]


@dataclasses.dataclass
Expand Down Expand Up @@ -62,6 +63,7 @@ class Record:
event_handlers: Dict[EventType, List[EventHandler]] = dataclasses.field(
default_factory=dict
)
state: RecordState = "start"

def split(
self,
Expand Down Expand Up @@ -126,6 +128,7 @@ def _handle_key_split_(self, id_key, key):

def to_log(self) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
"""Return a loggable representation of the record."""
log.debug(f"Logging record {self.uuid}.")
if self.source:
source = self.source()
if not source:
Expand Down Expand Up @@ -160,11 +163,29 @@ def walk_tree(self, only_leafs=True) -> List[Self]:

def done(self):
"""Call the on_done event handler."""
# Signal parent that this record is done
self.state = "done"
log.debug(f"Record {self.uuid} is set as done.")
if self.source:
parent: Record = self.source()
if not parent:
log.critical(f"Source of record {self.uuid} is not available")
raise ValueError(f"Source of record {self.uuid} is not available")
parent.signal_done()
log.debug(f"Signaled parent {parent.uuid} of record {self.uuid}.")
self.__dispatch_event__("on_done")
# Clean up the record
self.children = []
self.source = None
self.payload = None

def signal_done(self):
"""Signal that a child record is done."""
# If all children are done, so is the parent.
_children_status_ = [child.state == "done" for child in self.children]
log.debug(
f"Signaled that children of {self.uuid} is done."
f" Children status: {list(zip(self.children, _children_status_))}"
)
if all(_children_status_):
self.done()
log.debug(f"Record {self.uuid} is done.")

def destroy(self):
"""Destroy the record and all its children."""
Expand All @@ -186,6 +207,7 @@ def __eq__(self, other: Self) -> bool:

def __dispatch_event__(self, event: EventType):
"""Dispatch an event to the event handlers."""
log.debug(f"Dispatching event '{event}' for '{self.uuid}'.")
for handler in self.event_handlers.get(event, []):
handler(self)

Expand Down
7 changes: 6 additions & 1 deletion dabapush/Writer/Writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import abc
from typing import Iterator, List

from loguru import logger as log

from ..Configuration.WriterConfiguration import WriterConfiguration
from ..Record import Record

Expand Down Expand Up @@ -35,13 +37,16 @@ def write(self, queue: Iterator[Record]) -> None:

Args:
queue (Iterator[Record]): Items to be consumed.

"""
for item in queue:
self.buffer.append(item)
if len(self.buffer) >= self.config.chunk_size:
self.persist()
log.debug(
f"Persisted {self.config.chunk_size} records. Setting to done."
)
for record in self.buffer:
log.debug(f"Setting record {record.uuid} as done.")
record.done()
self.buffer = []

Expand Down
4 changes: 2 additions & 2 deletions dabapush/Writer/stdout_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module contains the STDOUTWriter and STDOUTWriterConfiguration classes."""

from ..Configuration.WriterConfiguration import WriterConfiguration
from .Writer import Writer

Expand All @@ -11,7 +12,6 @@ def __init__(self, config: "STDOUTWriterConfiguration"):

def persist(self):
last_rows = self.buffer
self.buffer = []
for row in last_rows:
print(row)

Expand All @@ -21,6 +21,6 @@ class STDOUTWriterConfiguration(WriterConfiguration):

yaml_tag = "!dabapush:STDOUTWriterConfiguration"

def get_writer(self):
def get_instance(self): # pylint: disable=W0221
"""Returns a STDOUTWriter instance."""
return STDOUTWriter(self)
37 changes: 37 additions & 0 deletions tests/Reader/test_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Test suite for the Reader class and its backlog."""

from pathlib import Path

import ujson

from dabapush import NDJSONReaderConfiguration, STDOUTWriterConfiguration


def test_backlogging(monkeypatch, tmp_path):
"""
Should write the records to a file if the log_path is set.
"""
monkeypatch.chdir(tmp_path) # Change the working directory to the temp path
# to isolate the automatically created backlog.
reader = NDJSONReaderConfiguration(
"test", id="testing", read_path=str(tmp_path / "data"), pattern="*.json"
).get_instance()
writer = STDOUTWriterConfiguration(
"test", id="testing", chunk_size=1
).get_instance()

records = [{"key": f"value_{n}"} for n in range(3)]
data_dir = tmp_path / "data"
data_dir.mkdir()
for file_num, record in enumerate(records):
with (data_dir / f"test{file_num}.json").open("wt", encoding="utf8") as f:
ujson.dump(record, f) # pylint: disable=I1101

writer.write(reader.read())

log_path = Path(".dabapush/test.jsonl")
assert log_path.exists()
with log_path.open("rt", encoding="utf8") as f:
content = f.readlines()
# assert content != []
assert len(content) == len(records)
3 changes: 2 additions & 1 deletion tests/Writer/test_stdout_writer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Test the STDOUTWriter class."""

from dabapush import STDOUTWriterConfiguration


def test_stdout_writer(capsys):
"""Should write to stdout."""
writer = STDOUTWriterConfiguration("stdout1").get_writer()
writer = STDOUTWriterConfiguration("stdout1").get_instance()
writer.buffer = ["test"]
writer.persist()

Expand Down