Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add VALI-based (GPU) video reader #12

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
_target_: rbyte.io.frame.VideoFrameReader
_target_: rbyte.io.frame.FfmpegFrameReader
path: ???
threads: !!null
resize_shorter_side: !!null
Expand Down
5 changes: 5 additions & 0 deletions examples/config_templates/frame_reader/video/vali.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
_target_: rbyte.io.frame.ValiGpuFrameReader
_convert_: all
path: ???
pixel_format_chain: [NV12]
7 changes: 4 additions & 3 deletions examples/config_templates/read_frames.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ defaults:
- _self_

batch_size: 1
application_id: rbyte-read-frames
entity_path: ???
application_id: rbyte
entity_path: frames
frame_config:
Image:
color_model: RGB
pixel_format: !!null
color_model: !!null

hydra:
output_subdir: !!null
Expand Down
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ read-frames *ARGS: generate-example-config

# rerun server and viewer
rerun bind="0.0.0.0" port="9876" ws-server-port="9877" web-viewer-port="9090":
uv run rerun \
RUST_LOG=debug uv run rerun \
--bind {{ bind }} \
--port {{ port }} \
--ws-server-port {{ ws-server-port }} \
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rbyte"
version = "0.3.0"
version = "0.4.0"
description = "Multimodal dataset library"
authors = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
Expand Down Expand Up @@ -44,7 +44,7 @@ mcap = [
]
yaak = ["protobuf", "ptars>=0.0.2"]
jpeg = ["simplejpeg>=1.7.6"]
video = ["video-reader-rs>=0.1.4"]
video = ["python-vali>=4.2.0.post0", "video-reader-rs>=0.1.5"]
hdf5 = ["h5py>=3.12.1"]

[project.scripts]
Expand Down
20 changes: 14 additions & 6 deletions src/rbyte/io/frame/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@

__all__ = ["DirectoryFrameReader"]

try:
from .video import VideoFrameReader
except ImportError:
pass
else:
__all__ += ["VideoFrameReader"]

try:
from .mcap import McapFrameReader
Expand All @@ -22,3 +16,17 @@
pass
else:
__all__ += ["Hdf5FrameReader"]

try:
from .video.ffmpeg_reader import FfmpegFrameReader
except ImportError:
pass
else:
__all__ += ["FfmpegFrameReader"]

try:
from .video.vali_reader import ValiGpuFrameReader
except ImportError:
pass
else:
__all__ += ["ValiGpuFrameReader"]
4 changes: 3 additions & 1 deletion src/rbyte/io/frame/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@

@runtime_checkable
class FrameReader(Protocol):
def read(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]: ...
def read(
self, indexes: Iterable[int]
) -> Shaped[Tensor, "b h w c"] | Shaped[Tensor, "b c h w"]: ...
def get_available_indexes(self) -> Sequence[int]: ...
18 changes: 16 additions & 2 deletions src/rbyte/io/frame/video/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
from .reader import VideoFrameReader
__all__: list[str] = []

__all__ = ["VideoFrameReader"]
try:
from .ffmpeg_reader import FfmpegFrameReader
except ImportError:
pass

else:
__all__ += ["FfmpegFrameReader"]

try:
from .vali_reader import ValiGpuFrameReader
except ImportError:
pass

else:
__all__ += ["ValiGpuFrameReader"]
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from collections.abc import Callable, Iterable, Sequence
from functools import partial
from os import PathLike
from pathlib import Path
from typing import override

import torch
import video_reader as vr
from jaxtyping import UInt8
from pydantic import NonNegativeInt, validate_call
from pydantic import FilePath, NonNegativeInt, validate_call
from torch import Tensor

from rbyte.io.frame.base import FrameReader


