Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sample builders #24

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
repos:
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.22
rev: v0.23
hooks:
- id: validate-pyproject

Expand All @@ -11,14 +11,14 @@ repos:
- id: pyupgrade

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2
rev: v0.7.4
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror
rev: 1.21.0
rev: 1.21.1
hooks:
- id: basedpyright

Expand Down
7 changes: 2 additions & 5 deletions config/_templates/dataset/carla.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ inputs:
#@ end

sample_builder:
_target_: rbyte.sample.GreedySampleBuilder
_target_: rbyte.RollingWindowSampleBuilder
index_column: _idx_
length: 1
stride: 1
min_step: 1
filter: !!null
period: 1i
7 changes: 2 additions & 5 deletions config/_templates/dataset/mimicgen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ inputs:
#@ end

sample_builder:
_target_: rbyte.sample.GreedySampleBuilder
_target_: rbyte.RollingWindowSampleBuilder
index_column: _idx_
length: 1
stride: 1
min_step: 1
filter: !!null
period: 1i
10 changes: 3 additions & 7 deletions config/_templates/dataset/nuscenes/mcap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ inputs:
fields:
#@ for topic in camera_topics.values():
(@=topic@):
_idx_:
log_time:
_target_: polars.Datetime
time_unit: ns

_idx_:
#@ end

/odom:
Expand Down Expand Up @@ -94,9 +93,6 @@ inputs:
#@ end

sample_builder:
_target_: rbyte.sample.GreedySampleBuilder
_target_: rbyte.RollingWindowSampleBuilder
index_column: mcap/(@=camera_topics.values()[0]@)/_idx_
length: 1
stride: 1
min_step: 1
filter: !!null
period: 1i
8 changes: 2 additions & 6 deletions config/_templates/dataset/nuscenes/rrd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ inputs:
#@ end

sample_builder:
_target_: rbyte.sample.GreedySampleBuilder
_target_: rbyte.RollingWindowSampleBuilder
index_column: rrd/(@=camera_entities.values()[0]@)/_idx_
length: 1
stride: 1
min_step: 1
filter: !!null

period: 1i
12 changes: 6 additions & 6 deletions config/_templates/dataset/yaak.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ inputs:
_target_: polars.Datetime
time_unit: ns

frame_idx: polars.UInt32
frame_idx: polars.Int32
camera_name:
_target_: polars.Enum
categories:
Expand Down Expand Up @@ -130,10 +130,10 @@ inputs:
#@ end

sample_builder:
_target_: rbyte.sample.GreedySampleBuilder
_target_: rbyte.FixedWindowSampleBuilder
index_column: meta/ImageMetadata.(@=cameras[0]@)/frame_idx
length: 1
stride: 1
min_step: 1
every: 6i
period: 6i
filter: |
array_mean(`meta/VehicleMotion/speed`) > 40
array_length(`meta/ImageMetadata.(@=cameras[0]@)/time_stamp`) == 6
and array_mean(`meta/VehicleMotion/speed`) > 40
23 changes: 12 additions & 11 deletions config/_templates/dataset/zod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ inputs:
index_column: camera_front_blur/timestamp
source:
_target_: rbyte.io.PathTensorSource
path: "${data_dir}/zod/sequences/000002_short/camera_front_blur/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.jpg"
path: "${data_dir}/sequences/000002_short/camera_front_blur/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.jpg"
decoder:
_target_: simplejpeg.decode_jpeg
_partial_: true
Expand All @@ -21,15 +21,15 @@ inputs:
index_column: lidar_velodyne/timestamp
source:
_target_: rbyte.io.NumpyTensorSource
path: "${data_dir}/zod/sequences/000002_short/lidar_velodyne/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.npy"
path: "${data_dir}/sequences/000002_short/lidar_velodyne/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.npy"
select: ["x", "y", "z"]

table_builder:
_target_: rbyte.io.TableBuilder
_convert_: all
readers:
camera_front_blur:
path: "${data_dir}/zod/sequences/000002_short/camera_front_blur/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.jpg"
path: "${data_dir}/sequences/000002_short/camera_front_blur/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.jpg"
reader:
_target_: rbyte.io.PathTableReader
_recursive_: false
Expand All @@ -39,7 +39,7 @@ inputs:
time_unit: ns

lidar_velodyne:
path: "${data_dir}/zod/sequences/000002_short/lidar_velodyne/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.npy"
path: "${data_dir}/sequences/000002_short/lidar_velodyne/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.npy"
reader:
_target_: rbyte.io.PathTableReader
_recursive_: false
Expand All @@ -49,7 +49,7 @@ inputs:
time_unit: ns

vehicle_data:
path: "${data_dir}/zod/sequences/000002_short/vehicle_data.hdf5"
path: "${data_dir}/sequences/000002_short/vehicle_data.hdf5"
reader:
_target_: rbyte.io.Hdf5TableReader
_recursive_: false
Expand Down Expand Up @@ -82,7 +82,7 @@ inputs:
timestamp:
method: asof
strategy: nearest
tolerance: 50ms
tolerance: 100ms

vehicle_data:
ego_vehicle_controls:
Expand All @@ -91,17 +91,17 @@ inputs:
timestamp/nanoseconds/value:
method: asof
strategy: nearest
tolerance: 50ms
tolerance: 100ms

