From 1e05044fcac4a19851125899c12c610a65f71c73 Mon Sep 17 00:00:00 2001 From: Philipp Kessling Date: Sat, 30 Nov 2024 00:14:58 +0100 Subject: [PATCH] Implement test for backlog writing and make it pass. --- dabapush/Reader/NDJSONReader.py | 4 +++- dabapush/Record.py | 30 ++++++++++++++++++++---- dabapush/Writer/Writer.py | 7 +++++- dabapush/Writer/stdout_writer.py | 4 ++-- tests/Reader/test_reader.py | 37 ++++++++++++++++++++++++++++++ tests/Writer/test_stdout_writer.py | 3 ++- 6 files changed, 76 insertions(+), 9 deletions(-) create mode 100644 tests/Reader/test_reader.py diff --git a/dabapush/Reader/NDJSONReader.py b/dabapush/Reader/NDJSONReader.py index 7f75c82..2fe8e62 100644 --- a/dabapush/Reader/NDJSONReader.py +++ b/dabapush/Reader/NDJSONReader.py @@ -1,5 +1,7 @@ """NDJSON Writer plug-in for dabapush""" +import weakref + # pylint: disable=R,I1101 from typing import Iterator, List @@ -25,7 +27,7 @@ def read_and_split( if not flatten_records else flatten(ujson.loads(line)) ), - source=record, + source=weakref.ref(record), ) for line_number, line in enumerate(file) ] diff --git a/dabapush/Record.py b/dabapush/Record.py index 2e82c96..06471da 100644 --- a/dabapush/Record.py +++ b/dabapush/Record.py @@ -12,6 +12,7 @@ EventHandler = Callable[[Self], None] EventType = Literal["on_done", "on_error", "on_start"] +RecordState = Literal["done", "error", "start", "rejected"] @dataclasses.dataclass @@ -62,6 +63,7 @@ class Record: event_handlers: Dict[EventType, List[EventHandler]] = dataclasses.field( default_factory=dict ) + state: RecordState = "start" def split( self, @@ -126,6 +128,7 @@ def _handle_key_split_(self, id_key, key): def to_log(self) -> Dict[str, Union[str, List[Dict[str, Any]]]]: """Return a loggable representation of the record.""" + log.debug(f"Logging record {self.uuid}.") if self.source: source = self.source() if not source: @@ -160,11 +163,29 @@ def walk_tree(self, only_leafs=True) -> List[Self]: def done(self): """Call the on_done event handler.""" + # Signal parent that this record is done + self.state = "done" + log.debug(f"Record {self.uuid} is set as done.") + if self.source: + parent: Record = self.source() + if not parent: + log.critical(f"Source of record {self.uuid} is not available") + raise ValueError(f"Source of record {self.uuid} is not available") + parent.signal_done() + log.debug(f"Signaled parent {parent.uuid} of record {self.uuid}.") self.__dispatch_event__("on_done") - # Clean up the record - self.children = [] - self.source = None - self.payload = None + + def signal_done(self): + """Signal that a child record is done.""" + # If all children are done, so is the parent. + _children_status_ = [child.state == "done" for child in self.children] + log.debug( + f"Signaled that children of {self.uuid} is done." + f" Children status: {list(zip(self.children, _children_status_))}" + ) + if all(_children_status_): + self.done() + log.debug(f"Record {self.uuid} is done.") def destroy(self): """Destroy the record and all its children.""" @@ -186,6 +207,7 @@ def __eq__(self, other: Self) -> bool: def __dispatch_event__(self, event: EventType): """Dispatch an event to the event handlers.""" + log.debug(f"Dispatching event '{event}' for '{self.uuid}'.") for handler in self.event_handlers.get(event, []): handler(self) diff --git a/dabapush/Writer/Writer.py b/dabapush/Writer/Writer.py index 91b3957..29a83f0 100644 --- a/dabapush/Writer/Writer.py +++ b/dabapush/Writer/Writer.py @@ -8,6 +8,8 @@ import abc from typing import Iterator, List +from loguru import logger as log + from ..Configuration.WriterConfiguration import WriterConfiguration from ..Record import Record @@ -35,13 +37,16 @@ def write(self, queue: Iterator[Record]) -> None: Args: queue (Iterator[Record]): Items to be consumed. - """ for item in queue: self.buffer.append(item) if len(self.buffer) >= self.config.chunk_size: self.persist() + log.debug( + f"Persisted {self.config.chunk_size} records. Setting to done." + ) for record in self.buffer: + log.debug(f"Setting record {record.uuid} as done.") record.done() self.buffer = [] diff --git a/dabapush/Writer/stdout_writer.py b/dabapush/Writer/stdout_writer.py index c3dbe6c..18fb0a2 100644 --- a/dabapush/Writer/stdout_writer.py +++ b/dabapush/Writer/stdout_writer.py @@ -1,4 +1,5 @@ """This module contains the STDOUTWriter and STDOUTWriterConfiguration classes.""" + from ..Configuration.WriterConfiguration import WriterConfiguration from .Writer import Writer @@ -11,7 +12,6 @@ def __init__(self, config: "STDOUTWriterConfiguration"): def persist(self): last_rows = self.buffer - self.buffer = [] for row in last_rows: print(row) @@ -21,6 +21,6 @@ class STDOUTWriterConfiguration(WriterConfiguration): yaml_tag = "!dabapush:STDOUTWriterConfiguration" - def get_writer(self): + def get_instance(self): # pylint: disable=W0221 """Returns a STDOUTWriter instance.""" return STDOUTWriter(self) diff --git a/tests/Reader/test_reader.py b/tests/Reader/test_reader.py new file mode 100644 index 0000000..3897079 --- /dev/null +++ b/tests/Reader/test_reader.py @@ -0,0 +1,37 @@ +"""Test suite for the Reader class and its backlog.""" + +from pathlib import Path + +import ujson + +from dabapush import NDJSONReaderConfiguration, STDOUTWriterConfiguration + + +def test_backlogging(monkeypatch, tmp_path): + """ + Should write the records to a file if the log_path is set. + """ + monkeypatch.chdir(tmp_path) # Change the working directory to the temp path + # to isolate the automatically created backlog. + reader = NDJSONReaderConfiguration( + "test", id="testing", read_path=str(tmp_path / "data"), pattern="*.json" + ).get_instance() + writer = STDOUTWriterConfiguration( + "test", id="testing", chunk_size=1 + ).get_instance() + + records = [{"key": f"value_{n}"} for n in range(3)] + data_dir = tmp_path / "data" + data_dir.mkdir() + for file_num, record in enumerate(records): + with (data_dir / f"test{file_num}.json").open("wt", encoding="utf8") as f: + ujson.dump(record, f) # pylint: disable=I1101 + + writer.write(reader.read()) + + log_path = Path(".dabapush/test.jsonl") + assert log_path.exists() + with log_path.open("rt", encoding="utf8") as f: + content = f.readlines() + # assert content != [] + assert len(content) == len(records) diff --git a/tests/Writer/test_stdout_writer.py b/tests/Writer/test_stdout_writer.py index c49c0e7..7033c36 100644 --- a/tests/Writer/test_stdout_writer.py +++ b/tests/Writer/test_stdout_writer.py @@ -1,10 +1,11 @@ """Test the STDOUTWriter class.""" + from dabapush import STDOUTWriterConfiguration def test_stdout_writer(capsys): """Should write to stdout.""" - writer = STDOUTWriterConfiguration("stdout1").get_writer() + writer = STDOUTWriterConfiguration("stdout1").get_instance() writer.buffer = ["test"] writer.persist()