Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: image format config for rerun logger #13

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,18 @@ logger:
_target_: rbyte.viz.loggers.RerunLogger
schema:
frame:
/CAM_FRONT/image_rect_compressed: rerun.components.ImageBufferBatch
/CAM_FRONT_LEFT/image_rect_compressed: rerun.components.ImageBufferBatch
/CAM_FRONT_RIGHT/image_rect_compressed: rerun.components.ImageBufferBatch
/CAM_FRONT/image_rect_compressed:
rerun.components.ImageBufferBatch:
color_model: RGB

/CAM_FRONT_LEFT/image_rect_compressed:
rerun.components.ImageBufferBatch:
color_model: RGB

/CAM_FRONT_RIGHT/image_rect_compressed:
rerun.components.ImageBufferBatch:
color_model: RGB

table:
/CAM_FRONT/image_rect_compressed/log_time: rerun.TimeNanosColumn
/CAM_FRONT/image_rect_compressed/idx: rerun.TimeSequenceColumn
Expand Down
4 changes: 3 additions & 1 deletion examples/config_templates/logger/rerun/carla.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ _convert_: all
schema:
frame:
#@ for camera in cameras:
(@=camera@): rerun.components.ImageBufferBatch
(@=camera@):
rerun.components.ImageBufferBatch:
color_model: RGB
#@ end

table:
Expand Down
4 changes: 3 additions & 1 deletion examples/config_templates/logger/rerun/mcap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ _target_: rbyte.viz.loggers.RerunLogger
schema:
frame:
#@ for topic in camera_topics:
(@=topic@): rerun.components.ImageBufferBatch
(@=topic@):
rerun.components.ImageBufferBatch:
color_model: RGB
#@ end

table:
Expand Down
4 changes: 3 additions & 1 deletion examples/config_templates/logger/rerun/yaak.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ _target_: rbyte.viz.loggers.RerunLogger
schema:
frame:
#@ for camera in cameras:
(@=camera@): rerun.components.ImageBufferBatch
(@=camera@):
rerun.components.ImageBufferBatch:
color_model: RGB
#@ end

table:
Expand Down
3 changes: 3 additions & 0 deletions src/rbyte/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import BaseModel, HydraConfig

__all__ = ["BaseModel", "HydraConfig"]
2 changes: 1 addition & 1 deletion src/rbyte/io/table/yaak/idl-repo
185 changes: 121 additions & 64 deletions src/rbyte/viz/loggers/rerun_logger.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from functools import cache, cached_property
from typing import Any, Literal, Protocol, cast, override, runtime_checkable

from typing import (
Annotated,
Any,
Literal,
Protocol,
Self,
cast,
override,
runtime_checkable,
)

import numpy.typing as npt
import rerun as rr
import torch
from optree import tree_flatten_with_path
from pydantic import ImportString, validate_call
from pydantic import (
BeforeValidator,
ConfigDict,
ImportString,
model_validator,
validate_call,
)
from rerun._baseclasses import ComponentBatchMixin # noqa: PLC2701
from rerun._send_columns import TimeColumnLike # noqa: PLC2701
from torch import Tensor

from rbyte.batch import Batch
from rbyte.config import BaseModel

from .base import Logger

Expand All @@ -19,21 +33,76 @@
class TimeColumn(TimeColumnLike, Protocol): ...


class ImageFormatConfig(BaseModel):
pixel_format: (
Annotated[rr.PixelFormat, BeforeValidator(rr.PixelFormat.auto)] | None
) = None

color_model: (
Annotated[rr.ColorModel, BeforeValidator(rr.ColorModel.auto)] | None
) = None

@model_validator(mode="after")
def validate_model(self: Self) -> Self:
if not (bool(self.pixel_format) ^ bool(self.color_model)):
msg = "either pixel_format or color_model must be specified"
raise ValueError(msg)

return self


TableSchema = (
ImportString[type[TimeColumn]] | ImportString[type[rr.components.ScalarBatch]]
)
FrameSchema = Mapping[
ImportString[type[rr.components.ImageBufferBatch]], ImageFormatConfig
]