acceleration_pedal/ratio/unitless/value:
method: asof
strategy: nearest
tolerance: 50ms
tolerance: 100ms

steering_wheel_angle/angle/radians/value:
method: asof
strategy: nearest
tolerance: 50ms
tolerance: 100ms

satellite:
key: timestamp/nanoseconds/value
Expand All @@ -110,5 +110,6 @@ inputs:
method: interp

sample_builder:
_target_: rbyte.sample.GreedySampleBuilder
length: 1
_target_: rbyte.FixedWindowSampleBuilder
index_column: camera_front_blur/timestamp
every: 300ms
4 changes: 2 additions & 2 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ generate-config:
--strict

test *ARGS: generate-config
uv run pytest --capture=no {{ ARGS }}
uv run --all-extras pytest --capture=no {{ ARGS }}

notebook FILE *ARGS: sync generate-config
uv run --with=jupyter,jupyterlab-vim,rerun-notebook jupyter lab {{ FILE }} {{ ARGS }}
uv run --all-extras --with=jupyter,jupyterlab-vim,rerun-notebook jupyter lab {{ FILE }} {{ ARGS }}

[group('scripts')]
visualize *ARGS: generate-config
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rbyte"
version = "0.7.0"
version = "0.8.0"
description = "Multimodal PyTorch dataset library"
authors = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
Expand Down
8 changes: 7 additions & 1 deletion src/rbyte/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from importlib.metadata import version

from .dataset import Dataset
from .sample import FixedWindowSampleBuilder, RollingWindowSampleBuilder

__version__ = version(__package__ or __name__)

__all__ = ["Dataset", "__version__"]
__all__ = [
"Dataset",
"FixedWindowSampleBuilder",
"RollingWindowSampleBuilder",
"__version__",
]
12 changes: 5 additions & 7 deletions src/rbyte/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ def __init__(
super().__init__()

_sample_builder = sample_builder.instantiate()
samples: Mapping[str, pl.LazyFrame] = {}
samples: Mapping[str, pl.DataFrame] = {}
for input_id, input_cfg in inputs.items():
with bound_contextvars(input_id=input_id):
table = input_cfg.table_builder.instantiate().build().lazy()
table = input_cfg.table_builder.instantiate().build()
samples[input_id] = _sample_builder.build(table)
logger.debug(
"built samples",
rows=table.select(pl.len()).collect().item(),
samples=samples[input_id].select(pl.len()).collect().item(),
rows=table.select(pl.len()).item(),
samples=samples[input_id].select(pl.len()).item(),
)

input_id_enum = pl.Enum(sorted(samples))
Expand All @@ -85,12 +85,11 @@ def __init__(
)
.sort(Column.input_id)
.with_row_index(Column.sample_idx)
.collect()
.rechunk()
)

self._sources: pl.DataFrame = (
pl.LazyFrame(
pl.DataFrame(
[
{
Column.input_id: input_id,
Expand All @@ -112,7 +111,6 @@ def __init__(
.explode(k)
.unnest(k)
.select(Column.input_id, pl.exclude(Column.input_id).name.prefix(f"{k}."))
.collect()
.rechunk()
)

Expand Down
5 changes: 3 additions & 2 deletions src/rbyte/sample/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .greedy_builder import GreedySampleBuilder
from .fixed_window import FixedWindowSampleBuilder
from .rolling_window import RollingWindowSampleBuilder

__all__ = ["GreedySampleBuilder"]
__all__ = ["FixedWindowSampleBuilder", "RollingWindowSampleBuilder"]
2 changes: 1 addition & 1 deletion src/rbyte/sample/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

@runtime_checkable
class SampleBuilder(Protocol):
def build(self, source: pl.LazyFrame) -> pl.LazyFrame: ...
def build(self, source: pl.DataFrame) -> pl.DataFrame: ...
54 changes: 54 additions & 0 deletions src/rbyte/sample/fixed_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from datetime import timedelta
from typing import Literal, override
from uuid import uuid4

import polars as pl
from polars._typing import ClosedInterval
from pydantic import validate_call

from .base import SampleBuilder


class FixedWindowSampleBuilder(SampleBuilder):
"""
Build samples using fixed (potentially overlapping) windows based on a temporal or
integer column.

https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.group_by_dynamic
"""

@validate_call
def __init__(
self,
*,
index_column: str,
every: str | timedelta,
period: str | timedelta | None = None,
closed: ClosedInterval = "left",
filter: str | None = None, # noqa: A002
) -> None:
self._index_column: pl.Expr = pl.col(index_column)
self._every: str | timedelta = every
self._period: str | timedelta | None = period
self._closed: ClosedInterval = closed
self._filter: str | Literal[True] = filter if filter is not None else True

@override
def build(self, source: pl.DataFrame) -> pl.DataFrame:
return (
source.sort(self._index_column)
.with_columns(self._index_column.alias(_index_column := uuid4().hex))
.group_by_dynamic(
index_column=_index_column,
every=self._every,
period=self._period,
closed=self._closed,
label="datapoint",
start_by="datapoint",
)
.agg(pl.all())
.sql(f"select * from self where ({self._filter})") # noqa: S608
.filter(self._index_column.list.len() > 0)
.sort(_index_column)
.drop(_index_column)
)
Loading
Loading