diff --git a/examples/config_templates/frame_reader/video.yaml b/examples/config_templates/frame_reader/video/ffmpeg.yaml similarity index 65% rename from examples/config_templates/frame_reader/video.yaml rename to examples/config_templates/frame_reader/video/ffmpeg.yaml index 3cc68ca..5107b03 100644 --- a/examples/config_templates/frame_reader/video.yaml +++ b/examples/config_templates/frame_reader/video/ffmpeg.yaml @@ -1,5 +1,5 @@ --- -_target_: rbyte.io.frame.VideoFrameReader +_target_: rbyte.io.frame.FfmpegFrameReader path: ??? threads: !!null resize_shorter_side: !!null diff --git a/examples/config_templates/frame_reader/video/vali.yaml b/examples/config_templates/frame_reader/video/vali.yaml new file mode 100644 index 0000000..af7a629 --- /dev/null +++ b/examples/config_templates/frame_reader/video/vali.yaml @@ -0,0 +1,5 @@ +--- +_target_: rbyte.io.frame.ValiGpuFrameReader +_convert_: all +path: ??? +pixel_format_chain: [NV12] diff --git a/examples/config_templates/read_frames.yaml b/examples/config_templates/read_frames.yaml index 660dd26..1cf3fde 100644 --- a/examples/config_templates/read_frames.yaml +++ b/examples/config_templates/read_frames.yaml @@ -4,11 +4,12 @@ defaults: - _self_ batch_size: 1 -application_id: rbyte-read-frames -entity_path: ??? +application_id: rbyte +entity_path: frames frame_config: Image: - color_model: RGB + pixel_format: !!null + color_model: !!null hydra: output_subdir: !!null diff --git a/justfile b/justfile index b1022ed..454061e 100644 --- a/justfile +++ b/justfile @@ -62,7 +62,7 @@ read-frames *ARGS: generate-example-config # rerun server and viewer rerun bind="0.0.0.0" port="9876" ws-server-port="9877" web-viewer-port="9090": - uv run rerun \ + RUST_LOG=debug uv run rerun \ --bind {{ bind }} \ --port {{ port }} \ --ws-server-port {{ ws-server-port }} \ diff --git a/pyproject.toml b/pyproject.toml index dcb780d..3054e04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rbyte" -version = "0.3.0" +version = "0.4.0" description = "Multimodal dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] @@ -44,7 +44,7 @@ mcap = [ ] yaak = ["protobuf", "ptars>=0.0.2"] jpeg = ["simplejpeg>=1.7.6"] -video = ["video-reader-rs>=0.1.4"] +video = ["python-vali>=4.2.0.post0", "video-reader-rs>=0.1.5"] hdf5 = ["h5py>=3.12.1"] [project.scripts] diff --git a/src/rbyte/io/frame/__init__.py b/src/rbyte/io/frame/__init__.py index 97a88a5..b60c385 100644 --- a/src/rbyte/io/frame/__init__.py +++ b/src/rbyte/io/frame/__init__.py @@ -2,12 +2,6 @@ __all__ = ["DirectoryFrameReader"] -try: - from .video import VideoFrameReader -except ImportError: - pass -else: - __all__ += ["VideoFrameReader"] try: from .mcap import McapFrameReader @@ -22,3 +16,17 @@ pass else: __all__ += ["Hdf5FrameReader"] + +try: + from .video.ffmpeg_reader import FfmpegFrameReader +except ImportError: + pass +else: + __all__ += ["FfmpegFrameReader"] + +try: + from .video.vali_reader import ValiGpuFrameReader +except ImportError: + pass +else: + __all__ += ["ValiGpuFrameReader"] diff --git a/src/rbyte/io/frame/base.py b/src/rbyte/io/frame/base.py index 11cc562..2b2d1e3 100644 --- a/src/rbyte/io/frame/base.py +++ b/src/rbyte/io/frame/base.py @@ -7,5 +7,7 @@ @runtime_checkable class FrameReader(Protocol): - def read(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]: ... + def read( + self, indexes: Iterable[int] + ) -> Shaped[Tensor, "b h w c"] | Shaped[Tensor, "b c h w"]: ... def get_available_indexes(self) -> Sequence[int]: ... diff --git a/src/rbyte/io/frame/video/__init__.py b/src/rbyte/io/frame/video/__init__.py index 405f40f..5e2a442 100644 --- a/src/rbyte/io/frame/video/__init__.py +++ b/src/rbyte/io/frame/video/__init__.py @@ -1,3 +1,17 @@ -from .reader import VideoFrameReader +__all__: list[str] = [] -__all__ = ["VideoFrameReader"] +try: + from .ffmpeg_reader import FfmpegFrameReader +except ImportError: + pass + +else: + __all__ += ["FfmpegFrameReader"] + +try: + from .vali_reader import ValiGpuFrameReader +except ImportError: + pass + +else: + __all__ += ["ValiGpuFrameReader"] diff --git a/src/rbyte/io/frame/video/reader.py b/src/rbyte/io/frame/video/ffmpeg_reader.py similarity index 77% rename from src/rbyte/io/frame/video/reader.py rename to src/rbyte/io/frame/video/ffmpeg_reader.py index a3ad0c1..55109cb 100644 --- a/src/rbyte/io/frame/video/reader.py +++ b/src/rbyte/io/frame/video/ffmpeg_reader.py @@ -1,23 +1,22 @@ from collections.abc import Callable, Iterable, Sequence from functools import partial -from os import PathLike from pathlib import Path from typing import override import torch import video_reader as vr from jaxtyping import UInt8 -from pydantic import NonNegativeInt, validate_call +from pydantic import FilePath, NonNegativeInt, validate_call from torch import Tensor from rbyte.io.frame.base import FrameReader -class VideoFrameReader(FrameReader): +class FfmpegFrameReader(FrameReader): @validate_call def __init__( self, - path: PathLike[str], + path: FilePath, threads: NonNegativeInt | None = None, resize_shorter_side: NonNegativeInt | None = None, with_fallback: bool | None = None, # noqa: FBT001 @@ -42,6 +41,6 @@ def read(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]: @override def get_available_indexes(self) -> Sequence[int]: - num_frames, *_ = vr.get_shape(self._path) # pyright: ignore[reportAttributeAccessIssue, reportUnknownVariableType, reportUnknownMemberType] + num_frames = int(vr.get_info(self._path)["frame_count"]) # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportUnknownMemberType] - return range(num_frames) # pyright: ignore[reportUnknownArgumentType] + return range(num_frames) diff --git a/src/rbyte/io/frame/video/vali_reader.py b/src/rbyte/io/frame/video/vali_reader.py new file mode 100644 index 0000000..154f7a4 --- /dev/null +++ b/src/rbyte/io/frame/video/vali_reader.py @@ -0,0 +1,116 @@ +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from itertools import pairwise +from typing import Annotated, override + +import more_itertools as mit +import python_vali as vali +import torch +from jaxtyping import Shaped +from pydantic import ( + BeforeValidator, + ConfigDict, + FilePath, + NonNegativeInt, + validate_call, +) +from structlog import get_logger +from torch import Tensor + +from rbyte.io.frame.base import FrameReader + +logger = get_logger(__name__) + +PixelFormat = Annotated[ + vali.PixelFormat, + BeforeValidator( + lambda x: x if isinstance(x, vali.PixelFormat) else getattr(vali.PixelFormat, x) + ), +] + + +class ValiGpuFrameReader(FrameReader): + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__( + self, + path: FilePath, + gpu_id: NonNegativeInt = 0, + pixel_format_chain: tuple[PixelFormat, ...] = ( + vali.PixelFormat.RGB, + vali.PixelFormat.RGB_PLANAR, + ), + ) -> None: + super().__init__() + + self._gpu_id = gpu_id + + self._decoder = vali.PyDecoder( + input=path.resolve().as_posix(), opts={}, gpu_id=self._gpu_id + ) + + self._pixel_format_chain = ( + (self._decoder.Format, *pixel_format_chain) + if mit.first(pixel_format_chain, default=None) != self._decoder.Format + else pixel_format_chain + ) + + @cached_property + def _surface_converters( + self, + ) -> Mapping[tuple[vali.PixelFormat, vali.PixelFormat], vali.PySurfaceConverter]: + return { + (src_format, dst_format): vali.PySurfaceConverter( + src_format=src_format, dst_format=dst_format, gpu_id=self._gpu_id + ) + for src_format, dst_format in pairwise(self._pixel_format_chain) + } + + @cached_property + def _surfaces(self) -> Mapping[vali.PixelFormat, vali.Surface]: + return { + pixel_format: vali.Surface.Make( + format=pixel_format, + width=self._decoder.Width, + height=self._decoder.Height, + gpu_id=self._gpu_id, + ) + for pixel_format in self._pixel_format_chain + } + + def _read(self, index: int) -> Shaped[Tensor, "c h w"] | Shaped[Tensor, "h w c"]: + seek_ctx = vali.SeekContext(seek_frame=index) + success, details = self._decoder.DecodeSingleSurface( # pyright: ignore[reportUnknownMemberType] + self._surfaces[self._decoder.Format], seek_ctx + ) + if not success: + logger.error(msg := "failed to decode surface", details=details) + + raise RuntimeError(msg) + + for (src_format, dst_format), converter in self._surface_converters.items(): + success, details = converter.Run( # pyright: ignore[reportUnknownMemberType] + (src := self._surfaces[src_format]), (dst := self._surfaces[dst_format]) + ) + if not success: + logger.error( + msg := "failed to convert surface", + src=src, + dst=dst, + details=details, + ) + + raise RuntimeError(msg) + + surface = self._surfaces[self._pixel_format_chain[-1]] + + return torch.from_dlpack(surface).clone().detach() # pyright: ignore[reportPrivateImportUsage] + + @override + def read( + self, indexes: Iterable[int] + ) -> Shaped[Tensor, "b h w c"] | Shaped[Tensor, "b c h w"]: + return torch.stack([self._read(index) for index in indexes]) + + @override + def get_available_indexes(self) -> Sequence[int]: + return range(self._decoder.NumFrames) diff --git a/src/rbyte/scripts/read_frames.py b/src/rbyte/scripts/read_frames.py index bce80ac..f370f24 100644 --- a/src/rbyte/scripts/read_frames.py +++ b/src/rbyte/scripts/read_frames.py @@ -5,25 +5,34 @@ import numpy as np import numpy.typing as npt import rerun as rr -from hydra.utils import instantiate -from omegaconf import DictConfig -from pydantic import TypeAdapter +from omegaconf import DictConfig, OmegaConf +from pydantic import ConfigDict, NonNegativeInt from structlog import get_logger from structlog.contextvars import bound_contextvars from tqdm import tqdm +from rbyte.config.base import BaseModel, HydraConfig from rbyte.io.frame.base import FrameReader from rbyte.viz.loggers.rerun_logger import FrameConfig logger = get_logger(__name__) +class Config(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + frame_reader: HydraConfig[FrameReader] + frame_config: FrameConfig + application_id: str + entity_path: str + batch_size: NonNegativeInt = 1 + + @hydra.main(version_base=None) -def main(config: DictConfig) -> None: - frame_reader = cast(FrameReader, instantiate(config.frame_reader)) - frame_config = cast( - FrameConfig, TypeAdapter(FrameConfig).validate_python(config.frame_config) - ) +def main(_config: DictConfig) -> None: + config = Config.model_validate(OmegaConf.to_object(_config)) + frame_reader = config.frame_reader.instantiate() + frame_config = config.frame_config rr.init(config.application_id, spawn=True) @@ -32,47 +41,61 @@ def main(config: DictConfig) -> None: config.batch_size, strict=False, ): - with bound_contextvars(frame_config=frame_config): + tensor = frame_reader.read(frame_indexes) + + with bound_contextvars(frame_config=frame_config, shape=tensor.shape): match frame_config: case {rr.Image: image_format} | {rr.DepthImage: image_format}: - arr = cast( - npt.NDArray[Any], - frame_reader.read(frame_indexes).cpu().numpy(), # pyright: ignore[reportUnknownMemberType] + match ( + image_format.pixel_format, + image_format.color_model, + tensor.shape, + ): + case None, color_model, shape: + match color_model, shape: + case ( + (rr.ColorModel.L, (batch, height, width, 1)) + | (rr.ColorModel.RGB, (batch, height, width, 3)) + | (rr.ColorModel.RGBA, (batch, height, width, 4)) + ): + pass + + case ( + (rr.ColorModel.L, (batch, 1, height, width)) + | (rr.ColorModel.RGB, (batch, 3, height, width)) + | (rr.ColorModel.RGBA, (batch, 4, height, width)) + ): + tensor = tensor.permute(0, 2, 3, 1) + + case _: + logger.error("not implemented") + + raise NotImplementedError + + case rr.PixelFormat.NV12, _, (batch, dim, width): + height = int(dim / 1.5) + + case _: + logger.error("not implemented") + + raise NotImplementedError + + arr = cast(npt.NDArray[Any], tensor.cpu().numpy()) # pyright: ignore[reportUnknownMemberType] + image_format = rr.components.ImageFormat( + height=height, + width=width, + pixel_format=image_format.pixel_format, + color_model=image_format.color_model, + channel_datatype=rr.ChannelDatatype.from_np_dtype(arr.dtype), ) - with bound_contextvars(image_format=image_format, shape=arr.shape): - match ( - image_format.pixel_format, - image_format.color_model, - arr.shape, - ): - case None, rr.ColorModel(), (batch, height, width, _): - pass - - case rr.PixelFormat.NV12, None, (batch, dim, width): - height = int(dim / 1.5) - - case _: - logger.error("not implemented") - - raise NotImplementedError - - image_format = rr.components.ImageFormat( - height=height, - width=width, - pixel_format=image_format.pixel_format, - color_model=image_format.color_model, - channel_datatype=rr.ChannelDatatype.from_np_dtype( - arr.dtype - ), - ) - - components = [ - mit.one(frame_config).indicator(), - rr.components.ImageFormatBatch([image_format] * batch), - rr.components.ImageBufferBatch( - arr.reshape(batch, -1).view(np.uint8) - ), - ] + + components = [ + mit.one(frame_config).indicator(), + rr.components.ImageFormatBatch([image_format] * batch), + rr.components.ImageBufferBatch( + arr.reshape(batch, -1).view(np.uint8) + ), + ] case _: logger.error("not implemented") diff --git a/src/rbyte/viz/loggers/rerun_logger.py b/src/rbyte/viz/loggers/rerun_logger.py index 852d468..62daceb 100644 --- a/src/rbyte/viz/loggers/rerun_logger.py +++ b/src/rbyte/viz/loggers/rerun_logger.py @@ -53,7 +53,7 @@ class ImageFormat(BaseModel): @model_validator(mode="after") def validate_model(self: Self) -> Self: if not (bool(self.pixel_format) ^ bool(self.color_model)): - msg = "either pixel_format or color_model must be specified" + msg = "pixel_format xor color_model must be specified" raise ValueError(msg) return self