class Schema(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

frame: Mapping[str, FrameSchema]
table: Mapping[str, TableSchema]

@cached_property
def times(self) -> Mapping[tuple[Literal["table"], str], TimeColumn]:
return {
("table", k): v for k, v in self.table.items() if isinstance(v, TimeColumn)
}

@cached_property
def components(
self,
) -> Mapping[tuple[str, str], FrameSchema | type[ComponentBatchMixin]]:
return {("frame", k): v for k, v in self.frame.items()} | {
("table", k): v
for k, v in self.table.items()
if issubclass(v, ComponentBatchMixin)
}


class RerunLogger(Logger[Batch]):
@validate_call
def __init__(
self,
schema: Mapping[Literal["frame", "table"], Mapping[str, ImportString[Any]]],
) -> None:
def __init__(self, schema: Schema) -> None:
super().__init__()

self._schema = schema

@cache # noqa: B019
def _get_recording(self, *, application_id: str) -> rr.RecordingStream: # noqa: PLR6301
return rr.new_recording(
def _get_recording(self, *, application_id: str) -> rr.RecordingStream:
with rr.new_recording(
application_id=application_id, spawn=True, make_default=True
)
) as recording:
for k in self._schema.frame:
rr.log(
entity_path=f"frame/{k}",
entity=[rr.Image.indicator()],
static=True,
strict=True,
)

return recording

@override
def log(self, batch_idx: int, batch: Batch) -> None:
Expand All @@ -44,68 +113,56 @@ def log(self, batch_idx: int, batch: Batch) -> None:
strict=True,
):
with self._get_recording(application_id=input_id): # pyright: ignore[reportUnknownArgumentType]
times = [
fn(timeline="/".join(k), times=sample.get(k).numpy()) # pyright: ignore[reportUnknownMemberType, reportCallIssue]
for k, fn in self.times
times: Sequence[TimeColumn] = [
v(timeline="/".join(k), times=sample.get(k).numpy()) # pyright: ignore[reportUnknownMemberType, reportCallIssue]
for k, v in self._schema.times.items()
]

for k, fn in self.components:
path = "/".join(k)
tensor = cast(Tensor, sample.get(k)) # pyright: ignore[reportUnknownMemberType]
match fn:
case rr.components.ImageBufferBatch:
match tensor.shape, tensor.dtype:
case ((_, height, width, 3), torch.uint8):
# TODO: make this configurable? # noqa: FIX002
rr.log(
path,
[
rr.components.ImageFormat(
height=height,
width=width,
color_model="RGB",
channel_datatype="U8",
),
rr.Image.indicator(),
],
static=True,
strict=True,
for k, v in self._schema.components.items():
tensor = cast(npt.NDArray[Any], sample.get(k).cpu().numpy()) # pyright: ignore[reportUnknownMemberType]
match v:
case rr.components.ScalarBatch:
components = [v(tensor)]

case {
rr.components.ImageBufferBatch: ImageFormatConfig(
pixel_format=pixel_format, color_model=color_model
)
}:
match (pixel_format, color_model, tensor.shape):
case None, rr.ColorModel.RGB, (
(b, h, w, 3) | (b, 3, h, w)
):
image_format = rr.components.ImageFormat(
width=w,
height=h,
color_model=color_model,
channel_datatype=rr.ChannelDatatype.from_np_dtype(
tensor.dtype
),
)

# https://github.com/rerun-io/rerun/blob/46a7035bca81f4ff158e0975a5a78746fc2c730c/docs/snippets/all/archetypes/image_send_columns.py#L26
tensor = tensor.flatten(start_dim=1, end_dim=-1)
case rr.PixelFormat.NV12, None, (b, dim, w):
image_format = rr.components.ImageFormat(
width=w,
height=int(dim / 1.5),
pixel_format=pixel_format,
)

case _:
raise NotImplementedError

components = [
rr.components.ImageFormatBatch([image_format] * b),
rr.components.ImageBufferBatch(tensor.reshape(b, -1)),
]

case _:
pass
raise NotImplementedError

rr.send_columns(
entity_path="/".join(k),
times=times,
components=[fn(tensor.numpy())], # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportCallIssue]
components=components,
strict=True,
)

@cached_property
def times(self) -> tuple[tuple[tuple[str, ...], type[TimeColumn]], ...]:
paths, leaves, _ = tree_flatten_with_path(self._schema) # pyright: ignore[reportArgumentType, reportUnknownVariableType]

return tuple(
(path, leaf)
for path, leaf in zip(paths, leaves, strict=True) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
if issubclass(leaf, TimeColumn)
)

@cached_property
def components(
self,
) -> tuple[tuple[tuple[str, ...], type[ComponentBatchMixin]], ...]:
paths, leaves, _ = tree_flatten_with_path(self._schema) # pyright: ignore[reportArgumentType, reportUnknownVariableType]

return tuple(
(path, leaf)
for path, leaf in zip(paths, leaves, strict=True) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
if issubclass(leaf, ComponentBatchMixin)
)
Loading