From 9585c748fa4d64b95b1b829a66c48a55cac11f83 Mon Sep 17 00:00:00 2001 From: Evgenii Gorchakov Date: Mon, 3 Feb 2025 14:25:23 +0100 Subject: [PATCH] feat: cacheable YaakMetadataDataFrameBuilder (#34) --- .pre-commit-config.yaml | 4 +- config/_templates/dataset/yaak.yaml | 68 ++++++++++++------------- pyproject.toml | 7 +-- src/rbyte/io/__init__.py | 4 +- src/rbyte/io/_json/dataframe_builder.py | 2 - src/rbyte/io/_mcap/dataframe_builder.py | 2 - src/rbyte/io/dataframe/aligner.py | 2 - src/rbyte/io/dataframe/concater.py | 2 - src/rbyte/io/dataframe/filter.py | 2 - src/rbyte/io/dataframe/fps_resampler.py | 2 - src/rbyte/io/dataframe/indexer.py | 2 - src/rbyte/io/hdf5/dataframe_builder.py | 2 - src/rbyte/io/path/dataframe_builder.py | 2 - src/rbyte/io/rrd/dataframe_builder.py | 2 - src/rbyte/io/video/dataframe_builder.py | 2 - src/rbyte/io/yaak/__init__.py | 7 +-- src/rbyte/io/yaak/dataframe_builder.py | 13 ++--- src/rbyte/sample/fixed_window.py | 2 - 18 files changed, 48 insertions(+), 79 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4550f0..57aab5c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,14 +11,14 @@ repos: - id: pyupgrade - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.2 + rev: v0.9.4 hooks: - id: ruff args: [--fix] - id: ruff-format - repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror - rev: 1.24.0 + rev: 1.26.0 hooks: - id: basedpyright diff --git a/config/_templates/dataset/yaak.yaml b/config/_templates/dataset/yaak.yaml index c3a7d17..c05c3d6 100644 --- a/config/_templates/dataset/yaak.yaml +++ b/config/_templates/dataset/yaak.yaml @@ -36,8 +36,40 @@ inputs: functions: - _target_: pipefunc.PipeFunc func: - _target_: hydra.utils.get_method - path: rbyte.io.build_yaak_metadata_dataframe + _target_: rbyte.io.YaakMetadataDataFrameBuilder + fields: + rbyte.io.yaak.proto.sensor_pb2.ImageMetadata: + time_stamp: + _target_: polars.Datetime + time_unit: ns + + frame_idx: + _target_: polars.Int32 + + camera_name: + _target_: polars.Enum + categories: + - cam_front_center + - cam_front_left + - cam_front_right + - cam_left_forward + - cam_right_forward + - cam_left_backward + - cam_right_backward + - cam_rear + + rbyte.io.yaak.proto.can_pb2.VehicleMotion: + time_stamp: + _target_: polars.Datetime + time_unit: ns + + speed: + _target_: polars.Float32 + + gear: + _target_: polars.Enum + categories: ["0", "1", "2", "3"] + output_name: output scope: metadata cache: true @@ -172,36 +204,4 @@ inputs: kwargs: metadata: path: ${data_dir}/(@=input_id@)/metadata.log - fields: - rbyte.io.yaak.proto.sensor_pb2.ImageMetadata: - time_stamp: - _target_: polars.Datetime - time_unit: ns - - frame_idx: - _target_: polars.Int32 - - camera_name: - _target_: polars.Enum - categories: - - cam_front_center - - cam_front_left - - cam_front_right - - cam_left_forward - - cam_right_forward - - cam_left_backward - - cam_right_backward - - cam_rear - - rbyte.io.yaak.proto.can_pb2.VehicleMotion: - time_stamp: - _target_: polars.Datetime - time_unit: ns - - speed: - _target_: polars.Float32 - - gear: - _target_: polars.Enum - categories: ["0", "1", "2", "3"] #@ end diff --git a/pyproject.toml b/pyproject.toml index 136add9..4651cd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rbyte" -version = "0.11.0" +version = "0.11.1" description = "Multimodal PyTorch dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] @@ -8,7 +8,7 @@ dependencies = [ "tensordict>=0.6.2", "torch", "numpy", - "polars>=1.18.0", + "polars>=1.21.0", "pydantic>=2.10.2", "more-itertools>=10.5.0", "hydra-core>=1.3.2", @@ -18,7 +18,8 @@ dependencies = [ "parse>=1.20.2", "structlog>=24.4.0", "tqdm>=4.66.5", - "pipefunc>=0.50.0", + "pipefunc>=0.53.0", + "xxhash>=3.5.0", ] readme = "README.md" requires-python = ">=3.12,<3.13" diff --git a/src/rbyte/io/__init__.py b/src/rbyte/io/__init__.py index 723eb84..8af4d9f 100644 --- a/src/rbyte/io/__init__.py +++ b/src/rbyte/io/__init__.py @@ -57,8 +57,8 @@ __all__ += ["VideoDataFrameBuilder"] try: - from .yaak import YaakMetadataDataFrameBuilder, build_yaak_metadata_dataframe + from .yaak import YaakMetadataDataFrameBuilder except ImportError: pass else: - __all__ += ["YaakMetadataDataFrameBuilder", "build_yaak_metadata_dataframe"] + __all__ += ["YaakMetadataDataFrameBuilder"] diff --git a/src/rbyte/io/_json/dataframe_builder.py b/src/rbyte/io/_json/dataframe_builder.py index 79fe202..4c000b3 100644 --- a/src/rbyte/io/_json/dataframe_builder.py +++ b/src/rbyte/io/_json/dataframe_builder.py @@ -19,8 +19,6 @@ @final class JsonDataFrameBuilder: - __name__ = __qualname__ - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__(self, fields: Fields) -> None: self._fields = fields diff --git a/src/rbyte/io/_mcap/dataframe_builder.py b/src/rbyte/io/_mcap/dataframe_builder.py index 2dbc8df..d095528 100644 --- a/src/rbyte/io/_mcap/dataframe_builder.py +++ b/src/rbyte/io/_mcap/dataframe_builder.py @@ -45,8 +45,6 @@ class SpecialField(StrEnum): @final class McapDataFrameBuilder: - __name__ = __qualname__ - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( self, diff --git a/src/rbyte/io/dataframe/aligner.py b/src/rbyte/io/dataframe/aligner.py index ac0e34b..eafd482 100644 --- a/src/rbyte/io/dataframe/aligner.py +++ b/src/rbyte/io/dataframe/aligner.py @@ -45,8 +45,6 @@ class MergeConfig(BaseModel): @final class DataFrameAligner: - __name__ = __qualname__ - @validate_call def __init__(self, *, fields: Fields, separator: str = "/") -> None: self._fields = fields diff --git a/src/rbyte/io/dataframe/concater.py b/src/rbyte/io/dataframe/concater.py index 709040d..47078fd 100644 --- a/src/rbyte/io/dataframe/concater.py +++ b/src/rbyte/io/dataframe/concater.py @@ -8,8 +8,6 @@ @final class DataFrameConcater: - __name__ = __qualname__ - @validate_call def __init__( self, method: ConcatMethod = "horizontal", separator: str | None = None diff --git a/src/rbyte/io/dataframe/filter.py b/src/rbyte/io/dataframe/filter.py index 258b29d..df00251 100644 --- a/src/rbyte/io/dataframe/filter.py +++ b/src/rbyte/io/dataframe/filter.py @@ -5,8 +5,6 @@ @final class DataFrameFilter: - __name__ = __qualname__ - def __init__(self, predicate: str) -> None: self._query = f"select * from self where {predicate}" # noqa: S608 diff --git a/src/rbyte/io/dataframe/fps_resampler.py b/src/rbyte/io/dataframe/fps_resampler.py index 93df618..0056f80 100644 --- a/src/rbyte/io/dataframe/fps_resampler.py +++ b/src/rbyte/io/dataframe/fps_resampler.py @@ -8,8 +8,6 @@ @final class DataFrameFpsResampler: - __name__ = __qualname__ - IDX_COL = uuid4().hex @validate_call diff --git a/src/rbyte/io/dataframe/indexer.py b/src/rbyte/io/dataframe/indexer.py index 1797a73..5c34937 100644 --- a/src/rbyte/io/dataframe/indexer.py +++ b/src/rbyte/io/dataframe/indexer.py @@ -8,8 +8,6 @@ @final class DataFrameIndexer: - __name__ = __qualname__ - @validate_call def __init__(self, name: str) -> None: self._fn = partial(pl.DataFrame.with_row_index, name=name) diff --git a/src/rbyte/io/hdf5/dataframe_builder.py b/src/rbyte/io/hdf5/dataframe_builder.py index 3f59a7e..b8b29c9 100644 --- a/src/rbyte/io/hdf5/dataframe_builder.py +++ b/src/rbyte/io/hdf5/dataframe_builder.py @@ -18,8 +18,6 @@ @final class Hdf5DataFrameBuilder: - __name__ = __qualname__ - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__(self, fields: Fields) -> None: self._fields = fields diff --git a/src/rbyte/io/path/dataframe_builder.py b/src/rbyte/io/path/dataframe_builder.py index e2e0850..d289d0f 100644 --- a/src/rbyte/io/path/dataframe_builder.py +++ b/src/rbyte/io/path/dataframe_builder.py @@ -22,8 +22,6 @@ @final class PathDataFrameBuilder: - __name__ = __qualname__ - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__(self, fields: Fields) -> None: self._fields = fields diff --git a/src/rbyte/io/rrd/dataframe_builder.py b/src/rbyte/io/rrd/dataframe_builder.py index 0b93c47..1691b3f 100644 --- a/src/rbyte/io/rrd/dataframe_builder.py +++ b/src/rbyte/io/rrd/dataframe_builder.py @@ -17,8 +17,6 @@ class Column(StrEnum): @final class RrdDataFrameBuilder: - __name__ = __qualname__ - @validate_call def __init__( self, index: str, contents: Mapping[str, Sequence[str] | None] diff --git a/src/rbyte/io/video/dataframe_builder.py b/src/rbyte/io/video/dataframe_builder.py index 980aa5c..32adf74 100644 --- a/src/rbyte/io/video/dataframe_builder.py +++ b/src/rbyte/io/video/dataframe_builder.py @@ -18,8 +18,6 @@ @final class VideoDataFrameBuilder: - __name__ = __qualname__ - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__(self, fields: Fields) -> None: self._fields = fields diff --git a/src/rbyte/io/yaak/__init__.py b/src/rbyte/io/yaak/__init__.py index 73fb791..845e6a0 100644 --- a/src/rbyte/io/yaak/__init__.py +++ b/src/rbyte/io/yaak/__init__.py @@ -1,6 +1,3 @@ -from .dataframe_builder import ( - YaakMetadataDataFrameBuilder, - build_yaak_metadata_dataframe, -) +from .dataframe_builder import YaakMetadataDataFrameBuilder -__all__ = ["YaakMetadataDataFrameBuilder", "build_yaak_metadata_dataframe"] +__all__ = ["YaakMetadataDataFrameBuilder"] diff --git a/src/rbyte/io/yaak/dataframe_builder.py b/src/rbyte/io/yaak/dataframe_builder.py index d4d7172..d7e1430 100644 --- a/src/rbyte/io/yaak/dataframe_builder.py +++ b/src/rbyte/io/yaak/dataframe_builder.py @@ -17,6 +17,7 @@ from pydantic import ConfigDict, ImportString, validate_call from structlog import get_logger from tqdm import tqdm +from xxhash import xxh3_64_hexdigest as digest from .message_iterator import YaakMetadataMessageIterator from .proto import sensor_pb2 @@ -31,14 +32,15 @@ @final class YaakMetadataDataFrameBuilder: - __name__ = __qualname__ - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__(self, *, fields: Fields) -> None: super().__init__() self._fields = fields + def __pipefunc_hash__(self) -> str: # noqa: PLW3201 + return digest(str(self._fields)) + def __call__(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: with Path(path).open("rb") as _f, mmap(_f.fileno(), 0, access=ACCESS_READ) as f: handler_pool = HandlerPool() @@ -77,10 +79,3 @@ def __call__(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: } return dfs - - -# exposing all kwargs so its cacheable by pipefunc -def build_yaak_metadata_dataframe( - *, path: PathLike[str], fields: Fields -) -> Mapping[str, pl.DataFrame]: - return YaakMetadataDataFrameBuilder(fields=fields)(path) diff --git a/src/rbyte/sample/fixed_window.py b/src/rbyte/sample/fixed_window.py index 8ca8dda..305c489 100644 --- a/src/rbyte/sample/fixed_window.py +++ b/src/rbyte/sample/fixed_window.py @@ -16,8 +16,6 @@ class FixedWindowSampleBuilder: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.group_by_dynamic """ - __name__ = __qualname__ - @validate_call def __init__( # noqa: PLR0913 self,