Skip to content

Commit

Permalink
feat: add VALI-based (GPU) video reader
Browse files Browse the repository at this point in the history
  • Loading branch information
egorchakov committed Oct 14, 2024
1 parent aed93a1 commit 8a1d8bb
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
_target_: rbyte.io.frame.VideoFrameReader
_target_: rbyte.io.frame.FfmpegFrameReader
path: ???
threads: !!null
resize_shorter_side: !!null
Expand Down
5 changes: 5 additions & 0 deletions examples/config_templates/frame_reader/video/vali.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
_target_: rbyte.io.frame.ValiGpuFrameReader
_convert_: all
path: ???
pixel_format_chain: [NV12]
7 changes: 4 additions & 3 deletions examples/config_templates/read_frames.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ defaults:
- _self_

batch_size: 1
application_id: rbyte-read-frames
entity_path: ???
application_id: rbyte
entity_path: frames
frame_config:
Image:
color_model: RGB
pixel_format: !!null
color_model: !!null

hydra:
output_subdir: !!null
Expand Down
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ read-frames *ARGS: generate-example-config

# rerun server and viewer
rerun bind="0.0.0.0" port="9876" ws-server-port="9877" web-viewer-port="9090":
uv run rerun \
RUST_LOG=debug uv run rerun \
--bind {{ bind }} \
--port {{ port }} \
--ws-server-port {{ ws-server-port }} \
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rbyte"
version = "0.3.0"
version = "0.4.0"
description = "Multimodal dataset library"
authors = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
Expand Down Expand Up @@ -44,7 +44,7 @@ mcap = [
]
yaak = ["protobuf", "ptars>=0.0.2"]
jpeg = ["simplejpeg>=1.7.6"]
video = ["video-reader-rs>=0.1.4"]
video = ["python-vali>=4.2.0.post0", "video-reader-rs>=0.1.5"]
hdf5 = ["h5py>=3.12.1"]

[project.scripts]
Expand Down
20 changes: 14 additions & 6 deletions src/rbyte/io/frame/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@

__all__ = ["DirectoryFrameReader"]

try:
from .video import VideoFrameReader
except ImportError:
pass
else:
__all__ += ["VideoFrameReader"]

try:
from .mcap import McapFrameReader
Expand All @@ -22,3 +16,17 @@
pass
else:
__all__ += ["Hdf5FrameReader"]

try:
from .video.ffmpeg_reader import FfmpegFrameReader
except ImportError:
pass
else:
__all__ += ["FfmpegFrameReader"]

try:
from .video.vali_reader import ValiGpuFrameReader
except ImportError:
pass
else:
__all__ += ["ValiGpuFrameReader"]
4 changes: 3 additions & 1 deletion src/rbyte/io/frame/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@

@runtime_checkable
class FrameReader(Protocol):
def read(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]: ...
def read(
self, indexes: Iterable[int]
) -> Shaped[Tensor, "b h w c"] | Shaped[Tensor, "b c h w"]: ...
def get_available_indexes(self) -> Sequence[int]: ...
18 changes: 16 additions & 2 deletions src/rbyte/io/frame/video/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
from .reader import VideoFrameReader
__all__: list[str] = []

__all__ = ["VideoFrameReader"]
try:
from .ffmpeg_reader import FfmpegFrameReader
except ImportError:
pass

else:
__all__ += ["FfmpegFrameReader"]

try:
from .vali_reader import ValiGpuFrameReader
except ImportError:
pass

else:
__all__ += ["ValiGpuFrameReader"]
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from collections.abc import Callable, Iterable, Sequence
from functools import partial
from os import PathLike
from pathlib import Path
from typing import override

import torch
import video_reader as vr
from jaxtyping import UInt8
from pydantic import NonNegativeInt, validate_call
from pydantic import FilePath, NonNegativeInt, validate_call
from torch import Tensor

from rbyte.io.frame.base import FrameReader


