diff --git a/dabapush/Reader/NDJSONReader.py b/dabapush/Reader/NDJSONReader.py index 32d42ae..29ff064 100644 --- a/dabapush/Reader/NDJSONReader.py +++ b/dabapush/Reader/NDJSONReader.py @@ -1,7 +1,7 @@ """NDJSON Writer plug-in for dabapush""" # pylint: disable=R,I1101 -from typing import Iterator, List +from typing import Iterator import ujson @@ -14,10 +14,10 @@ def read_and_split( record: Record, flatten_records: bool = False, -) -> List[Record]: +) -> Iterator[Record]: """Reads a file and splits it into records by line.""" with record.payload.open("rt", encoding="utf8") as file: - children = [ + children = ( Record( uuid=f"{str(record.uuid)}:{str(line_number)}", payload=( @@ -28,10 +28,10 @@ def read_and_split( source=record, ) for line_number, line in enumerate(file) - ] - record.children.extend(children) - - return children + ) + for child in children: + record.children.append(child) + yield child class NDJSONReader(FileReader): diff --git a/dabapush/Record.py b/dabapush/Record.py index d668205..382118a 100644 --- a/dabapush/Record.py +++ b/dabapush/Record.py @@ -5,7 +5,7 @@ # pylint: disable=R0917, R0913 from datetime import datetime -from typing import Any, Callable, Dict, List, Literal, Optional, Self, Union +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Self, Union from uuid import uuid4 from loguru import logger as log @@ -92,9 +92,9 @@ def split( self, key: Optional[str] = None, id_key: Optional[str] = None, - func: Optional[Callable[[Self, ...], List[Self]]] = None, + func: Optional[Callable[[Self, ...], Iterable[Self]]] = None, **kwargs, - ) -> List[Self]: + ) -> Iterable[Self]: """Splits the record bases on either a keyword or a function. If a function is provided, it will be used to split the payload, even if you provide a key. If a key is provided, it will split the payload. @@ -134,22 +134,20 @@ def split( def _handle_key_split_(self, id_key, key): payload = self.payload # Get the payload, the original payload # will be set to None to free memory. - if key not in payload: - return [] - if not isinstance(payload[key], list): - return [] - split_payload = [ - Record( - **{ - "payload": value, - "uuid": value.get(id_key) if id_key else uuid4().hex, - "source": self, - } + if key in payload and isinstance(payload[key], list): + split_payload = ( + Record( + **{ + "payload": value, + "uuid": value.get(id_key) if id_key else uuid4().hex, + "source": self, + } + ) + for value in payload[key] ) - for value in payload[key] - ] - self.children.extend(split_payload) - return split_payload + for child in split_payload: + self.children.append(child) + yield child def to_log(self) -> Dict[str, Union[str, List[Dict[str, Any]]]]: """Return a loggable representation of the record."""