class VideoFrameReader(FrameReader):
class FfmpegFrameReader(FrameReader):
@validate_call
def __init__(
self,
path: PathLike[str],
path: FilePath,
threads: NonNegativeInt | None = None,
resize_shorter_side: NonNegativeInt | None = None,
with_fallback: bool | None = None, # noqa: FBT001
Expand All @@ -42,6 +41,6 @@ def read(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]:

@override
def get_available_indexes(self) -> Sequence[int]:
num_frames, *_ = vr.get_shape(self._path) # pyright: ignore[reportAttributeAccessIssue, reportUnknownVariableType, reportUnknownMemberType]
num_frames = int(vr.get_info(self._path)["frame_count"]) # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportUnknownMemberType]

return range(num_frames) # pyright: ignore[reportUnknownArgumentType]
return range(num_frames)
116 changes: 116 additions & 0 deletions src/rbyte/io/frame/video/vali_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from itertools import pairwise
from typing import Annotated, override

import more_itertools as mit
import python_vali as vali
import torch
from jaxtyping import Shaped
from pydantic import (
BeforeValidator,
ConfigDict,
FilePath,
NonNegativeInt,
validate_call,
)
from structlog import get_logger
from torch import Tensor

from rbyte.io.frame.base import FrameReader

logger = get_logger(__name__)

PixelFormat = Annotated[
vali.PixelFormat,
BeforeValidator(
lambda x: x if isinstance(x, vali.PixelFormat) else getattr(vali.PixelFormat, x)
),
]


class ValiGpuFrameReader(FrameReader):
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(
self,
path: FilePath,
gpu_id: NonNegativeInt = 0,
pixel_format_chain: tuple[PixelFormat, ...] = (
vali.PixelFormat.RGB,
vali.PixelFormat.RGB_PLANAR,
),
) -> None:
super().__init__()

self._gpu_id = gpu_id

self._decoder = vali.PyDecoder(
input=path.resolve().as_posix(), opts={}, gpu_id=self._gpu_id
)

self._pixel_format_chain = (
(self._decoder.Format, *pixel_format_chain)
if mit.first(pixel_format_chain, default=None) != self._decoder.Format
else pixel_format_chain
)

@cached_property
def _surface_converters(
self,
) -> Mapping[tuple[vali.PixelFormat, vali.PixelFormat], vali.PySurfaceConverter]:
return {
(src_format, dst_format): vali.PySurfaceConverter(
src_format=src_format, dst_format=dst_format, gpu_id=self._gpu_id
)
for src_format, dst_format in pairwise(self._pixel_format_chain)
}

@cached_property
def _surfaces(self) -> Mapping[vali.PixelFormat, vali.Surface]:
return {
pixel_format: vali.Surface.Make(
format=pixel_format,
width=self._decoder.Width,
height=self._decoder.Height,
gpu_id=self._gpu_id,
)
for pixel_format in self._pixel_format_chain
}

def _read(self, index: int) -> Shaped[Tensor, "c h w"] | Shaped[Tensor, "h w c"]:
seek_ctx = vali.SeekContext(seek_frame=index)
success, details = self._decoder.DecodeSingleSurface( # pyright: ignore[reportUnknownMemberType]
self._surfaces[self._decoder.Format], seek_ctx
)
if not success:
logger.error(msg := "failed to decode surface", details=details)

raise RuntimeError(msg)

for (src_format, dst_format), converter in self._surface_converters.items():
success, details = converter.Run( # pyright: ignore[reportUnknownMemberType]
(src := self._surfaces[src_format]), (dst := self._surfaces[dst_format])
)
if not success:
logger.error(
msg := "failed to convert surface",
src=src,
dst=dst,
details=details,
)

raise RuntimeError(msg)

surface = self._surfaces[self._pixel_format_chain[-1]]

return torch.from_dlpack(surface).clone().detach() # pyright: ignore[reportPrivateImportUsage]

@override
def read(
self, indexes: Iterable[int]
) -> Shaped[Tensor, "b h w c"] | Shaped[Tensor, "b c h w"]:
return torch.stack([self._read(index) for index in indexes])

@override
def get_available_indexes(self) -> Sequence[int]:
return range(self._decoder.NumFrames)
Loading
Loading