diff --git a/config/_templates/dataset/mimicgen.yaml b/config/_templates/dataset/mimicgen.yaml index 392f6cd..a0f93c6 100644 --- a/config/_templates/dataset/mimicgen.yaml +++ b/config/_templates/dataset/mimicgen.yaml @@ -42,6 +42,7 @@ samples: executor: _target_: concurrent.futures.ThreadPoolExecutor + storage: dict pipeline: _target_: pipefunc.Pipeline validate_type_annotations: false diff --git a/config/_templates/dataset/nuscenes/mcap.yaml b/config/_templates/dataset/nuscenes/mcap.yaml index 235fa55..886ba8b 100644 --- a/config/_templates/dataset/nuscenes/mcap.yaml +++ b/config/_templates/dataset/nuscenes/mcap.yaml @@ -43,6 +43,7 @@ samples: executor: _target_: concurrent.futures.ThreadPoolExecutor + storage: dict pipeline: _target_: pipefunc.Pipeline validate_type_annotations: false diff --git a/config/_templates/dataset/nuscenes/rrd.yaml b/config/_templates/dataset/nuscenes/rrd.yaml index a46c0ee..1239073 100644 --- a/config/_templates/dataset/nuscenes/rrd.yaml +++ b/config/_templates/dataset/nuscenes/rrd.yaml @@ -43,6 +43,7 @@ samples: executor: _target_: concurrent.futures.ThreadPoolExecutor + storage: dict pipeline: _target_: pipefunc.Pipeline validate_type_annotations: false diff --git a/config/_templates/dataset/yaak.yaml b/config/_templates/dataset/yaak.yaml index e1eaa70..8cb723b 100644 --- a/config/_templates/dataset/yaak.yaml +++ b/config/_templates/dataset/yaak.yaml @@ -40,6 +40,7 @@ samples: executor: _target_: concurrent.futures.ThreadPoolExecutor + storage: dict pipeline: _target_: pipefunc.Pipeline validate_type_annotations: false diff --git a/config/_templates/dataset/zod.yaml b/config/_templates/dataset/zod.yaml index b04b93e..dd6f5d3 100644 --- a/config/_templates/dataset/zod.yaml +++ b/config/_templates/dataset/zod.yaml @@ -33,6 +33,7 @@ samples: executor: _target_: concurrent.futures.ThreadPoolExecutor + storage: dict pipeline: _target_: pipefunc.Pipeline validate_type_annotations: false diff --git a/pyproject.toml b/pyproject.toml index 1997f1c..96945d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rbyte" -version = "0.14.0" +version = "0.14.1" description = "Multimodal PyTorch dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] @@ -16,7 +16,7 @@ dependencies = [ "cachetools>=5.5.1", "structlog>=25.1.0", "tqdm>=4.67.1", - "pipefunc[autodoc]>=0.56.0", + "pipefunc[autodoc]>=0.57.0", "xxhash>=3.5.0", ] readme = "README.md" diff --git a/src/rbyte/dataset.py b/src/rbyte/dataset.py index d198700..3c54685 100644 --- a/src/rbyte/dataset.py +++ b/src/rbyte/dataset.py @@ -2,15 +2,13 @@ from concurrent.futures import Executor from enum import StrEnum, unique from functools import cache -from pathlib import Path -from typing import Annotated, Any, Literal, override +from typing import Annotated, Any, ClassVar, Literal, override import polars as pl import torch from optree import tree_map, tree_structure, tree_transpose from pipefunc import Pipeline -from pipefunc._pipeline._types import OUTPUT_TYPE, StorageType -from pipefunc.map import run_map +from pipefunc._pipeline._types import OUTPUT_TYPE from pydantic import ConfigDict, StringConstraints, validate_call from structlog import get_logger from tensordict import TensorDict @@ -39,20 +37,14 @@ class SourceConfig(BaseModel): class PipelineConfig(BaseModel): + model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow") + pipeline: HydraConfig[Pipeline] inputs: Mapping[str, Any] - run_folder: str | Path | None = None - parallel: bool = True executor: ( HydraConfig[Executor] | dict[OUTPUT_TYPE, HydraConfig[Executor]] | None ) = None - chunksizes: int | dict[OUTPUT_TYPE, int] | None = None - storage: StorageType = "dict" - persist_memory: bool = True - cleanup: bool = True - fixed_indices: dict[str, int | slice] | None = None - auto_subpipeline: bool = False - show_progress: bool = False + return_results: Literal[True] = True @unique @@ -226,8 +218,7 @@ def _build_samples(cls, samples: PipelineConfig) -> pl.DataFrame: samples.executor, # pyright: ignore[reportArgumentType] ) - results = run_map( - pipeline=pipeline, + results = pipeline.map( inputs=inputs, executor=executor, **samples.model_dump(exclude={"pipeline", "inputs", "executor"}),