Skip to content

Commit

Permalink
feat: cacheable YaakMetadataDataFrameBuilder (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
egorchakov authored Feb 3, 2025
1 parent ffd14dc commit 9585c74
Show file tree
Hide file tree
Showing 18 changed files with 48 additions and 79 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ repos:
- id: pyupgrade

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.2
rev: v0.9.4
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror
rev: 1.24.0
rev: 1.26.0
hooks:
- id: basedpyright

Expand Down
68 changes: 34 additions & 34 deletions config/_templates/dataset/yaak.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,40 @@ inputs:
functions:
- _target_: pipefunc.PipeFunc
func:
_target_: hydra.utils.get_method
path: rbyte.io.build_yaak_metadata_dataframe
_target_: rbyte.io.YaakMetadataDataFrameBuilder
fields:
rbyte.io.yaak.proto.sensor_pb2.ImageMetadata:
time_stamp:
_target_: polars.Datetime
time_unit: ns

frame_idx:
_target_: polars.Int32

camera_name:
_target_: polars.Enum
categories:
- cam_front_center
- cam_front_left
- cam_front_right
- cam_left_forward
- cam_right_forward
- cam_left_backward
- cam_right_backward
- cam_rear

rbyte.io.yaak.proto.can_pb2.VehicleMotion:
time_stamp:
_target_: polars.Datetime
time_unit: ns

speed:
_target_: polars.Float32

gear:
_target_: polars.Enum
categories: ["0", "1", "2", "3"]

output_name: output
scope: metadata
cache: true
Expand Down Expand Up @@ -172,36 +204,4 @@ inputs:
kwargs:
metadata:
path: ${data_dir}/(@=input_id@)/metadata.log
fields:
rbyte.io.yaak.proto.sensor_pb2.ImageMetadata:
time_stamp:
_target_: polars.Datetime
time_unit: ns

frame_idx:
_target_: polars.Int32

camera_name:
_target_: polars.Enum
categories:
- cam_front_center
- cam_front_left
- cam_front_right
- cam_left_forward
- cam_right_forward
- cam_left_backward
- cam_right_backward
- cam_rear

rbyte.io.yaak.proto.can_pb2.VehicleMotion:
time_stamp:
_target_: polars.Datetime
time_unit: ns

speed:
_target_: polars.Float32

gear:
_target_: polars.Enum
categories: ["0", "1", "2", "3"]
#@ end
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[project]
name = "rbyte"
version = "0.11.0"
version = "0.11.1"
description = "Multimodal PyTorch dataset library"
authors = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
dependencies = [
"tensordict>=0.6.2",
"torch",
"numpy",
"polars>=1.18.0",
"polars>=1.21.0",
"pydantic>=2.10.2",
"more-itertools>=10.5.0",
"hydra-core>=1.3.2",
Expand All @@ -18,7 +18,8 @@ dependencies = [
"parse>=1.20.2",
"structlog>=24.4.0",
"tqdm>=4.66.5",
"pipefunc>=0.50.0",
"pipefunc>=0.53.0",
"xxhash>=3.5.0",
]
readme = "README.md"
requires-python = ">=3.12,<3.13"
Expand Down
4 changes: 2 additions & 2 deletions src/rbyte/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
__all__ += ["VideoDataFrameBuilder"]

try:
from .yaak import YaakMetadataDataFrameBuilder, build_yaak_metadata_dataframe
from .yaak import YaakMetadataDataFrameBuilder
except ImportError:
pass
else:
__all__ += ["YaakMetadataDataFrameBuilder", "build_yaak_metadata_dataframe"]
__all__ += ["YaakMetadataDataFrameBuilder"]
2 changes: 0 additions & 2 deletions src/rbyte/io/_json/dataframe_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

@final
class JsonDataFrameBuilder:
__name__ = __qualname__

@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(self, fields: Fields) -> None:
self._fields = fields
Expand Down
2 changes: 0 additions & 2 deletions src/rbyte/io/_mcap/dataframe_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ class SpecialField(StrEnum):

@final
class McapDataFrameBuilder:
__name__ = __qualname__

@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(
self,
Expand Down
2 changes: 0 additions & 2 deletions src/rbyte/io/dataframe/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ class MergeConfig(BaseModel):

@final
class DataFrameAligner:
__name__ = __qualname__

@validate_call
def __init__(self, *, fields: Fields, separator: str = "/") -> None:
self._fields = fields
Expand Down
2 changes: 0 additions & 2 deletions src/rbyte/io/dataframe/concater.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

@final
class DataFrameConcater:
__name__ = __qualname__

@validate_call
def __init__(
self, method: ConcatMethod = "horizontal", separator: str | None = None
Expand Down
2 changes: 0 additions & 2 deletions src/rbyte/io/dataframe/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

@final
class DataFrameFilter:
__name__ = __qualname__

def __init__(self, predicate: str) -> None:
self._query = f"select * from self where {predicate}" # noqa: S608

Expand Down
2 changes: 0 additions & 2 deletions src/rbyte/io/dataframe/fps_resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

@final
class DataFrameFpsResampler:
__name__ = __qualname__

IDX_COL = uuid4().hex

@validate_call
Expand Down
2 changes: 0 additions & 2 deletions src/rbyte/io/dataframe/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

@final
class DataFrameIndexer:
__name__ = __qualname__

@validate_call
def __init__(self, name: str) -> None:
self._fn = partial(pl.DataFrame.with_row_index, name=name)
Expand Down
2 changes: 0 additions & 2 deletions src/rbyte/io/hdf5/dataframe_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

@final
class Hdf5DataFrameBuilder:
__name__ = __qualname__

@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(self, fields: Fields) -> None:
self._fields = fields
Expand Down
2 changes: 0 additions & 2 deletions src/rbyte/io/path/dataframe_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

@final
class PathDataFrameBuilder:
__name__ = __qualname__

@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(self, fields: Fields) -> None:
self._fields = fields
Expand Down
2 changes: 0 additions & 2 deletions src/rbyte/io/rrd/dataframe_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ class Column(StrEnum):

@final
class RrdDataFrameBuilder:
__name__ = __qualname__

@validate_call
def __init__(
self, index: str, contents: Mapping[str, Sequence[str] | None]
Expand Down
2 changes: 0 additions & 2 deletions src/rbyte/io/video/dataframe_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

@final
class VideoDataFrameBuilder:
__name__ = __qualname__

@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(self, fields: Fields) -> None:
self._fields = fields
Expand Down
7 changes: 2 additions & 5 deletions src/rbyte/io/yaak/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from .dataframe_builder import (
YaakMetadataDataFrameBuilder,
build_yaak_metadata_dataframe,
)
from .dataframe_builder import YaakMetadataDataFrameBuilder

__all__ = ["YaakMetadataDataFrameBuilder", "build_yaak_metadata_dataframe"]
__all__ = ["YaakMetadataDataFrameBuilder"]
13 changes: 4 additions & 9 deletions src/rbyte/io/yaak/dataframe_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pydantic import ConfigDict, ImportString, validate_call
from structlog import get_logger
from tqdm import tqdm
from xxhash import xxh3_64_hexdigest as digest

from .message_iterator import YaakMetadataMessageIterator
from .proto import sensor_pb2
Expand All @@ -31,14 +32,15 @@

@final
class YaakMetadataDataFrameBuilder:
__name__ = __qualname__

@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(self, *, fields: Fields) -> None:
super().__init__()

self._fields = fields

def __pipefunc_hash__(self) -> str: # noqa: PLW3201
return digest(str(self._fields))

def __call__(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]:
with Path(path).open("rb") as _f, mmap(_f.fileno(), 0, access=ACCESS_READ) as f:
handler_pool = HandlerPool()
Expand Down Expand Up @@ -77,10 +79,3 @@ def __call__(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]:
}

return dfs


# exposing all kwargs so its cacheable by pipefunc
def build_yaak_metadata_dataframe(
*, path: PathLike[str], fields: Fields
) -> Mapping[str, pl.DataFrame]:
return YaakMetadataDataFrameBuilder(fields=fields)(path)
2 changes: 0 additions & 2 deletions src/rbyte/sample/fixed_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class FixedWindowSampleBuilder:
https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.group_by_dynamic
"""

__name__ = __qualname__

@validate_call
def __init__( # noqa: PLR0913
self,
Expand Down

0 comments on commit 9585c74

Please sign in to comment.