From 1e05044fcac4a19851125899c12c610a65f71c73 Mon Sep 17 00:00:00 2001
From: Philipp Kessling
Date: Sat, 30 Nov 2024 00:14:58 +0100
Subject: [PATCH] Implement test for backlog writing and make it pass.
---
dabapush/Reader/NDJSONReader.py | 4 +++-
dabapush/Record.py | 30 ++++++++++++++++++++----
dabapush/Writer/Writer.py | 7 +++++-
dabapush/Writer/stdout_writer.py | 4 ++--
tests/Reader/test_reader.py | 37 ++++++++++++++++++++++++++++++
tests/Writer/test_stdout_writer.py | 3 ++-
6 files changed, 76 insertions(+), 9 deletions(-)
create mode 100644 tests/Reader/test_reader.py
diff --git a/dabapush/Reader/NDJSONReader.py b/dabapush/Reader/NDJSONReader.py
index 7f75c82..2fe8e62 100644
--- a/dabapush/Reader/NDJSONReader.py
+++ b/dabapush/Reader/NDJSONReader.py
@@ -1,5 +1,7 @@
"""NDJSON Writer plug-in for dabapush"""
+import weakref
+
# pylint: disable=R,I1101
from typing import Iterator, List
@@ -25,7 +27,7 @@ def read_and_split(
if not flatten_records
else flatten(ujson.loads(line))
),
- source=record,
+ source=weakref.ref(record),
)
for line_number, line in enumerate(file)
]
diff --git a/dabapush/Record.py b/dabapush/Record.py
index 2e82c96..06471da 100644
--- a/dabapush/Record.py
+++ b/dabapush/Record.py
@@ -12,6 +12,7 @@
EventHandler = Callable[[Self], None]
EventType = Literal["on_done", "on_error", "on_start"]
+RecordState = Literal["done", "error", "start", "rejected"]
@dataclasses.dataclass
@@ -62,6 +63,7 @@ class Record:
event_handlers: Dict[EventType, List[EventHandler]] = dataclasses.field(
default_factory=dict
)
+ state: RecordState = "start"
def split(
self,
@@ -126,6 +128,7 @@ def _handle_key_split_(self, id_key, key):
def to_log(self) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
"""Return a loggable representation of the record."""
+ log.debug(f"Logging record {self.uuid}.")
if self.source:
source = self.source()
if not source:
@@ -160,11 +163,29 @@ def walk_tree(self, only_leafs=True) -> List[Self]:
def done(self):
"""Call the on_done event handler."""
+ # Signal parent that this record is done
+ self.state = "done"
+ log.debug(f"Record {self.uuid} is set as done.")
+ if self.source:
+ parent: Record = self.source()
+ if not parent:
+ log.critical(f"Source of record {self.uuid} is not available")
+ raise ValueError(f"Source of record {self.uuid} is not available")
+ parent.signal_done()
+ log.debug(f"Signaled parent {parent.uuid} of record {self.uuid}.")
self.__dispatch_event__("on_done")
- # Clean up the record
- self.children = []
- self.source = None
- self.payload = None
+
+ def signal_done(self):
+ """Signal that a child record is done."""
+ # If all children are done, so is the parent.
+ _children_status_ = [child.state == "done" for child in self.children]
+ log.debug(
+ f"Signaled that children of {self.uuid} is done."
+ f" Children status: {list(zip(self.children, _children_status_))}"
+ )
+ if all(_children_status_):
+ self.done()
+ log.debug(f"Record {self.uuid} is done.")
def destroy(self):
"""Destroy the record and all its children."""
@@ -186,6 +207,7 @@ def __eq__(self, other: Self) -> bool:
def __dispatch_event__(self, event: EventType):
"""Dispatch an event to the event handlers."""
+ log.debug(f"Dispatching event '{event}' for '{self.uuid}'.")
for handler in self.event_handlers.get(event, []):
handler(self)
diff --git a/dabapush/Writer/Writer.py b/dabapush/Writer/Writer.py
index 91b3957..29a83f0 100644
--- a/dabapush/Writer/Writer.py
+++ b/dabapush/Writer/Writer.py
@@ -8,6 +8,8 @@
import abc
from typing import Iterator, List
+from loguru import logger as log
+
from ..Configuration.WriterConfiguration import WriterConfiguration
from ..Record import Record
@@ -35,13 +37,16 @@ def write(self, queue: Iterator[Record]) -> None:
Args:
queue (Iterator[Record]): Items to be consumed.
-
"""
for item in queue:
self.buffer.append(item)
if len(self.buffer) >= self.config.chunk_size:
self.persist()
+ log.debug(
+ f"Persisted {self.config.chunk_size} records. Setting to done."
+ )
for record in self.buffer:
+ log.debug(f"Setting record {record.uuid} as done.")
record.done()
self.buffer = []
diff --git a/dabapush/Writer/stdout_writer.py b/dabapush/Writer/stdout_writer.py
index c3dbe6c..18fb0a2 100644
--- a/dabapush/Writer/stdout_writer.py
+++ b/dabapush/Writer/stdout_writer.py
@@ -1,4 +1,5 @@
"""This module contains the STDOUTWriter and STDOUTWriterConfiguration classes."""
+
from ..Configuration.WriterConfiguration import WriterConfiguration
from .Writer import Writer
@@ -11,7 +12,6 @@ def __init__(self, config: "STDOUTWriterConfiguration"):
def persist(self):
last_rows = self.buffer
- self.buffer = []
for row in last_rows:
print(row)
@@ -21,6 +21,6 @@ class STDOUTWriterConfiguration(WriterConfiguration):
yaml_tag = "!dabapush:STDOUTWriterConfiguration"
- def get_writer(self):
+ def get_instance(self): # pylint: disable=W0221
"""Returns a STDOUTWriter instance."""
return STDOUTWriter(self)
diff --git a/tests/Reader/test_reader.py b/tests/Reader/test_reader.py
new file mode 100644
index 0000000..3897079
--- /dev/null
+++ b/tests/Reader/test_reader.py
@@ -0,0 +1,37 @@
+"""Test suite for the Reader class and its backlog."""
+
+from pathlib import Path
+
+import ujson
+
+from dabapush import NDJSONReaderConfiguration, STDOUTWriterConfiguration
+
+
+def test_backlogging(monkeypatch, tmp_path):
+ """
+ Should write the records to a file if the log_path is set.
+ """
+ monkeypatch.chdir(tmp_path) # Change the working directory to the temp path
+ # to isolate the automatically created backlog.
+ reader = NDJSONReaderConfiguration(
+ "test", id="testing", read_path=str(tmp_path / "data"), pattern="*.json"
+ ).get_instance()
+ writer = STDOUTWriterConfiguration(
+ "test", id="testing", chunk_size=1
+ ).get_instance()
+
+ records = [{"key": f"value_{n}"} for n in range(3)]
+ data_dir = tmp_path / "data"
+ data_dir.mkdir()
+ for file_num, record in enumerate(records):
+ with (data_dir / f"test{file_num}.json").open("wt", encoding="utf8") as f:
+ ujson.dump(record, f) # pylint: disable=I1101
+
+ writer.write(reader.read())
+
+ log_path = Path(".dabapush/test.jsonl")
+ assert log_path.exists()
+ with log_path.open("rt", encoding="utf8") as f:
+ content = f.readlines()
+ # assert content != []
+ assert len(content) == len(records)
diff --git a/tests/Writer/test_stdout_writer.py b/tests/Writer/test_stdout_writer.py
index c49c0e7..7033c36 100644
--- a/tests/Writer/test_stdout_writer.py
+++ b/tests/Writer/test_stdout_writer.py
@@ -1,10 +1,11 @@
"""Test the STDOUTWriter class."""
+
from dabapush import STDOUTWriterConfiguration
def test_stdout_writer(capsys):
"""Should write to stdout."""
- writer = STDOUTWriterConfiguration("stdout1").get_writer()
+ writer = STDOUTWriterConfiguration("stdout1").get_instance()
writer.buffer = ["test"]
writer.persist()