Skip to content

Commit

Permalink
feat: image format config for rerun logger (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
egorchakov authored Oct 2, 2024
1 parent f2c0ff8 commit c1b6835
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 71 deletions.
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,18 @@ logger:
_target_: rbyte.viz.loggers.RerunLogger
schema:
frame:
/CAM_FRONT/image_rect_compressed: rerun.components.ImageBufferBatch
/CAM_FRONT_LEFT/image_rect_compressed: rerun.components.ImageBufferBatch
/CAM_FRONT_RIGHT/image_rect_compressed: rerun.components.ImageBufferBatch
/CAM_FRONT/image_rect_compressed:
rerun.components.ImageBufferBatch:
color_model: RGB

/CAM_FRONT_LEFT/image_rect_compressed:
rerun.components.ImageBufferBatch:
color_model: RGB

/CAM_FRONT_RIGHT/image_rect_compressed:
rerun.components.ImageBufferBatch:
color_model: RGB

table:
/CAM_FRONT/image_rect_compressed/log_time: rerun.TimeNanosColumn
/CAM_FRONT/image_rect_compressed/idx: rerun.TimeSequenceColumn
Expand Down
4 changes: 3 additions & 1 deletion examples/config_templates/logger/rerun/carla.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ _convert_: all
schema:
frame:
#@ for camera in cameras:
(@=camera@): rerun.components.ImageBufferBatch
(@=camera@):
rerun.components.ImageBufferBatch:
color_model: RGB
#@ end

table:
Expand Down
4 changes: 3 additions & 1 deletion examples/config_templates/logger/rerun/mcap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ _target_: rbyte.viz.loggers.RerunLogger
schema:
frame:
#@ for topic in camera_topics:
(@=topic@): rerun.components.ImageBufferBatch
(@=topic@):
rerun.components.ImageBufferBatch:
color_model: RGB
#@ end

table:
Expand Down
4 changes: 3 additions & 1 deletion examples/config_templates/logger/rerun/yaak.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ _target_: rbyte.viz.loggers.RerunLogger
schema:
frame:
#@ for camera in cameras:
(@=camera@): rerun.components.ImageBufferBatch
(@=camera@):
rerun.components.ImageBufferBatch:
color_model: RGB
#@ end

table:
Expand Down
3 changes: 3 additions & 0 deletions src/rbyte/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import BaseModel, HydraConfig

__all__ = ["BaseModel", "HydraConfig"]
2 changes: 1 addition & 1 deletion src/rbyte/io/table/yaak/idl-repo
185 changes: 121 additions & 64 deletions src/rbyte/viz/loggers/rerun_logger.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from functools import cache, cached_property
from typing import Any, Literal, Protocol, cast, override, runtime_checkable

from typing import (
Annotated,
Any,
Literal,
Protocol,
Self,
cast,
override,
runtime_checkable,
)

import numpy.typing as npt
import rerun as rr
import torch
from optree import tree_flatten_with_path
from pydantic import ImportString, validate_call
from pydantic import (
BeforeValidator,
ConfigDict,
ImportString,
model_validator,
validate_call,
)
from rerun._baseclasses import ComponentBatchMixin # noqa: PLC2701
from rerun._send_columns import TimeColumnLike # noqa: PLC2701
from torch import Tensor

from rbyte.batch import Batch
from rbyte.config import BaseModel

from .base import Logger

Expand All @@ -19,21 +33,76 @@
class TimeColumn(TimeColumnLike, Protocol): ...


class ImageFormatConfig(BaseModel):
pixel_format: (
Annotated[rr.PixelFormat, BeforeValidator(rr.PixelFormat.auto)] | None
) = None

color_model: (
Annotated[rr.ColorModel, BeforeValidator(rr.ColorModel.auto)] | None
) = None

@model_validator(mode="after")
def validate_model(self: Self) -> Self:
if not (bool(self.pixel_format) ^ bool(self.color_model)):
msg = "either pixel_format or color_model must be specified"
raise ValueError(msg)

return self


TableSchema = (
ImportString[type[TimeColumn]] | ImportString[type[rr.components.ScalarBatch]]
)
FrameSchema = Mapping[
ImportString[type[rr.components.ImageBufferBatch]], ImageFormatConfig
]


class Schema(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

frame: Mapping[str, FrameSchema]
table: Mapping[str, TableSchema]

@cached_property
def times(self) -> Mapping[tuple[Literal["table"], str], TimeColumn]:
return {
("table", k): v for k, v in self.table.items() if isinstance(v, TimeColumn)
}

@cached_property
def components(
self,
) -> Mapping[tuple[str, str], FrameSchema | type[ComponentBatchMixin]]:
return {("frame", k): v for k, v in self.frame.items()} | {
("table", k): v
for k, v in self.table.items()
if issubclass(v, ComponentBatchMixin)
}


class RerunLogger(Logger[Batch]):
@validate_call
def __init__(
self,
schema: Mapping[Literal["frame", "table"], Mapping[str, ImportString[Any]]],
) -> None:
def __init__(self, schema: Schema) -> None:
super().__init__()

self._schema = schema

@cache # noqa: B019
def _get_recording(self, *, application_id: str) -> rr.RecordingStream: # noqa: PLR6301
return rr.new_recording(
def _get_recording(self, *, application_id: str) -> rr.RecordingStream:
with rr.new_recording(
application_id=application_id, spawn=True, make_default=True
)
) as recording:
for k in self._schema.frame:
rr.log(
entity_path=f"frame/{k}",
entity=[rr.Image.indicator()],
static=True,
strict=True,
)

return recording

@override
def log(self, batch_idx: int, batch: Batch) -> None:
Expand All @@ -44,68 +113,56 @@ def log(self, batch_idx: int, batch: Batch) -> None:
strict=True,
):
with self._get_recording(application_id=input_id): # pyright: ignore[reportUnknownArgumentType]
times = [
fn(timeline="/".join(k), times=sample.get(k).numpy()) # pyright: ignore[reportUnknownMemberType, reportCallIssue]
for k, fn in self.times
times: Sequence[TimeColumn] = [
v(timeline="/".join(k), times=sample.get(k).numpy()) # pyright: ignore[reportUnknownMemberType, reportCallIssue]
for k, v in self._schema.times.items()
]

for k, fn in self.components:
path = "/".join(k)
tensor = cast(Tensor, sample.get(k)) # pyright: ignore[reportUnknownMemberType]
match fn:
case rr.components.ImageBufferBatch:
match tensor.shape, tensor.dtype:
case ((_, height, width, 3), torch.uint8):
# TODO: make this configurable? # noqa: FIX002
rr.log(
path,
[
rr.components.ImageFormat(
height=height,
width=width,
color_model="RGB",
channel_datatype="U8",
),
rr.Image.indicator(),
],
static=True,
strict=True,
for k, v in self._schema.components.items():
tensor = cast(npt.NDArray[Any], sample.get(k).cpu().numpy()) # pyright: ignore[reportUnknownMemberType]
match v:
case rr.components.ScalarBatch:
components = [v(tensor)]

case {
rr.components.ImageBufferBatch: ImageFormatConfig(
pixel_format=pixel_format, color_model=color_model
)
}:
match (pixel_format, color_model, tensor.shape):
case None, rr.ColorModel.RGB, (
(b, h, w, 3) | (b, 3, h, w)
):
image_format = rr.components.ImageFormat(
width=w,
height=h,
color_model=color_model,
channel_datatype=rr.ChannelDatatype.from_np_dtype(
tensor.dtype
),
)

# https://github.com/rerun-io/rerun/blob/46a7035bca81f4ff158e0975a5a78746fc2c730c/docs/snippets/all/archetypes/image_send_columns.py#L26
tensor = tensor.flatten(start_dim=1, end_dim=-1)
case rr.PixelFormat.NV12, None, (b, dim, w):
image_format = rr.components.ImageFormat(
width=w,
height=int(dim / 1.5),
pixel_format=pixel_format,
)

case _:
raise NotImplementedError

components = [
rr.components.ImageFormatBatch([image_format] * b),
rr.components.ImageBufferBatch(tensor.reshape(b, -1)),
]

case _:
pass
raise NotImplementedError

rr.send_columns(
entity_path="/".join(k),
times=times,
components=[fn(tensor.numpy())], # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportCallIssue]
components=components,
strict=True,
)

@cached_property
def times(self) -> tuple[tuple[tuple[str, ...], type[TimeColumn]], ...]:
paths, leaves, _ = tree_flatten_with_path(self._schema) # pyright: ignore[reportArgumentType, reportUnknownVariableType]

return tuple(
(path, leaf)
for path, leaf in zip(paths, leaves, strict=True) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
if issubclass(leaf, TimeColumn)
)

@cached_property
def components(
self,
) -> tuple[tuple[tuple[str, ...], type[ComponentBatchMixin]], ...]:
paths, leaves, _ = tree_flatten_with_path(self._schema) # pyright: ignore[reportArgumentType, reportUnknownVariableType]

return tuple(
(path, leaf)
for path, leaf in zip(paths, leaves, strict=True) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
if issubclass(leaf, ComponentBatchMixin)
)

0 comments on commit c1b6835

Please sign in to comment.