class VideoFrameReader(FrameReader):
class FfmpegFrameReader(FrameReader):
@validate_call
def __init__(
self,
path: PathLike[str],
path: FilePath,
threads: NonNegativeInt | None = None,
resize_shorter_side: NonNegativeInt | None = None,
with_fallback: bool | None = None, # noqa: FBT001
Expand All @@ -42,6 +41,6 @@ def read(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]:

@override
def get_available_indexes(self) -> Sequence[int]:
num_frames, *_ = vr.get_shape(self._path) # pyright: ignore[reportAttributeAccessIssue, reportUnknownVariableType, reportUnknownMemberType]
num_frames = int(vr.get_info(self._path)["frame_count"]) # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportUnknownMemberType]

return range(num_frames) # pyright: ignore[reportUnknownArgumentType]
return range(num_frames)
116 changes: 116 additions & 0 deletions src/rbyte/io/frame/video/vali_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from itertools import pairwise
from typing import Annotated, override

import more_itertools as mit
import python_vali as vali
import torch
from jaxtyping import Shaped
from pydantic import (
BeforeValidator,
ConfigDict,
FilePath,
NonNegativeInt,
validate_call,
)
from structlog import get_logger
from torch import Tensor

from rbyte.io.frame.base import FrameReader

logger = get_logger(__name__)

PixelFormat = Annotated[
vali.PixelFormat,
BeforeValidator(
lambda x: x if isinstance(x, vali.PixelFormat) else getattr(vali.PixelFormat, x)
),
]


class ValiGpuFrameReader(FrameReader):
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(
self,
path: FilePath,
gpu_id: NonNegativeInt = 0,
pixel_format_chain: tuple[PixelFormat, ...] = (
vali.PixelFormat.RGB,
vali.PixelFormat.RGB_PLANAR,
),
) -> None:
super().__init__()

self._gpu_id = gpu_id

self._decoder = vali.PyDecoder(
input=path.resolve().as_posix(), opts={}, gpu_id=self._gpu_id
)

self._pixel_format_chain = (
(self._decoder.Format, *pixel_format_chain)
if mit.first(pixel_format_chain, default=None) != self._decoder.Format
else pixel_format_chain
)

@cached_property
def _surface_converters(
self,
) -> Mapping[tuple[vali.PixelFormat, vali.PixelFormat], vali.PySurfaceConverter]:
return {
(src_format, dst_format): vali.PySurfaceConverter(
src_format=src_format, dst_format=dst_format, gpu_id=self._gpu_id
)
for src_format, dst_format in pairwise(self._pixel_format_chain)
}

@cached_property
def _surfaces(self) -> Mapping[vali.PixelFormat, vali.Surface]:
return {
pixel_format: vali.Surface.Make(
format=pixel_format,
width=self._decoder.Width,
height=self._decoder.Height,
gpu_id=self._gpu_id,
)
for pixel_format in self._pixel_format_chain
}

def _read(self, index: int) -> Shaped[Tensor, "c h w"] | Shaped[Tensor, "h w c"]:
seek_ctx = vali.SeekContext(seek_frame=index)
success, details = self._decoder.DecodeSingleSurface( # pyright: ignore[reportUnknownMemberType]
self._surfaces[self._decoder.Format], seek_ctx
)
if not success:
logger.error(msg := "failed to decode surface", details=details)

raise RuntimeError(msg)

for (src_format, dst_format), converter in self._surface_converters.items():
success, details = converter.Run( # pyright: ignore[reportUnknownMemberType]
(src := self._surfaces[src_format]), (dst := self._surfaces[dst_format])
)
if not success:
logger.error(
msg := "failed to convert surface",
src=src,
dst=dst,
details=details,
)

raise RuntimeError(msg)

surface = self._surfaces[self._pixel_format_chain[-1]]

return torch.from_dlpack(surface).clone().detach() # pyright: ignore[reportPrivateImportUsage]

@override
def read(
self, indexes: Iterable[int]
) -> Shaped[Tensor, "b h w c"] | Shaped[Tensor, "b c h w"]:
return torch.stack([self._read(index) for index in indexes])

@override
def get_available_indexes(self) -> Sequence[int]:
return range(self._decoder.NumFrames)
Loading

0 comments on commit 8a1d8bb

Please sign in to comment.