Skip to content

Commit

Permalink
feat: batch key selection (#32)
Browse files Browse the repository at this point in the history
- expose `Dataset.get_batch(keys=...)`, mimicking `TensorDict.select`
- bump `tensordict>=0.7.0`
  • Loading branch information
egorchakov authored Feb 10, 2025
1 parent 9585c74 commit df252dd
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 76 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- id: pyupgrade

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.4
rev: v0.9.5
hooks:
- id: ruff
args: [--fix]
Expand Down
18 changes: 8 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
[project]
name = "rbyte"
version = "0.11.1"
version = "0.12.0"
description = "Multimodal PyTorch dataset library"
authors = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
dependencies = [
"tensordict>=0.6.2",
"tensordict>=0.7.0",
"torch",
"numpy",
"polars>=1.21.0",
Expand Down Expand Up @@ -39,7 +39,7 @@ repo = "https://github.com/yaak-ai/rbyte"

[project.optional-dependencies]
build = ["hatchling>=1.25.0", "grpcio-tools>=1.62.0", "protoletariat==3.2.19"]
visualize = ["rerun-sdk[notebook]>=0.21.0"]
visualize = ["rerun-sdk[notebook]==0.21.0"]
mcap = [
"mcap>=1.2.1",
"mcap-ros2-support>=0.5.5",
Expand All @@ -53,7 +53,7 @@ video = [
"video-reader-rs>=0.2.2",
]
hdf5 = ["h5py>=3.12.1"]
rrd = ["rerun-sdk>=0.21.0", "pyarrow-stubs"]
rrd = ["rerun-sdk==0.21.0", "pyarrow-stubs"]

[project.scripts]
rbyte-visualize = 'rbyte.scripts.visualize:main'
Expand All @@ -66,20 +66,18 @@ requires = [
]
build-backend = "hatchling.build"

[tool.uv]
dev-dependencies = [
[dependency-groups]
dev = [
"wat-inspector>=0.4.3",
"lovely-tensors>=0.1.18",
"pudb>=2024.1.2",
"ipython>=8.30.0",
"ipython>=8.32.0",
"ipython-autoimport>=0.5",
"pytest>=8.3.3",
"pytest>=8.3.4",
"testbook>=0.4.2",
"ipykernel>=6.29.5",
]

[tool.uv.sources]

[tool.hatch.metadata]
allow-direct-references = true

Expand Down
27 changes: 27 additions & 0 deletions src/rbyte/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Literal

from tensordict import (
NonTensorData, # pyright: ignore[reportAttributeAccessIssue, reportUnknownVariableType]
TensorClass,
TensorDict,
)
from torch import Tensor


class BatchMeta(TensorClass, autocast=True): # pyright: ignore[reportGeneralTypeIssues, reportCallIssue]
sample_idx: Tensor | None = None
input_id: NonTensorData | None = None # pyright: ignore[reportUnknownVariableType]


class Batch(TensorClass, autocast=True): # pyright: ignore[reportGeneralTypeIssues, reportCallIssue]
data: TensorDict | None = None # pyright: ignore[reportIncompatibleMethodOverride]
meta: BatchMeta | None = None


type BatchKeys = frozenset[
Literal["data", "meta"]
| tuple[Literal["data"], str]
| tuple[Literal["meta"], Literal["sample_idx", "input_id"]]
]

BATCH_KEYS_DEFAULT = frozenset(("data", "meta"))
3 changes: 0 additions & 3 deletions src/rbyte/batch/__init__.py

This file was deleted.

18 changes: 0 additions & 18 deletions src/rbyte/batch/batch.py

This file was deleted.

149 changes: 108 additions & 41 deletions src/rbyte/dataset.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum, unique
from functools import cache
from typing import Annotated
from typing import Annotated, Literal, override

import polars as pl
import torch
from hydra.utils import instantiate
from pipefunc import Pipeline
from pydantic import Field, StringConstraints, validate_call
from pydantic import ConfigDict, Field, StringConstraints, validate_call
from structlog import get_logger
from structlog.contextvars import bound_contextvars
from tensordict import TensorDict
from torch.utils.data import Dataset as TorchDataset

from rbyte.batch import Batch, BatchMeta
from rbyte.batch import BATCH_KEYS_DEFAULT, Batch, BatchKeys, BatchMeta
from rbyte.config import BaseModel, HydraConfig
from rbyte.io.base import TensorSource
from rbyte.utils.tensor import pad_sequence
Expand Down Expand Up @@ -53,7 +53,14 @@ class Column(StrEnum):
source_index_column = "__source.index_column"


class Dataset(TorchDataset[TensorDict]):
class _ALL_TYPE: # noqa: N801
pass


_ALL = _ALL_TYPE()


class Dataset(TorchDataset[Batch]):
@validate_call(config=BaseModel.model_config)
def __init__(
self, inputs: Annotated[Mapping[Id, InputConfig], Field(min_length=1)]
Expand Down Expand Up @@ -136,52 +143,112 @@ def sources(self) -> pl.DataFrame:
def _get_source(self, config: str) -> TensorSource: # noqa: PLR6301
return HydraConfig[TensorSource].model_validate_json(config).instantiate()

def __getitems__(self, indexes: Sequence[int]) -> Batch: # noqa: PLW3201
samples = self.samples[indexes]
batch_size = [samples.height]
@validate_call(
config=ConfigDict(arbitrary_types_allowed=True, validate_default=False)
)
def get_batch(
self,
index: int | Sequence[int] | slice | range,
*,
keys: BatchKeys = BATCH_KEYS_DEFAULT,
) -> Batch:
subkeys: Mapping[Literal["data", "meta"], set[_ALL_TYPE | str]] = {
"data": set(),
"meta": set(),
}
for key in keys:
match key:
case "data" | "meta":
subkeys[key].add(_ALL)

source_idx_cols = self._sources[Column.source_index_column].unique()
case ("data" | "meta", _):
subkeys[key[0]].add(key[1])

sources = (
samples.lazy()
.join(self.sources.lazy(), on=Column.input_id, how="left")
.with_columns(
pl.coalesce(
pl.when(pl.col(Column.source_index_column) == idx_col).then(idx_col)
for idx_col in source_idx_cols
).alias(Column.source_idxs)
for v in subkeys.values():
if _ALL in v and len(v) > 1:
v.remove(_ALL)

samples = self.samples[index]
batch_size = [samples.height]

if subkeys_data := subkeys["data"]:
source_idx_cols = self._sources[Column.source_index_column].unique()
sources = (
samples.lazy()
.join(self.sources.lazy(), on=Column.input_id, how="left")
.with_columns(
pl.coalesce(
pl.when(pl.col(Column.source_index_column) == idx_col).then(
idx_col
)
for idx_col in source_idx_cols
).alias(Column.source_idxs)
)
.group_by(Column.source_id)
.agg(Column.source_config, Column.source_idxs)
.filter(
True
if _ALL in subkeys_data
else pl.col(Column.source_id).is_in(subkeys_data)
)
)
.group_by(Column.source_id)
.agg(Column.source_config, Column.source_idxs)
)

tensor_data: Mapping[str, torch.Tensor] = {
row[Column.source_id]: pad_sequence(
[
self._get_source(source)[idxs]
for (source, idxs) in zip(
row[Column.source_config], row[Column.source_idxs], strict=True
)
],
dim=1,
value=torch.nan,
source_data = {
row[Column.source_id]: pad_sequence(
[
self._get_source(source)[idxs]
for (source, idxs) in zip(
row[Column.source_config],
row[Column.source_idxs],
strict=True,
)
],
dim=1,
value=torch.nan,
)
for row in sources.collect().iter_rows(named=True)
}

sample_data_cols = (
pl.all()
if _ALL in subkeys_data
else pl.col(subkeys_data - source_data.keys()) # pyright: ignore[reportArgumentType]
).exclude(Column.sample_idx, Column.input_id)

sample_data = samples.select(sample_data_cols.to_physical()).to_dict(
as_series=False
)
for row in sources.collect().iter_rows(named=True)
}

sample_data: Mapping[str, Sequence[object]] = samples.select(
pl.exclude(Column.sample_idx, Column.input_id).to_physical()
).to_dict(as_series=False)
data = TensorDict(source_data | sample_data, batch_size=batch_size) # pyright: ignore[reportArgumentType]

else:
data = None

if subkeys_meta := subkeys["meta"]:
meta = BatchMeta(
sample_idx=(
samples[Column.sample_idx].to_torch()
if _ALL in subkeys_meta or "sample_idx" in subkeys_meta
else None
),
input_id=(
samples[Column.input_id].to_list()
if _ALL in subkeys_meta or "input_id" in subkeys_meta
else None
),
batch_size=batch_size,
)
else:
meta = None

data = TensorDict(tensor_data | sample_data, batch_size=batch_size) # pyright: ignore[reportArgumentType]
return Batch(data=data, meta=meta, batch_size=batch_size)

meta = BatchMeta(
sample_idx=samples[Column.sample_idx].to_torch(), # pyright: ignore[reportCallIssue]
input_id=samples[Column.input_id].to_list(), # pyright: ignore[reportCallIssue]
batch_size=batch_size, # pyright: ignore[reportCallIssue]
)
def __getitems__(self, index: Sequence[int]) -> Batch: # noqa: PLW3201
return self.get_batch(index)

return Batch(data=data, meta=meta, batch_size=batch_size) # pyright: ignore[reportCallIssue]
@override
def __getitem__(self, index: int) -> Batch:
return self.get_batch(index)

def __len__(self) -> int:
return len(self.samples)
2 changes: 1 addition & 1 deletion src/rbyte/io/yaak/idl-repo
4 changes: 2 additions & 2 deletions src/rbyte/viz/loggers/rerun_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def _build_components(

@override
def log(self, batch_idx: int, batch: Batch) -> None:
for i, sample in enumerate(batch.data): # pyright: ignore[reportUnknownVariableType]
with self._get_recording(batch.meta.input_id[i]): # pyright: ignore[reportUnknownArgumentType, reportIndexIssue]
for i, sample in enumerate(batch.data): # pyright: ignore[reportArgumentType, reportUnknownVariableType]
with self._get_recording(batch.meta.input_id[i]): # pyright: ignore[reportUnknownArgumentType, reportOptionalSubscript, reportUnknownMemberType, reportOptionalMemberAccess]
times: Sequence[TimeColumn] = [
column(
timeline=timeline,
Expand Down
Loading

0 comments on commit df252dd

Please sign in to comment.