diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a1b2432..d0fc271 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -48,20 +48,9 @@ jobs: enable-cache: true cache-dependency-glob: "**/pyproject.toml" - - name: sync - run: just sync - - - name: build protos - run: just build-protos - - - name: format - run: just format --check - - - name: lint - run: just lint - - - name: typecheck - run: just typecheck - - - name: test - run: just test + - run: just sync + - run: just build + - run: just format --check + - run: just lint + - run: just typecheck + - run: just test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5838eb3..0487e01 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,14 +11,14 @@ repos: - id: pyupgrade - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.5 + rev: v0.9.9 hooks: - id: ruff args: [--fix] - id: ruff-format - repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror - rev: 1.26.0 + rev: 1.28.1 hooks: - id: basedpyright diff --git a/config/_templates/dataset/carla.yaml b/config/_templates/dataset/carla.yaml index 4245c12..56f0ec1 100644 --- a/config/_templates/dataset/carla.yaml +++ b/config/_templates/dataset/carla.yaml @@ -56,75 +56,85 @@ _target_: rbyte.Dataset _recursive_: false _convert_: all -inputs: +sources: #@ for input_id in drives: (@=input_id@): - sources: - #@ for source_id in cameras: - (@=source_id@): - index_column: _idx_ - source: - _target_: rbyte.io.PathTensorSource - path: "${data_dir}/(@=input_id@)/frames/(@=source_id@).defish.mp4/576x324/{:09d}.jpg" - decoder: - _target_: simplejpeg.decode_jpeg - _partial_: true - colorspace: rgb - fastdct: true - fastupsample: true + #@ for source_id in cameras: + (@=source_id@): + index_column: _idx_ + source: + _target_: rbyte.io.PathTensorSource + path: "${data_dir}/(@=input_id@)/frames/(@=source_id@).defish.mp4/576x324/{:09d}.jpg" + decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true - #@ end + #@ end + #@ end - samples: - pipeline: - _target_: pipefunc.Pipeline - validate_type_annotations: false - functions: - - _target_: pipefunc.PipeFunc - bound: - path: ${data_dir}/(@=input_id@)/ego_logs.json - output_name: ego_logs - func: - _target_: rbyte.io.JsonDataFrameBuilder - fields: - records: - control.brake: - control.throttle: - control.steer: - state.velocity.value: - state.acceleration.value: +samples: + inputs: + #@ for input_id in drives: + (@=input_id@): + ego_logs_path: ${data_dir}/(@=input_id@)/ego_logs.json + #@ end - - _target_: pipefunc.PipeFunc - renames: - input: ego_logs - output_name: data - func: - _target_: rbyte.io.DataFrameConcater - method: vertical + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + functions: + - _target_: pipefunc.PipeFunc + renames: + path: ego_logs_path + output_name: ego_logs + mapspec: "ego_logs_path[i] -> ego_logs[i]" + func: + _target_: rbyte.io.JsonDataFrameBuilder + fields: + records: + control.brake: + control.throttle: + control.steer: + state.velocity.value: + state.acceleration.value: - - _target_: pipefunc.PipeFunc - renames: - input: data - output_name: data_resampled - func: - _target_: rbyte.io.DataFrameFpsResampler - fps_in: 20 - fps_out: 30 + - _target_: pipefunc.PipeFunc + renames: + input: ego_logs + output_name: data + mapspec: "ego_logs[i] -> data[i]" + func: + _target_: rbyte.io.DataFrameConcater + method: vertical - - _target_: pipefunc.PipeFunc - renames: - input: data_resampled - output_name: data_indexed - func: - _target_: rbyte.io.DataFrameIndexer - name: _idx_ + - _target_: pipefunc.PipeFunc + renames: + input: data + output_name: resampled + mapspec: "data[i] -> resampled[i]" + func: + _target_: rbyte.io.DataFrameFpsResampler + fps_in: 20 + fps_out: 30 - - _target_: pipefunc.PipeFunc - renames: - input: data_indexed - output_name: data_filtered - func: - _target_: rbyte.io.DataFrameFilter - predicate: | - `control.throttle` > 0.5 - #@ end + - _target_: pipefunc.PipeFunc + renames: + input: resampled + output_name: indexed + mapspec: "resampled[i] -> indexed[i]" + func: + _target_: rbyte.io.DataFrameIndexer + name: _idx_ + + - _target_: pipefunc.PipeFunc + renames: + input: indexed + output_name: samples + mapspec: "indexed[i] -> samples[i]" + func: + _target_: rbyte.io.DataFrameFilter + predicate: | + `control.throttle` > 0.5 diff --git a/config/_templates/dataset/mimicgen.yaml b/config/_templates/dataset/mimicgen.yaml index b897072..392f6cd 100644 --- a/config/_templates/dataset/mimicgen.yaml +++ b/config/_templates/dataset/mimicgen.yaml @@ -14,49 +14,60 @@ _target_: rbyte.Dataset _recursive_: false _convert_: all -inputs: +sources: #@ for input_id, input_keys in inputs.items(): #@ for input_key in input_keys: (@=input_id@)(@=input_key@): - sources: - #@ for frame_key in frame_keys: - (@=frame_key@): - index_column: _idx_ - source: - _target_: rbyte.io.Hdf5TensorSource - path: "${data_dir}/(@=input_id@).hdf5" - key: (@=input_key@)/(@=frame_key@) - #@ end + #@ for frame_key in frame_keys: + (@=frame_key@): + index_column: _idx_ + source: + _target_: rbyte.io.Hdf5TensorSource + path: "${data_dir}/(@=input_id@).hdf5" + key: (@=input_key@)/(@=frame_key@) + #@ end + #@ end + #@ end - samples: - pipeline: - _target_: pipefunc.Pipeline - validate_type_annotations: false - functions: - - _target_: pipefunc.PipeFunc - bound: - path: "${data_dir}/(@=input_id@).hdf5" - output_name: data - func: - _target_: rbyte.io.Hdf5DataFrameBuilder - fields: - (@=input_key@): - obs/robot0_eef_pos: +samples: + inputs: + #@ for input_id, input_keys in inputs.items(): + #@ for input_key in input_keys: + (@=input_id@)(@=input_key@): + path: "${data_dir}/(@=input_id@).hdf5" + prefix: (@=input_key@) + #@ end + #@ end - - _target_: pipefunc.PipeFunc - renames: - input: data - output_name: data_indexed - func: - _target_: rbyte.io.DataFrameIndexer - name: _idx_ + executor: + _target_: concurrent.futures.ThreadPoolExecutor - - _target_: pipefunc.PipeFunc - renames: - input: data_indexed - output_name: data_concated - func: - _target_: rbyte.io.DataFrameConcater - method: vertical - #@ end - #@ end + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + functions: + - _target_: pipefunc.PipeFunc + output_name: data + mapspec: "path[i], prefix[i] -> data[i]" + func: + _target_: rbyte.io.Hdf5DataFrameBuilder + fields: + obs/robot0_eef_pos: + + - _target_: pipefunc.PipeFunc + renames: + input: data + output_name: indexed + mapspec: "data[i] -> indexed[i]" + func: + _target_: rbyte.io.DataFrameIndexer + name: _idx_ + + - _target_: pipefunc.PipeFunc + renames: + input: indexed + output_name: concated + mapspec: "indexed[i] -> concated[i]" + func: + _target_: rbyte.io.DataFrameConcater + method: vertical diff --git a/config/_templates/dataset/nuscenes/mcap.yaml b/config/_templates/dataset/nuscenes/mcap.yaml index 2988694..235fa55 100644 --- a/config/_templates/dataset/nuscenes/mcap.yaml +++ b/config/_templates/dataset/nuscenes/mcap.yaml @@ -13,96 +13,106 @@ _target_: rbyte.Dataset _recursive_: false _convert_: all -inputs: +sources: #@ for input_id in inputs: (@=input_id@): - sources: - #@ for camera, topic in camera_topics.items(): - (@=camera@): - index_column: (@=topic@)/_idx_ - source: - _target_: rbyte.io.McapTensorSource - path: "${data_dir}/(@=input_id@).mcap" - topic: (@=topic@) - decoder_factory: mcap_protobuf.decoder.DecoderFactory - decoder: - _target_: simplejpeg.decode_jpeg - _partial_: true - colorspace: rgb - fastdct: true - fastupsample: true - #@ end + #@ for camera, topic in camera_topics.items(): + (@=camera@): + index_column: (@=topic@)/_idx_ + source: + _target_: rbyte.io.McapTensorSource + path: "${data_dir}/(@=input_id@).mcap" + topic: (@=topic@) + decoder_factory: mcap_protobuf.decoder.DecoderFactory + decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true + #@ end + #@ end - samples: - pipeline: - _target_: pipefunc.Pipeline - validate_type_annotations: false - functions: - - _target_: pipefunc.PipeFunc - bound: - path: ${data_dir}/(@=input_id@).mcap - output_name: data - func: - _target_: rbyte.io.McapDataFrameBuilder - decoder_factories: - - rbyte.utils._mcap.ProtobufDecoderFactory - - rbyte.utils._mcap.JsonDecoderFactory - fields: - #@ for topic in camera_topics.values(): - (@=topic@): - log_time: - _target_: polars.Datetime - time_unit: ns - #@ end +samples: + inputs: + #@ for input_id in inputs: + (@=input_id@): + path: "${data_dir}/(@=input_id@).mcap" + #@ end - /odom: - log_time: - _target_: polars.Datetime - time_unit: ns - vel.x: + executor: + _target_: concurrent.futures.ThreadPoolExecutor + + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + functions: + - _target_: pipefunc.PipeFunc + output_name: data + mapspec: "path[i] -> data[i]" + func: + _target_: rbyte.io.McapDataFrameBuilder + decoder_factories: + - rbyte.utils._mcap.ProtobufDecoderFactory + - rbyte.utils._mcap.JsonDecoderFactory + fields: + #@ for topic in camera_topics.values(): + (@=topic@): + log_time: + _target_: polars.Datetime + time_unit: ns + #@ end - - _target_: pipefunc.PipeFunc - renames: - input: data - output_name: data_indexed - func: - _target_: rbyte.io.DataFrameIndexer - name: _idx_ + /odom: + log_time: + _target_: polars.Datetime + time_unit: ns + vel.x: - - _target_: pipefunc.PipeFunc - renames: - input: data_indexed - output_name: data_aligned - func: - _target_: rbyte.io.DataFrameAligner - separator: / - fields: - #@ topic = camera_topics.values()[0] - (@=topic@): - key: log_time + - _target_: pipefunc.PipeFunc + renames: + input: data + output_name: indexed + mapspec: "data[i] -> indexed[i]" + func: + _target_: rbyte.io.DataFrameIndexer + name: _idx_ - #@ for topic in camera_topics.values()[1:]: - (@=topic@): - key: log_time - columns: - _idx_: - method: asof - tolerance: 40ms - strategy: nearest - #@ end + - _target_: pipefunc.PipeFunc + renames: + input: indexed + output_name: aligned + mapspec: "indexed[i] -> aligned[i]" + func: + _target_: rbyte.io.DataFrameAligner + separator: / + fields: + #@ topic = camera_topics.values()[0] + (@=topic@): + key: log_time - /odom: - key: log_time - columns: - vel.x: - method: interp + #@ for topic in camera_topics.values()[1:]: + (@=topic@): + key: log_time + columns: + _idx_: + method: asof + tolerance: 40ms + strategy: nearest + #@ end - - _target_: pipefunc.PipeFunc - renames: - input: data_aligned - output_name: data_filtered - func: - _target_: rbyte.io.DataFrameFilter - predicate: | - `/odom/vel.x` >= 8 - #@ end + /odom: + key: log_time + columns: + vel.x: + method: interp + + - _target_: pipefunc.PipeFunc + renames: + input: aligned + output_name: samples + mapspec: "aligned[i] -> samples[i]" + func: + _target_: rbyte.io.DataFrameFilter + predicate: | + `/odom/vel.x` >= 8 diff --git a/config/_templates/dataset/nuscenes/rrd.yaml b/config/_templates/dataset/nuscenes/rrd.yaml index a0af322..a46c0ee 100644 --- a/config/_templates/dataset/nuscenes/rrd.yaml +++ b/config/_templates/dataset/nuscenes/rrd.yaml @@ -13,90 +13,100 @@ _target_: rbyte.Dataset _recursive_: false _convert_: all -inputs: +sources: #@ for input_id in inputs: (@=input_id@): - sources: - #@ for camera, entity in camera_entities.items(): - (@=camera@): - index_column: (@=entity@)/_idx_ - source: - _target_: rbyte.io.RrdFrameSource - path: "${data_dir}/(@=input_id@).rrd" - index: timestamp - entity_path: (@=entity@) - decoder: - _target_: simplejpeg.decode_jpeg - _partial_: true - colorspace: rgb - fastdct: true - fastupsample: true - #@ end + #@ for camera, entity in camera_entities.items(): + (@=camera@): + index_column: (@=entity@)/_idx_ + source: + _target_: rbyte.io.RrdFrameSource + path: "${data_dir}/(@=input_id@).rrd" + index: timestamp + entity_path: (@=entity@) + decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true + #@ end + #@ end + +samples: + inputs: + #@ for input_id in inputs: + (@=input_id@): + path: "${data_dir}/(@=input_id@).rrd" + #@ end - samples: - pipeline: - _target_: pipefunc.Pipeline - validate_type_annotations: false - functions: - - _target_: pipefunc.PipeFunc - bound: - path: "${data_dir}/(@=input_id@).rrd" - output_name: data - func: - _target_: rbyte.io.RrdDataFrameBuilder - index: timestamp - contents: - #@ for entity in camera_entities.values(): - (@=entity@): - #@ end + executor: + _target_: concurrent.futures.ThreadPoolExecutor - /world/ego_vehicle/LIDAR_TOP: - - Position3D + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + functions: + - _target_: pipefunc.PipeFunc + output_name: data + mapspec: "path[i] -> data[i]" + func: + _target_: rbyte.io.RrdDataFrameBuilder + index: timestamp + contents: + #@ for entity in camera_entities.values(): + (@=entity@): + #@ end - - _target_: pipefunc.PipeFunc - renames: - input: data - output_name: data_indexed - func: - _target_: rbyte.io.DataFrameIndexer - name: _idx_ + /world/ego_vehicle/LIDAR_TOP: + - Position3D - - _target_: pipefunc.PipeFunc - renames: - input: data_indexed - output_name: data_aligned - func: - _target_: rbyte.io.DataFrameAligner - separator: / - fields: - #@ entity = camera_entities.values()[0] - (@=entity@): - key: timestamp + - _target_: pipefunc.PipeFunc + renames: + input: data + output_name: indexed + mapspec: "data[i] -> indexed[i]" + func: + _target_: rbyte.io.DataFrameIndexer + name: _idx_ - #@ for entity in camera_entities.values()[1:]: - (@=entity@): - key: timestamp - columns: - _idx_: - method: asof - strategy: nearest - tolerance: 60ms - #@ end + - _target_: pipefunc.PipeFunc + renames: + input: indexed + output_name: aligned + mapspec: "indexed[i] -> aligned[i]" + func: + _target_: rbyte.io.DataFrameAligner + separator: / + fields: + #@ entity = camera_entities.values()[0] + (@=entity@): + key: timestamp - /world/ego_vehicle/LIDAR_TOP: - key: timestamp - columns: - Position3D: - method: asof - strategy: nearest - tolerance: 60ms + #@ for entity in camera_entities.values()[1:]: + (@=entity@): + key: timestamp + columns: + _idx_: + method: asof + strategy: nearest + tolerance: 60ms + #@ end - - _target_: pipefunc.PipeFunc - renames: - input: data_aligned - output_name: data_filtered - func: - _target_: rbyte.io.DataFrameFilter - predicate: | - `/world/ego_vehicle/CAM_FRONT/timestamp` between '2018-07-24 03:28:48' and '2018-07-24 03:28:50' - #@ end + /world/ego_vehicle/LIDAR_TOP: + key: timestamp + columns: + Position3D: + method: asof + strategy: nearest + tolerance: 60ms + + - _target_: pipefunc.PipeFunc + renames: + input: aligned + output_name: samples + mapspec: "aligned[i] -> samples[i]" + func: + _target_: rbyte.io.DataFrameFilter + predicate: | + `/world/ego_vehicle/CAM_FRONT/timestamp` between '2018-07-24 03:28:48' and '2018-07-24 03:28:50' diff --git a/config/_templates/dataset/yaak.yaml b/config/_templates/dataset/yaak.yaml index c05c3d6..e1eaa70 100644 --- a/config/_templates/dataset/yaak.yaml +++ b/config/_templates/dataset/yaak.yaml @@ -13,195 +13,194 @@ _target_: rbyte.Dataset _recursive_: false _convert_: all -inputs: +sources: #@ for input_id in drives: (@=input_id@): - sources: - #@ for source_id in cameras: - (@=source_id@): - index_column: "meta/ImageMetadata.(@=source_id@)/frame_idx" - source: - _target_: rbyte.io.FfmpegFrameSource - path: "${data_dir}/(@=input_id@)/(@=source_id@).pii.mp4" - resize_shorter_side: 324 - #@ end - - samples: - pipeline: - _target_: pipefunc.Pipeline - validate_type_annotations: false - cache_type: disk - cache_kwargs: - cache_dir: /tmp/rbyte-cache - functions: - - _target_: pipefunc.PipeFunc - func: - _target_: rbyte.io.YaakMetadataDataFrameBuilder - fields: - rbyte.io.yaak.proto.sensor_pb2.ImageMetadata: - time_stamp: - _target_: polars.Datetime - time_unit: ns + #@ for source_id in cameras: + (@=source_id@): + index_column: meta/ImageMetadata.(@=source_id@)/frame_idx + source: + _target_: rbyte.io.FfmpegFrameSource + path: ${data_dir}/(@=input_id@)/(@=source_id@).pii.mp4 + resize_shorter_side: 324 + #@ end + #@ end +samples: + inputs: + #@ for input_id in drives: + (@=input_id@): + meta_path: ${data_dir}/(@=input_id@)/metadata.log + mcap_path: ${data_dir}/(@=input_id@)/ai.mcap + #@ for camera in cameras: + (@=camera@)_path: ${data_dir}/(@=input_id@)/(@=camera@).pii.mp4 + #@ end + #@ end + + executor: + _target_: concurrent.futures.ThreadPoolExecutor + + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + functions: + - _target_: pipefunc.PipeFunc + renames: + path: meta_path + output_name: meta + mapspec: "meta_path[i] -> meta[i]" + func: + _target_: rbyte.io.YaakMetadataDataFrameBuilder + fields: + rbyte.io.yaak.proto.sensor_pb2.ImageMetadata: + time_stamp: + _target_: polars.Datetime + time_unit: ns + + frame_idx: + _target_: polars.Int32 + + camera_name: + _target_: polars.Enum + categories: + - cam_front_center + - cam_front_left + - cam_front_right + - cam_left_forward + - cam_right_forward + - cam_left_backward + - cam_right_backward + - cam_rear + + rbyte.io.yaak.proto.can_pb2.VehicleMotion: + time_stamp: + _target_: polars.Datetime + time_unit: ns + + speed: + _target_: polars.Float32 + + gear: + _target_: polars.Enum + categories: ["0", "1", "2", "3"] + + - _target_: pipefunc.PipeFunc + renames: + path: mcap_path + output_name: mcap + mapspec: "mcap_path[i] -> mcap[i]" + func: + _target_: rbyte.io.McapDataFrameBuilder + decoder_factories: [rbyte.utils._mcap.ProtobufDecoderFactory] + fields: + /ai/safety_score: + clip.end_timestamp: + _target_: polars.Datetime + time_unit: ns + + score: + _target_: polars.Float32 + + - _target_: pipefunc.PipeFunc + output_name: data + mapspec: "meta[i], mcap[i] -> data[i]" + func: + _target_: pipefunc.helpers.collect_kwargs + parameters: [meta, mcap] + function_name: aggregate + + - _target_: pipefunc.PipeFunc + renames: + input: data + output_name: aligned + mapspec: "data[i] -> aligned[i]" + func: + _target_: rbyte.io.DataFrameAligner + separator: / + fields: + meta: + ImageMetadata.(@=cameras[0]@): + key: time_stamp + + #@ for camera in cameras[1:]: + ImageMetadata.(@=camera@): + key: time_stamp + columns: frame_idx: - _target_: polars.Int32 - - camera_name: - _target_: polars.Enum - categories: - - cam_front_center - - cam_front_left - - cam_front_right - - cam_left_forward - - cam_right_forward - - cam_left_backward - - cam_right_backward - - cam_rear - - rbyte.io.yaak.proto.can_pb2.VehicleMotion: - time_stamp: - _target_: polars.Datetime - time_unit: ns + method: asof + tolerance: 20ms + strategy: nearest + #@ end + VehicleMotion: + key: time_stamp + columns: speed: - _target_: polars.Float32 - + method: interp gear: - _target_: polars.Enum - categories: ["0", "1", "2", "3"] - - output_name: output - scope: metadata - cache: true - - - _target_: pipefunc.PipeFunc - func: - _target_: rbyte.io.McapDataFrameBuilder - decoder_factories: [rbyte.utils._mcap.ProtobufDecoderFactory] - fields: - /ai/safety_score: + method: asof + tolerance: 100ms + strategy: nearest + + mcap: + /ai/safety_score: + key: clip.end_timestamp + columns: clip.end_timestamp: - _target_: polars.Datetime - time_unit: ns - + method: asof + tolerance: 500ms + strategy: nearest score: - _target_: polars.Float32 - - output_name: output - bound: - path: ${data_dir}/(@=input_id@)/ai.mcap - scope: mcap - - - _target_: pipefunc.PipeFunc - func: - _target_: pipefunc.helpers.collect_kwargs - parameters: [meta, mcap] - output_name: data - renames: - meta: metadata.output - mcap: mcap.output - - - _target_: pipefunc.PipeFunc - func: - _target_: rbyte.io.DataFrameAligner - separator: / - fields: - meta: - ImageMetadata.(@=cameras[0]@): - key: time_stamp - - #@ for camera in cameras[1:]: - ImageMetadata.(@=camera@): - key: time_stamp - columns: - frame_idx: - method: asof - tolerance: 20ms - strategy: nearest - #@ end - - VehicleMotion: - key: time_stamp - columns: - speed: - method: interp - gear: - method: asof - tolerance: 100ms - strategy: nearest - - mcap: - /ai/safety_score: - key: clip.end_timestamp - columns: - clip.end_timestamp: - method: asof - tolerance: 500ms - strategy: nearest - score: - method: asof - tolerance: 500ms - strategy: nearest - - output_name: data_aligned - renames: - input: data - - - _target_: pipefunc.PipeFunc - func: - _target_: rbyte.io.DataFrameFilter - predicate: | - `meta/VehicleMotion/speed` > 44 - output_name: data_filtered - renames: - input: data_aligned - - #@ for i, camera in enumerate(cameras): - - _target_: pipefunc.PipeFunc - func: - _target_: rbyte.io.VideoDataFrameBuilder - fields: - frame_idx: - _target_: polars.Int32 - - output_name: data_(@=camera@) - bound: - path: "${data_dir}/(@=input_id@)/(@=camera@).pii.mp4" - - - _target_: pipefunc.PipeFunc - func: - _target_: hydra.utils.get_method - path: polars.DataFrame.join - #@ if i == len(cameras) - 1: - output_name: data_joined - #@ else: - output_name: data_joined_(@=camera@) - #@ end - renames: - #@ if i == 0: - self: data_filtered - #@ else: - self: data_joined_(@=cameras[i-1]@) - #@ end - other: data_(@=camera@) - bound: - how: semi - left_on: meta/ImageMetadata.(@=camera@)/frame_idx - right_on: frame_idx - #@ end - - - _target_: pipefunc.PipeFunc - renames: - input: data_joined - output_name: samples - func: - _target_: rbyte.FixedWindowSampleBuilder - index_column: meta/ImageMetadata.(@=cameras[0]@)/frame_idx - every: 6i - period: 6i - length: 6 - - kwargs: - metadata: - path: ${data_dir}/(@=input_id@)/metadata.log - #@ end + tolerance: 500ms + strategy: nearest + + - _target_: pipefunc.PipeFunc + renames: + input: aligned + output_name: filtered + mapspec: "aligned[i] -> filtered[i]" + func: + _target_: rbyte.io.DataFrameFilter + predicate: | + `meta/VehicleMotion/speed` > 44 + + #@ for i, camera in enumerate(cameras): + - _target_: pipefunc.PipeFunc + renames: + path: (@=camera@)_path + output_name: (@=camera@)_meta + mapspec: "(@=camera@)_path[i] -> (@=camera@)_meta[i]" + func: + _target_: rbyte.io.VideoDataFrameBuilder + fields: + frame_idx: + _target_: polars.Int32 + + - _target_: pipefunc.PipeFunc + #@ left = "filtered" if i == 0 else "joined_{}".format('_'.join(cameras[:i])) + #@ right = "{}_meta".format(camera) + #@ joined = "joined_{}".format('_'.join(cameras[:i+1])) + renames: + left: #@ left + right: #@ right + output_name: #@ joined + mapspec: "(@=left@)[i], (@=right@)[i] -> (@=joined@)[i]" + func: + _target_: rbyte.io.DataFrameJoiner + how: semi + left_on: meta/ImageMetadata.(@=camera@)/frame_idx + right_on: frame_idx + #@ end + + - _target_: pipefunc.PipeFunc + #@ input = "joined_{}".format('_'.join(cameras)) + renames: + input: (@=input@) + output_name: samples + mapspec: "(@=input@)[i] -> samples[i]" + func: + _target_: rbyte.io.FixedWindowSampleBuilder + index_column: meta/ImageMetadata.(@=cameras[0]@)/frame_idx + every: 6i + period: 6i + length: 6 diff --git a/config/_templates/dataset/zod.yaml b/config/_templates/dataset/zod.yaml index ae6eee5..3d8a4e2 100644 --- a/config/_templates/dataset/zod.yaml +++ b/config/_templates/dataset/zod.yaml @@ -2,132 +2,147 @@ _target_: rbyte.Dataset _recursive_: false _convert_: all -inputs: +sources: 000002_short: - sources: - camera_front_blur: - index_column: camera_front_blur/timestamp - source: - _target_: rbyte.io.PathTensorSource - path: "${data_dir}/sequences/000002_short/camera_front_blur/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.jpg" - decoder: - _target_: simplejpeg.decode_jpeg - _partial_: true - colorspace: rgb - fastdct: true - fastupsample: true - - lidar_velodyne: - index_column: lidar_velodyne/timestamp - source: - _target_: rbyte.io.NumpyTensorSource - path: "${data_dir}/sequences/000002_short/lidar_velodyne/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.npy" - select: ["x", "y", "z"] - - samples: - pipeline: - _target_: pipefunc.Pipeline - validate_type_annotations: false - functions: - - _target_: pipefunc.PipeFunc - bound: - path: "${data_dir}/sequences/000002_short/camera_front_blur/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.jpg" - output_name: camera_front_blur - func: - _target_: rbyte.io.PathDataFrameBuilder - fields: - timestamp: - _target_: polars.Datetime - time_unit: ns - - - _target_: pipefunc.PipeFunc - bound: - path: "${data_dir}/sequences/000002_short/lidar_velodyne/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.npy" - output_name: lidar_velodyne - func: - _target_: rbyte.io.PathDataFrameBuilder - fields: + camera_front_blur: + index_column: camera_front_blur_meta/timestamp + source: + _target_: rbyte.io.PathTensorSource + path: "${data_dir}/sequences/000002_short/camera_front_blur/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.jpg" + decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true + + lidar_velodyne: + index_column: lidar_velodyne_meta/timestamp + source: + _target_: rbyte.io.NumpyTensorSource + path: "${data_dir}/sequences/000002_short/lidar_velodyne/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.npy" + select: ["x", "y", "z"] + +samples: + inputs: + 000002_short: + camera_front_blur_path: "${data_dir}/sequences/000002_short/camera_front_blur/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.jpg" + lidar_velodyne_path: "${data_dir}/sequences/000002_short/lidar_velodyne/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.npy" + vehicle_data_path: "${data_dir}/sequences/000002_short/vehicle_data.hdf5" + + executor: + _target_: concurrent.futures.ThreadPoolExecutor + + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + functions: + - _target_: pipefunc.PipeFunc + renames: + path: camera_front_blur_path + output_name: camera_front_blur_meta + mapspec: "camera_front_blur_path[i] -> camera_front_blur_meta[i]" + func: + _target_: rbyte.io.PathDataFrameBuilder + fields: + timestamp: + _target_: polars.Datetime + time_unit: ns + + - _target_: pipefunc.PipeFunc + renames: + path: lidar_velodyne_path + output_name: lidar_velodyne_meta + mapspec: "lidar_velodyne_path[i] -> lidar_velodyne_meta[i]" + func: + _target_: rbyte.io.PathDataFrameBuilder + fields: + timestamp: + _target_: polars.Datetime + time_unit: ns + + - _target_: pipefunc.PipeFunc + renames: + path: vehicle_data_path + output_name: vehicle_data + mapspec: "vehicle_data_path[i] -> vehicle_data[i]" + func: + _target_: rbyte.io.Hdf5DataFrameBuilder + fields: + ego_vehicle_controls: + timestamp/nanoseconds/value: + _target_: polars.Datetime + time_unit: ns + + acceleration_pedal/ratio/unitless/value: + steering_wheel_angle/angle/radians/value: + + satellite: + timestamp/nanoseconds/value: + _target_: polars.Datetime + time_unit: ns + + speed/meters_per_second/value: + + - _target_: pipefunc.PipeFunc + output_name: data + mapspec: "camera_front_blur_meta[i], lidar_velodyne_meta[i], vehicle_data[i] -> data[i]" + func: + _target_: pipefunc.helpers.collect_kwargs + parameters: + [camera_front_blur_meta, lidar_velodyne_meta, vehicle_data] + + - _target_: pipefunc.PipeFunc + renames: + input: data + output_name: aligned + mapspec: "data[i] -> aligned[i]" + func: + _target_: rbyte.io.DataFrameAligner + separator: / + fields: + camera_front_blur_meta: + key: timestamp + + lidar_velodyne_meta: + key: timestamp + columns: timestamp: - _target_: polars.Datetime - time_unit: ns - - - _target_: pipefunc.PipeFunc - bound: - path: "${data_dir}/sequences/000002_short/vehicle_data.hdf5" - output_name: vehicle_data - func: - _target_: rbyte.io.Hdf5DataFrameBuilder - fields: - ego_vehicle_controls: + method: asof + strategy: nearest + tolerance: 100ms + + vehicle_data: + ego_vehicle_controls: + key: timestamp/nanoseconds/value + columns: timestamp/nanoseconds/value: - _target_: polars.Datetime - time_unit: ns + method: asof + strategy: nearest + tolerance: 100ms acceleration_pedal/ratio/unitless/value: - steering_wheel_angle/angle/radians/value: + method: asof + strategy: nearest + tolerance: 100ms - satellite: - timestamp/nanoseconds/value: - _target_: polars.Datetime - time_unit: ns + steering_wheel_angle/angle/radians/value: + method: asof + strategy: nearest + tolerance: 100ms + satellite: + key: timestamp/nanoseconds/value + columns: speed/meters_per_second/value: + method: interp - - _target_: pipefunc.PipeFunc - func: - _target_: pipefunc.helpers.collect_kwargs - parameters: [camera_front_blur, lidar_velodyne, vehicle_data] - output_name: data - - - _target_: pipefunc.PipeFunc - renames: - input: data - output_name: data_aligned - func: - _target_: rbyte.io.DataFrameAligner - separator: / - fields: - camera_front_blur: - key: timestamp - - lidar_velodyne: - key: timestamp - columns: - timestamp: - method: asof - strategy: nearest - tolerance: 100ms - - vehicle_data: - ego_vehicle_controls: - key: timestamp/nanoseconds/value - columns: - timestamp/nanoseconds/value: - method: asof - strategy: nearest - tolerance: 100ms - - acceleration_pedal/ratio/unitless/value: - method: asof - strategy: nearest - tolerance: 100ms - - steering_wheel_angle/angle/radians/value: - method: asof - strategy: nearest - tolerance: 100ms - - satellite: - key: timestamp/nanoseconds/value - columns: - speed/meters_per_second/value: - method: interp - - - _target_: pipefunc.PipeFunc - renames: - input: data_aligned - output_name: samples - func: - _target_: rbyte.FixedWindowSampleBuilder - index_column: camera_front_blur/timestamp - every: 300ms + - _target_: pipefunc.PipeFunc + renames: + input: aligned + output_name: samples + mapspec: "aligned[i] -> samples[i]" + func: + _target_: rbyte.io.FixedWindowSampleBuilder + index_column: camera_front_blur_meta/timestamp + every: 300ms diff --git a/config/_templates/logger/rerun/zod.yaml b/config/_templates/logger/rerun/zod.yaml index ad21664..09736b8 100644 --- a/config/_templates/logger/rerun/zod.yaml +++ b/config/_templates/logger/rerun/zod.yaml @@ -2,12 +2,12 @@ _target_: rbyte.viz.loggers.RerunLogger spawn: true schema: - camera_front_blur/timestamp: TimeNanosColumn + camera_front_blur_meta/timestamp: TimeNanosColumn camera_front_blur: Image: color_model: RGB - lidar_velodyne/timestamp: TimeNanosColumn + lidar_velodyne_meta/timestamp: TimeNanosColumn lidar_velodyne: Points3D vehicle_data/ego_vehicle_controls/timestamp/nanoseconds/value: TimeNanosColumn diff --git a/justfile b/justfile index 68c8102..ac44ff0 100644 --- a/justfile +++ b/justfile @@ -1,6 +1,7 @@ export PYTHONOPTIMIZE := "1" export HATCH_BUILD_CLEAN := "1" export HYDRA_FULL_ERROR := "1" +export TQDM_DISABLE := "1" _default: @just --list --unsorted @@ -18,9 +19,6 @@ setup: sync install-tools git lfs pull uvx pre-commit install --install-hooks -clean: - uvx --from hatch hatch clean - build: uv build @@ -33,10 +31,7 @@ lint *ARGS: typecheck *ARGS: uvx basedpyright {{ ARGS }} -build-protos: - uvx --from hatch hatch build --clean --hooks-only --target sdist - -pre-commit *ARGS: build-protos +pre-commit *ARGS: build uvx pre-commit run --all-files --color=always {{ ARGS }} generate-config: @@ -46,7 +41,7 @@ generate-config: --output yaml \ --strict -test *ARGS: build-protos generate-config +test *ARGS: build generate-config uv run --all-extras pytest --capture=no {{ ARGS }} notebook FILE *ARGS: sync generate-config @@ -61,10 +56,6 @@ visualize *ARGS: generate-config hydra/job_logging=disabled \ {{ ARGS }} -[group('visualize')] -visualize-mimicgen *ARGS: - just visualize dataset=mimicgen logger=rerun/mimicgen ++data_dir={{ justfile_directory() }}/tests/data/mimicgen {{ ARGS }} - [group('visualize')] visualize-yaak *ARGS: just visualize dataset=yaak logger=rerun/yaak ++data_dir={{ justfile_directory() }}/tests/data/yaak {{ ARGS }} @@ -73,6 +64,10 @@ visualize-yaak *ARGS: visualize-zod *ARGS: just visualize dataset=zod logger=rerun/zod ++data_dir={{ justfile_directory() }}/tests/data/zod {{ ARGS }} +[group('visualize')] +visualize-mimicgen *ARGS: + just visualize dataset=mimicgen logger=rerun/mimicgen ++data_dir={{ justfile_directory() }}/tests/data/mimicgen {{ ARGS }} + [group('visualize')] visualize-nuscenes-mcap *ARGS: just visualize dataset=nuscenes/mcap logger=rerun/nuscenes/mcap ++data_dir={{ justfile_directory() }}/tests/data/nuscenes/mcap {{ ARGS }} @@ -81,6 +76,9 @@ visualize-nuscenes-mcap *ARGS: visualize-nuscenes-rrd *ARGS: just visualize dataset=nuscenes/rrd logger=rerun/nuscenes/rrd ++data_dir={{ justfile_directory() }}/tests/data/nuscenes/rrd {{ ARGS }} +[group('visualize')] +visualize-all: visualize-yaak visualize-zod visualize-mimicgen visualize-nuscenes-mcap visualize-nuscenes-rrd + # rerun server and viewer rerun bind="0.0.0.0" port="9876" ws-server-port="9877" web-viewer-port="9090": RUST_LOG=debug uv run rerun \ diff --git a/pyproject.toml b/pyproject.toml index c36caa5..6edb3ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,23 +1,23 @@ [project] name = "rbyte" -version = "0.12.1" +version = "0.13.0" description = "Multimodal PyTorch dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] dependencies = [ - "tensordict>=0.7.0", + "tensordict>=0.7.2", "torch", "numpy", - "polars>=1.22.0", + "polars>=1.24.0", "pydantic>=2.10.6", "more-itertools>=10.6.0", "hydra-core>=1.3.2", - "optree>=0.14.0", + "optree>=0.14.1", "cachetools>=5.5.1", "parse>=1.20.2", "structlog>=25.1.0", "tqdm>=4.67.1", - "pipefunc>=0.53.3", + "pipefunc[autodoc]>=0.56.0", "xxhash>=3.5.0", ] readme = "README.md" @@ -39,21 +39,16 @@ repo = "https://github.com/yaak-ai/rbyte" [project.optional-dependencies] build = ["hatchling>=1.27.0"] protos = ["grpcio-tools>=1.70.0", "protoletariat>=3.3.9"] -visualize = ["rerun-sdk[notebook]>=0.22.0"] -mcap = [ - "mcap>=1.2.2", - "mcap-ros2-support>=0.5.5", - "protobuf", - "mcap-protobuf-support>=0.5.3", -] -yaak = ["protobuf", "ptars>=0.0.3"] -jpeg = ["simplejpeg>=1.8.1"] +visualize = ["rerun-sdk[notebook]>=0.22.1"] +mcap = ["mcap>=1.2.2", "mcap-protobuf-support>=0.5.3", "protobuf"] +yaak = ["ptars>=0.0.3", "protobuf"] +jpeg = ["simplejpeg>=1.8.2"] video = [ - "python-vali>=4.2.9.post1; sys_platform == 'linux'", + "python-vali>=4.2.10post1; sys_platform == 'linux'", "video-reader-rs>=0.2.3", ] -hdf5 = ["h5py>=3.12.1"] -rrd = ["rerun-sdk>=0.22.0", "pyarrow-stubs"] +hdf5 = ["h5py>=3.13.0"] +rrd = ["rerun-sdk>=0.22.1", "pyarrow-stubs"] [project.scripts] rbyte-visualize = 'rbyte.scripts.visualize:main' @@ -69,9 +64,10 @@ dev = [ "pudb>=2024.1.2", "ipython>=8.32.0", "ipython-autoimport>=0.5", - "pytest>=8.3.4", + "pytest>=8.3.5", "testbook>=0.4.2", "ipykernel>=6.29.5", + "pipefunc[plotting]", ] [tool.hatch.metadata] diff --git a/src/rbyte/__init__.py b/src/rbyte/__init__.py index dbfd726..fd81618 100644 --- a/src/rbyte/__init__.py +++ b/src/rbyte/__init__.py @@ -1,8 +1,7 @@ from importlib.metadata import version from .dataset import Dataset -from .sample import FixedWindowSampleBuilder __version__ = version(__package__ or __name__) -__all__ = ["Dataset", "FixedWindowSampleBuilder", "__version__"] +__all__ = ["Dataset", "__version__"] diff --git a/src/rbyte/config/base.py b/src/rbyte/config/base.py index c3cc0f9..6ad2b70 100644 --- a/src/rbyte/config/base.py +++ b/src/rbyte/config/base.py @@ -1,9 +1,10 @@ +from copy import deepcopy from functools import cached_property -from typing import ClassVar, Literal +from typing import Any, ClassVar, Literal, override from hydra.utils import instantiate from pydantic import BaseModel as _BaseModel -from pydantic import ConfigDict, Field, ImportString, model_validator +from pydantic import ConfigDict, Field, ImportString, TypeAdapter, model_validator from pydantic import RootModel as _RootModel @@ -39,12 +40,30 @@ class HydraConfig[T](BaseModel): def instantiate(self, **kwargs: object) -> T: return instantiate(self.model_dump(by_alias=True), **kwargs) + +class PickleableImportString[T](BaseModel): + model_config: ClassVar[ConfigDict] = ConfigDict( + extra="allow", frozen=True, validate_assignment=True + ) + + obj: ImportString[T] + _path: str + @model_validator(mode="before") @classmethod - def validate_model(cls, data: object) -> object: - match data: - case str(): - return {"_target_": data} + def _validate_model(cls, path: str) -> dict[str, str]: + return {"obj": path, "_path": path} + + @override + def __getstate__(self) -> dict[Any, Any]: + state = deepcopy(super().__getstate__()) + state["__dict__"].pop("obj") + + return state - case _: - return data + @override + def __setstate__(self, state: dict[Any, Any]) -> None: + state["__dict__"]["obj"] = TypeAdapter(ImportString[T]).validate_python( + state["__pydantic_extra__"]["_path"] + ) + super().__setstate__(state) diff --git a/src/rbyte/dataset.py b/src/rbyte/dataset.py index 7ffac4d..d198700 100644 --- a/src/rbyte/dataset.py +++ b/src/rbyte/dataset.py @@ -1,15 +1,18 @@ from collections.abc import Mapping, Sequence +from concurrent.futures import Executor from enum import StrEnum, unique from functools import cache -from typing import Annotated, Literal, override +from pathlib import Path +from typing import Annotated, Any, Literal, override import polars as pl import torch -from hydra.utils import instantiate +from optree import tree_map, tree_structure, tree_transpose from pipefunc import Pipeline -from pydantic import ConfigDict, Field, StringConstraints, validate_call +from pipefunc._pipeline._types import OUTPUT_TYPE, StorageType +from pipefunc.map import run_map +from pydantic import ConfigDict, StringConstraints, validate_call from structlog import get_logger -from structlog.contextvars import bound_contextvars from tensordict import TensorDict from torch.utils.data import Dataset as TorchDataset @@ -32,15 +35,24 @@ class SourceConfig(BaseModel): index_column: str -class PipelineConfig(BaseModel): - pipeline: HydraConfig[Pipeline] - output_name: str | None = None - kwargs: dict[str, object] = Field(default_factory=dict) +type SourcesConfig = Mapping[Id, Mapping[Id, SourceConfig]] -class InputConfig(BaseModel): - sources: Mapping[Id, SourceConfig] = Field(min_length=1) - samples: PipelineConfig +class PipelineConfig(BaseModel): + pipeline: HydraConfig[Pipeline] + inputs: Mapping[str, Any] + run_folder: str | Path | None = None + parallel: bool = True + executor: ( + HydraConfig[Executor] | dict[OUTPUT_TYPE, HydraConfig[Executor]] | None + ) = None + chunksizes: int | dict[OUTPUT_TYPE, int] | None = None + storage: StorageType = "dict" + persist_memory: bool = True + cleanup: bool = True + fixed_indices: dict[str, int | slice] | None = None + auto_subpipeline: bool = False + show_progress: bool = False @unique @@ -61,75 +73,20 @@ class _ALL_TYPE: # noqa: N801 class Dataset(TorchDataset[Batch]): - @validate_call(config=BaseModel.model_config) - def __init__( - self, inputs: Annotated[Mapping[Id, InputConfig], Field(min_length=1)] - ) -> None: - logger.debug("initializing dataset") + _samples: pl.DataFrame + _sources: pl.DataFrame + @validate_call(config=BaseModel.model_config) + def __init__(self, sources: SourcesConfig, samples: PipelineConfig) -> None: super().__init__() - samples: Mapping[str, pl.DataFrame] = {} - for input_id, input_cfg in inputs.items(): - with bound_contextvars(input_id=input_id): - samples_cfg = input_cfg.samples - pipeline = samples_cfg.pipeline.instantiate() - output_name = ( - samples_cfg.output_name or pipeline.unique_leaf_node.output_name # pyright: ignore[reportUnknownMemberType] - ) - kwargs = instantiate( - samples_cfg.kwargs, _recursive_=True, _convert_="all" - ) - samples[input_id] = pipeline.run(output_name=output_name, kwargs=kwargs) - logger.debug( - "built samples", - columns=samples[input_id].columns, - len=len(samples[input_id]), - ) + logger.debug("initializing dataset") - input_id_enum = pl.Enum(sorted(samples)) + self._samples = self._build_samples(samples) + logger.debug("built samples", length=len(self._samples)) - self._samples: pl.DataFrame = ( - pl.concat( - [ - df.select( - pl.lit(input_id).cast(input_id_enum).alias(Column.input_id), - pl.col(sorted(df.collect_schema().names())), - ) - for input_id, df in samples.items() - ], - how="vertical", - ) - .sort(Column.input_id) - .with_row_index(Column.sample_idx) - .rechunk() - ) - - self._sources: pl.DataFrame = ( - pl.DataFrame( - [ - { - Column.input_id: input_id, - (k := "__source"): [ - source_cfg.model_dump(exclude={"source"}) - | { - "id": source_id, - "config": source_cfg.source.model_dump_json( - by_alias=True - ), - } - for source_id, source_cfg in input_cfg.sources.items() - ], - } - for input_id, input_cfg in inputs.items() - ], - schema_overrides={Column.input_id: input_id_enum}, - ) - .explode(k) - .unnest(k) - .select(Column.input_id, pl.exclude(Column.input_id).name.prefix(f"{k}.")) - .rechunk() - ) + self._sources = self._build_sources(sources) + logger.debug("built sources") @property def samples(self) -> pl.DataFrame: @@ -139,10 +96,6 @@ def samples(self) -> pl.DataFrame: def sources(self) -> pl.DataFrame: return self._sources - @cache # noqa: B019 - def _get_source(self, config: str) -> TensorSource: # noqa: PLR6301 - return HydraConfig[TensorSource].model_validate_json(config).instantiate() - @validate_call( config=ConfigDict(arbitrary_types_allowed=True, validate_default=False) ) @@ -252,3 +205,84 @@ def __getitem__(self, index: int) -> Batch: def __len__(self) -> int: return len(self.samples) + + @cache # noqa: B019 + def _get_source(self, config: str) -> TensorSource: # noqa: PLR6301 + return HydraConfig[TensorSource].model_validate_json(config).instantiate() + + @classmethod + def _build_samples(cls, samples: PipelineConfig) -> pl.DataFrame: + logger.debug("building samples") + pipeline = samples.pipeline.instantiate() + pipeline.print_documentation() + + input_ids, input_values = zip(*samples.inputs.items(), strict=False) + outer = tree_structure(list(range(len(input_values)))) # pyright: ignore[reportArgumentType] + inner = tree_structure(input_values[0]) + inputs: dict[str, Any] = tree_transpose(outer, inner, input_values) # pyright: ignore[reportArgumentType, reportAssignmentType] + + executor: Executor | dict[OUTPUT_TYPE, Executor] = tree_map( # pyright: ignore[reportAssignmentType] + HydraConfig[Executor].instantiate, + samples.executor, # pyright: ignore[reportArgumentType] + ) + + results = run_map( + pipeline=pipeline, + inputs=inputs, + executor=executor, + **samples.model_dump(exclude={"pipeline", "inputs", "executor"}), + ) + output_name = pipeline.unique_leaf_node.output_name # pyright: ignore[reportUnknownMemberType] + output: Sequence[pl.DataFrame] = results[output_name].output # pyright: ignore[reportArgumentType] + + input_id_enum = pl.Enum(input_ids) + + return ( + pl.concat( + [ + df.select( + pl.lit(input_id).cast(input_id_enum).alias(Column.input_id), + pl.col(sorted(df.collect_schema().names())), + ) + for input_id, df in zip(input_ids, output, strict=True) + ], + how="vertical", + ) + .sort(Column.input_id) + .with_row_index(Column.sample_idx) + .rechunk() + ) + + @classmethod + def _build_sources( + cls, sources: Mapping[Id, Mapping[Id, SourceConfig]] + ) -> pl.DataFrame: + logger.debug("building sources") + + input_id_enum = pl.Enum(categories=sources.keys()) + + return ( + pl.DataFrame( + [ + { + Column.input_id: input_id, + (k := "__source"): [ + source_cfg.model_dump(exclude={"source"}) + | { + "id": source_id, + "config": source_cfg.source.model_dump_json( + by_alias=True + ), + } + for source_id, source_cfg in input_cfg.items() + ], + } + for input_id, input_cfg in sources.items() + ], + schema_overrides={Column.input_id: input_id_enum}, + ) + .explode(k) + .unnest(k) + .select(Column.input_id, pl.exclude(Column.input_id).name.prefix(f"{k}.")) + .rechunk() + ) diff --git a/src/rbyte/io/__init__.py b/src/rbyte/io/__init__.py index 8af4d9f..1035bd2 100644 --- a/src/rbyte/io/__init__.py +++ b/src/rbyte/io/__init__.py @@ -6,6 +6,8 @@ DataFrameFilter, DataFrameFpsResampler, DataFrameIndexer, + DataFrameJoiner, + FixedWindowSampleBuilder, ) from .path import PathDataFrameBuilder, PathTensorSource @@ -15,6 +17,8 @@ "DataFrameFilter", "DataFrameFpsResampler", "DataFrameIndexer", + "DataFrameJoiner", + "FixedWindowSampleBuilder", "JsonDataFrameBuilder", "NumpyTensorSource", "PathDataFrameBuilder", diff --git a/src/rbyte/io/_json/dataframe_builder.py b/src/rbyte/io/_json/dataframe_builder.py index 4c000b3..611966f 100644 --- a/src/rbyte/io/_json/dataframe_builder.py +++ b/src/rbyte/io/_json/dataframe_builder.py @@ -4,26 +4,40 @@ from typing import final import polars as pl -from optree import PyTree +from optree import PyTree, tree_map from polars._typing import PolarsDataType # noqa: PLC2701 from polars.datatypes import ( DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 ) from pydantic import ConfigDict, validate_call +from structlog import get_logger +from structlog.contextvars import bound_contextvars from rbyte.utils.dataframe import unnest_all +logger = get_logger(__name__) + + type Fields = Mapping[str, Mapping[str, PolarsDataType | None]] @final class JsonDataFrameBuilder: + __name__ = __qualname__ + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__(self, fields: Fields) -> None: self._fields = fields def __call__(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: + with bound_contextvars(path=path): + result = self._build(path) + logger.debug("built dataframes", length=tree_map(len, result)) + + return result + + def _build(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: dfs: Mapping[str, pl.DataFrame] = {} for k, series in ( diff --git a/src/rbyte/io/_mcap/dataframe_builder.py b/src/rbyte/io/_mcap/dataframe_builder.py index d095528..bf12049 100644 --- a/src/rbyte/io/_mcap/dataframe_builder.py +++ b/src/rbyte/io/_mcap/dataframe_builder.py @@ -45,6 +45,8 @@ class SpecialField(StrEnum): @final class McapDataFrameBuilder: + __name__ = __qualname__ + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( self, @@ -58,6 +60,15 @@ def __init__( self._validate_crcs = validate_crcs def __call__(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + with bound_contextvars(path=path): + result = self._build(path) + logger.debug( + "built dataframes", length={k: len(v) for k, v in result.items()} + ) + + return result + + def _build(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: with ( bound_contextvars(path=str(path)), Path(path).open("rb") as _f, diff --git a/src/rbyte/io/dataframe/__init__.py b/src/rbyte/io/dataframe/__init__.py index 59e7b9f..6cd7ef4 100644 --- a/src/rbyte/io/dataframe/__init__.py +++ b/src/rbyte/io/dataframe/__init__.py @@ -3,6 +3,8 @@ from .filter import DataFrameFilter from .fps_resampler import DataFrameFpsResampler from .indexer import DataFrameIndexer +from .joiner import DataFrameJoiner +from .sample_builder import FixedWindowSampleBuilder __all__ = [ "DataFrameAligner", @@ -10,4 +12,6 @@ "DataFrameFilter", "DataFrameFpsResampler", "DataFrameIndexer", + "DataFrameJoiner", + "FixedWindowSampleBuilder", ] diff --git a/src/rbyte/io/dataframe/aligner.py b/src/rbyte/io/dataframe/aligner.py index eafd482..21fa4e1 100644 --- a/src/rbyte/io/dataframe/aligner.py +++ b/src/rbyte/io/dataframe/aligner.py @@ -9,63 +9,77 @@ PyTree, PyTreeAccessor, tree_accessors, + tree_map, tree_map_with_accessor, tree_map_with_path, ) from polars._typing import AsofJoinStrategy from pydantic import Field, validate_call from structlog import get_logger -from structlog.contextvars import bound_contextvars from rbyte.config.base import BaseModel logger = get_logger(__name__) -class InterpColumnMergeConfig(BaseModel): +class InterpColumnAlignConfig(BaseModel): method: Literal["interp"] = "interp" -class AsofColumnMergeConfig(BaseModel): +class AsofColumnAlignConfig(BaseModel): method: Literal["asof"] = "asof" strategy: AsofJoinStrategy = "backward" tolerance: str | int | float | timedelta | None = None -type ColumnMergeConfig = InterpColumnMergeConfig | AsofColumnMergeConfig +type ColumnAlignConfig = InterpColumnAlignConfig | AsofColumnAlignConfig -class MergeConfig(BaseModel): +class AlignConfig(BaseModel): key: str - columns: OrderedDict[str, ColumnMergeConfig] = Field(default_factory=OrderedDict) + columns: OrderedDict[str, ColumnAlignConfig] = Field(default_factory=OrderedDict) -type Fields = MergeConfig | OrderedDict[str, Fields] +type Fields = AlignConfig | OrderedDict[str, Fields] @final class DataFrameAligner: + __name__ = __qualname__ + @validate_call def __init__(self, *, fields: Fields, separator: str = "/") -> None: self._fields = fields self._separator = separator @cached_property - def _fully_qualified_fields(self) -> PyTree[MergeConfig]: - def fqn(path: tuple[str, ...], cfg: MergeConfig) -> MergeConfig: + def _fully_qualified_fields(self) -> PyTree[AlignConfig]: + def fqn(path: tuple[str, ...], cfg: AlignConfig) -> AlignConfig: key = self._separator.join((*path, cfg.key)) columns = OrderedDict({ self._separator.join((*path, k)): v for k, v in cfg.columns.items() }) - return MergeConfig(key=key, columns=columns) + return AlignConfig(key=key, columns=columns) return tree_map_with_path(fqn, self._fields) # pyright: ignore[reportArgumentType] def __call__(self, input: PyTree[pl.DataFrame]) -> pl.DataFrame: + result = self._build(input) + logger.debug( + "aligned dataframes", + length={"input": tree_map(len, input), "result": len(result)}, + ) + + return result + + def _build(self, input: PyTree[pl.DataFrame]) -> pl.DataFrame: fields = self._fully_qualified_fields + accessors = tree_accessors(fields) + accessor, *accessors_rest = accessors + left_on = accessor(fields).key - def get_df(accessor: PyTreeAccessor, cfg: MergeConfig) -> pl.DataFrame: + def get_df(accessor: PyTreeAccessor, cfg: AlignConfig) -> pl.DataFrame: return ( accessor(input) .rename(lambda col: self._separator.join((*accessor.path, col))) # pyright: ignore[reportUnknownLambdaType, reportUnknownArgumentType] @@ -73,65 +87,54 @@ def get_df(accessor: PyTreeAccessor, cfg: MergeConfig) -> pl.DataFrame: ) dfs = tree_map_with_accessor(get_df, fields) - accessor, *accessors_rest = tree_accessors(fields) df: pl.DataFrame = accessor(dfs) - left_on = accessor(fields).key for accessor in accessors_rest: other: pl.DataFrame = accessor(dfs) - merge_config: MergeConfig = accessor(fields) - key = merge_config.key - - for column, config in merge_config.columns.items(): - df_height_pre = df.height - - with bound_contextvars(key=key, column=column, config=config): - match config: - case AsofColumnMergeConfig( - strategy=strategy, tolerance=tolerance - ): - right_on = key if key == column else uuid4().hex - - df = ( - df.join_asof( - other=other.select({key, column}).rename({ - key: right_on - }), - left_on=left_on, - right_on=right_on, - strategy=strategy, - tolerance=tolerance, - ) - .drop_nulls(column) - .drop({right_on} - {key}) + align_config: AlignConfig = accessor(fields) + key = align_config.key + + for column, config in align_config.columns.items(): + match config: + case AsofColumnAlignConfig(strategy=strategy, tolerance=tolerance): + right_on = key if key == column else uuid4().hex + + df = ( + df.join_asof( + other=other.select({key, column}).rename({ + key: right_on + }), + left_on=left_on, + right_on=right_on, + strategy=strategy, + tolerance=tolerance, ) - - case InterpColumnMergeConfig(): - if key == column: - logger.error(msg := "cannot interpolate key") - - raise ValueError(msg) - - right_on = key - - df = ( - # take a union of timestamps - df.join( - other.select(right_on, column), - how="full", - left_on=left_on, - right_on=right_on, - coalesce=True, - ) - # interpolate - .with_columns(pl.col(column).interpolate_by(left_on)) - # narrow back to original ref col - .join(df.select(left_on), on=left_on, how="semi") - .sort(left_on) - ).drop_nulls(column) - - logger.debug( - "merged", column=column, height=f"{df_height_pre}->{df.height}" - ) + .drop_nulls(column) + .drop({right_on} - {key}) + ) + + case InterpColumnAlignConfig(): + if key == column: + logger.error(msg := "cannot interpolate key") + + raise ValueError(msg) + + right_on = key + + df = ( + # take a union of timestamps + df.join( + other.select(right_on, column), + how="full", + left_on=left_on, + right_on=right_on, + coalesce=True, + ) + # interpolate + .with_columns(pl.col(column).interpolate_by(left_on)) + # narrow back to original ref col + .join(df.select(left_on), on=left_on, how="semi") + .sort(left_on) + ).drop_nulls(column) return df diff --git a/src/rbyte/io/dataframe/concater.py b/src/rbyte/io/dataframe/concater.py index 47078fd..709040d 100644 --- a/src/rbyte/io/dataframe/concater.py +++ b/src/rbyte/io/dataframe/concater.py @@ -8,6 +8,8 @@ @final class DataFrameConcater: + __name__ = __qualname__ + @validate_call def __init__( self, method: ConcatMethod = "horizontal", separator: str | None = None diff --git a/src/rbyte/io/dataframe/filter.py b/src/rbyte/io/dataframe/filter.py index df00251..258b29d 100644 --- a/src/rbyte/io/dataframe/filter.py +++ b/src/rbyte/io/dataframe/filter.py @@ -5,6 +5,8 @@ @final class DataFrameFilter: + __name__ = __qualname__ + def __init__(self, predicate: str) -> None: self._query = f"select * from self where {predicate}" # noqa: S608 diff --git a/src/rbyte/io/dataframe/fps_resampler.py b/src/rbyte/io/dataframe/fps_resampler.py index 0056f80..93df618 100644 --- a/src/rbyte/io/dataframe/fps_resampler.py +++ b/src/rbyte/io/dataframe/fps_resampler.py @@ -8,6 +8,8 @@ @final class DataFrameFpsResampler: + __name__ = __qualname__ + IDX_COL = uuid4().hex @validate_call diff --git a/src/rbyte/io/dataframe/indexer.py b/src/rbyte/io/dataframe/indexer.py index 5c34937..1797a73 100644 --- a/src/rbyte/io/dataframe/indexer.py +++ b/src/rbyte/io/dataframe/indexer.py @@ -8,6 +8,8 @@ @final class DataFrameIndexer: + __name__ = __qualname__ + @validate_call def __init__(self, name: str) -> None: self._fn = partial(pl.DataFrame.with_row_index, name=name) diff --git a/src/rbyte/io/dataframe/joiner.py b/src/rbyte/io/dataframe/joiner.py new file mode 100644 index 0000000..6f44e26 --- /dev/null +++ b/src/rbyte/io/dataframe/joiner.py @@ -0,0 +1,31 @@ +from collections.abc import Sequence +from typing import TypedDict, Unpack, final + +import polars as pl +from polars import Expr +from polars._typing import JoinStrategy, JoinValidation, MaintainOrderJoin +from pydantic import ConfigDict, validate_call + + +class _Kwargs(TypedDict, total=False): + on: str | Expr | Sequence[str | Expr] | None + how: JoinStrategy + left_on: str | Expr | Sequence[str | Expr] | None + right_on: str | Expr | Sequence[str | Expr] | None + suffix: str + validate: JoinValidation + nulls_equal: bool + coalesce: bool | None + maintain_order: MaintainOrderJoin | None + + +@final +class DataFrameJoiner: + __name__ = __qualname__ + + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__(self, **kwargs: Unpack[_Kwargs]) -> None: + self._kwargs = kwargs + + def __call__(self, left: pl.DataFrame, right: pl.DataFrame) -> pl.DataFrame: + return left.join(other=right, **self._kwargs) diff --git a/src/rbyte/sample/fixed_window.py b/src/rbyte/io/dataframe/sample_builder.py similarity index 78% rename from src/rbyte/sample/fixed_window.py rename to src/rbyte/io/dataframe/sample_builder.py index 305c489..f90e914 100644 --- a/src/rbyte/sample/fixed_window.py +++ b/src/rbyte/io/dataframe/sample_builder.py @@ -5,17 +5,19 @@ import polars as pl from polars._typing import ClosedInterval from pydantic import PositiveInt, validate_call +from structlog import get_logger + +logger = get_logger(__name__) @final class FixedWindowSampleBuilder: """ - Build samples using fixed (potentially overlapping) windows based on a temporal or - integer column. - - https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.group_by_dynamic + Build samples using `polars.DataFrame.group_by_dynamic`. """ + __name__ = __qualname__ + @validate_call def __init__( # noqa: PLR0913 self, @@ -27,6 +29,7 @@ def __init__( # noqa: PLR0913 gather_every: PositiveInt = 1, length: PositiveInt | None = None, ) -> None: + self._index_column_name = index_column self._index_column = pl.col(index_column) self._every = every self._period = period @@ -39,6 +42,14 @@ def __init__( # noqa: PLR0913 ) def __call__(self, input: pl.DataFrame) -> pl.DataFrame: + result = self._build(input) + logger.debug( + "built samples", index_column=self._index_column_name, length=len(result) + ) + + return result + + def _build(self, input: pl.DataFrame) -> pl.DataFrame: return ( input.sort(self._index_column) .with_columns(self._index_column.alias(_index_column := uuid4().hex)) diff --git a/src/rbyte/io/hdf5/dataframe_builder.py b/src/rbyte/io/hdf5/dataframe_builder.py index b8b29c9..85ed536 100644 --- a/src/rbyte/io/hdf5/dataframe_builder.py +++ b/src/rbyte/io/hdf5/dataframe_builder.py @@ -1,8 +1,7 @@ from collections.abc import Mapping, Sequence from os import PathLike -from typing import cast, final +from typing import final -import numpy.typing as npt import polars as pl from h5py import Dataset, File from optree import PyTree, tree_map, tree_map_with_path @@ -12,27 +11,40 @@ DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 ) from pydantic import ConfigDict, validate_call +from structlog import get_logger +from structlog.contextvars import bound_contextvars + +logger = get_logger(__name__) + type Fields = Mapping[str, PolarsDataType | None] | Mapping[str, Fields] @final class Hdf5DataFrameBuilder: + __name__ = __qualname__ + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__(self, fields: Fields) -> None: self._fields = fields - def __call__(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: + def __call__(self, path: PathLike[str], prefix: str = "/") -> PyTree[pl.DataFrame]: + with bound_contextvars(path=path, prefix=prefix): + result = self._build(path, prefix) + logger.debug("built dataframes", length=tree_map(len, result)) + + return result + + def _build(self, path: PathLike[str], prefix: str) -> PyTree[pl.DataFrame]: with File(path) as f: def build_series( path: Sequence[str], dtype: PolarsDataType | None ) -> pl.Series | None: - key = "/".join(path) - match obj := f.get(key): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + name = "/".join((prefix, *path)) + match obj := f.get(name): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] case Dataset(): - values = cast(npt.ArrayLike, obj[:]) - return pl.Series(values=values, dtype=dtype) + return pl.Series(values=obj[:], dtype=dtype) # pyright: ignore[reportUnknownArgumentType] case None: return None diff --git a/src/rbyte/io/path/dataframe_builder.py b/src/rbyte/io/path/dataframe_builder.py index d289d0f..3b20c29 100644 --- a/src/rbyte/io/path/dataframe_builder.py +++ b/src/rbyte/io/path/dataframe_builder.py @@ -13,6 +13,7 @@ ) from pydantic import ConfigDict, validate_call from structlog import get_logger +from structlog.contextvars import bound_contextvars logger = get_logger(__name__) @@ -22,11 +23,20 @@ @final class PathDataFrameBuilder: + __name__ = __qualname__ + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__(self, fields: Fields) -> None: self._fields = fields def __call__(self, path: PathLike[str]) -> pl.DataFrame: + with bound_contextvars(path=path): + result = self._build(path) + logger.debug("built dataframe", length=len(result)) + + return result + + def _build(self, path: PathLike[str]) -> pl.DataFrame: parser = parse.compile(Path(path).resolve().as_posix()) # pyright: ignore[reportUnknownMemberType] match parser.named_fields, parser.fixed_fields: # pyright: ignore[reportUnknownMemberType] case ([_, *_], []): # pyright: ignore[reportUnknownVariableType] @@ -45,6 +55,6 @@ def __call__(self, path: PathLike[str]) -> pl.DataFrame: ) paths = map(Path.as_posix, parent.rglob("*")) parsed = map(parser.parse, paths) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] - data = (x.named for x in parsed if isinstance(x, parse.Result)) # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + data = [x.named for x in parsed if isinstance(x, parse.Result)] # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - return pl.DataFrame(data=data, schema=self._fields) + return pl.DataFrame(data=data, schema=self._fields) # pyright: ignore[reportUnknownArgumentType] diff --git a/src/rbyte/io/rrd/dataframe_builder.py b/src/rbyte/io/rrd/dataframe_builder.py index 1691b3f..2d4f14e 100644 --- a/src/rbyte/io/rrd/dataframe_builder.py +++ b/src/rbyte/io/rrd/dataframe_builder.py @@ -7,6 +7,10 @@ import polars as pl import rerun.dataframe as rrd from pydantic import validate_call +from structlog import get_logger +from structlog.contextvars import bound_contextvars + +logger = get_logger(__name__) @unique @@ -17,6 +21,8 @@ class Column(StrEnum): @final class RrdDataFrameBuilder: + __name__ = __qualname__ + @validate_call def __init__( self, index: str, contents: Mapping[str, Sequence[str] | None] @@ -25,6 +31,15 @@ def __init__( self._contents = contents def __call__(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + with bound_contextvars(path=path): + result = self._build(path) + logger.debug( + "built dataframes", length={k: len(v) for k, v in result.items()} + ) + + return result + + def _build(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: recording = rrd.load_recording(path) # pyright: ignore[reportUnknownMemberType] schema = recording.schema() diff --git a/src/rbyte/io/video/dataframe_builder.py b/src/rbyte/io/video/dataframe_builder.py index 32adf74..550dc85 100644 --- a/src/rbyte/io/video/dataframe_builder.py +++ b/src/rbyte/io/video/dataframe_builder.py @@ -9,20 +9,34 @@ IntegerType, # pyright: ignore[reportUnusedImport] # noqa: F401 ) from pydantic import ConfigDict, validate_call +from structlog import get_logger +from structlog.contextvars import bound_contextvars from video_reader import ( PyVideoReader, # pyright: ignore[reportAttributeAccessIssue, reportUnknownVariableType] ) +logger = get_logger(__name__) + + type Fields = Mapping[Literal["frame_idx"], PolarsIntegerType] @final class VideoDataFrameBuilder: + __name__ = __qualname__ + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__(self, fields: Fields) -> None: self._fields = fields def __call__(self, path: PathLike[str]) -> pl.DataFrame: + with bound_contextvars(path=path): + result = self._build(path) + logger.debug("built dataframe", length=len(result)) + + return result + + def _build(self, path: PathLike[str]) -> pl.DataFrame: vr = PyVideoReader(Path(path).resolve().as_posix()) # pyright: ignore[reportUnknownVariableType] info = vr.get_info() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] frame_count = int(info["frame_count"]) # pyright: ignore[reportUnknownArgumentType] diff --git a/src/rbyte/io/yaak/dataframe_builder.py b/src/rbyte/io/yaak/dataframe_builder.py index d7e1430..281199c 100644 --- a/src/rbyte/io/yaak/dataframe_builder.py +++ b/src/rbyte/io/yaak/dataframe_builder.py @@ -14,11 +14,14 @@ DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 ) from ptars import HandlerPool -from pydantic import ConfigDict, ImportString, validate_call +from pydantic import ConfigDict, validate_call from structlog import get_logger +from structlog.contextvars import bound_contextvars from tqdm import tqdm from xxhash import xxh3_64_hexdigest as digest +from rbyte.config.base import PickleableImportString + from .message_iterator import YaakMetadataMessageIterator from .proto import sensor_pb2 @@ -26,12 +29,14 @@ type Fields = Mapping[ - type[Message] | ImportString[type[Message]], Mapping[str, PolarsDataType | None] + PickleableImportString[type[Message]], Mapping[str, PolarsDataType | None] ] @final class YaakMetadataDataFrameBuilder: + __name__ = __qualname__ + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__(self, *, fields: Fields) -> None: super().__init__() @@ -42,24 +47,34 @@ def __pipefunc_hash__(self) -> str: # noqa: PLW3201 return digest(str(self._fields)) def __call__(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + with bound_contextvars(path=path): + result = self._build(path) + logger.debug( + "built dataframes", length={k: len(v) for k, v in result.items()} + ) + + return result + + def _build(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: with Path(path).open("rb") as _f, mmap(_f.fileno(), 0, access=ACCESS_READ) as f: handler_pool = HandlerPool() + message_types = {k.obj for k in self._fields} messages = mit.bucket( - YaakMetadataMessageIterator(f, message_types=self._fields), + YaakMetadataMessageIterator(f, message_types=message_types), key=itemgetter(0), - validator=self._fields.__contains__, + validator=message_types.__contains__, ) dfs = { - msg_type.__name__: cast( + msg.obj.__name__: cast( pl.DataFrame, pl.from_arrow( # pyright: ignore[reportUnknownMemberType] - data=handler_pool.get_for_message(msg_type.DESCRIPTOR) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] + data=handler_pool.get_for_message(msg.obj.DESCRIPTOR) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] .list_to_record_batch([ msg_data for (_, msg_data) in tqdm( - messages[msg_type], postfix={"msg_type": msg_type} + messages[msg.obj], postfix={"msg": msg.obj} ) ]) .select(schema), @@ -67,7 +82,7 @@ def __call__(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: rechunk=True, ), ) - for msg_type, schema in self._fields.items() + for msg, schema in self._fields.items() } if (df := dfs.pop((k := sensor_pb2.ImageMetadata.__name__), None)) is not None: diff --git a/src/rbyte/io/yaak/idl-repo b/src/rbyte/io/yaak/idl-repo index c9e3d52..1f8f30e 160000 --- a/src/rbyte/io/yaak/idl-repo +++ b/src/rbyte/io/yaak/idl-repo @@ -1 +1 @@ -Subproject commit c9e3d5284935ef5a070bb3961a7ee5fd1cdd588c +Subproject commit 1f8f30e5d730f31818411ebbed9c6eab2f7f2a4e diff --git a/src/rbyte/io/yaak/message_iterator.py b/src/rbyte/io/yaak/message_iterator.py index 1dc5a4c..c540ace 100644 --- a/src/rbyte/io/yaak/message_iterator.py +++ b/src/rbyte/io/yaak/message_iterator.py @@ -1,9 +1,11 @@ import struct -from collections.abc import Iterable, Iterator, Mapping +from collections.abc import Iterator, Mapping +from collections.abc import Set as AbstractSet from mmap import mmap from typing import BinaryIO, Self, override from google.protobuf.message import Message +from pydantic import ConfigDict, validate_call from structlog import get_logger from structlog.contextvars import bound_contextvars @@ -30,10 +32,12 @@ class YaakMetadataMessageIterator(Iterator[tuple[type[Message], bytes]]): FILE_HEADER_LEN: int = 12 MESSAGE_HEADER_LEN: int = 8 + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( self, file: BinaryIO | mmap, - message_types: Iterable[type[Message]] | None = None, + *, + message_types: AbstractSet[type[Message]] | None = None, ) -> None: super().__init__() @@ -51,8 +55,8 @@ def __init__( if message_types is None: self._message_types: Mapping[int, type[Message]] = self.MESSAGE_TYPES else: - if unknown_message_types := set(message_types) - set( - self.MESSAGE_TYPES.values() + if unknown_message_types := ( + message_types - set(self.MESSAGE_TYPES.values()) ): with bound_contextvars(unknown_message_types=unknown_message_types): logger.error(msg := "unknown message types") diff --git a/src/rbyte/sample/__init__.py b/src/rbyte/sample/__init__.py deleted file mode 100644 index a1dcb35..0000000 --- a/src/rbyte/sample/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .fixed_window import FixedWindowSampleBuilder - -__all__ = ["FixedWindowSampleBuilder"] diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index d5f8193..d4503db 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -42,7 +42,7 @@ def test_mimicgen() -> None: **meta_rest, }, **batch_rest, - } if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any(( + } if set(input_id).issubset(cfg.dataloader.dataset.sources) and not any(( batch_rest, data_rest, meta_rest, @@ -93,7 +93,7 @@ def test_nuscenes_mcap() -> None: **meta_rest, }, **batch_rest, - } if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any(( + } if set(input_id).issubset(cfg.dataloader.dataset.sources) and not any(( batch_rest, data_rest, meta_rest, @@ -144,7 +144,7 @@ def test_nuscenes_rrd() -> None: **meta_rest, }, **batch_rest, - } if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any(( + } if set(input_id).issubset(cfg.dataloader.dataset.sources) and not any(( batch_rest, data_rest, meta_rest, @@ -202,7 +202,7 @@ def test_yaak() -> None: **meta_rest, }, **batch_rest, - } if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any(( + } if set(input_id).issubset(cfg.dataloader.dataset.sources) and not any(( batch_rest, data_rest, meta_rest, @@ -234,9 +234,9 @@ def test_zod() -> None: case { "data": { "camera_front_blur": Tensor(shape=[c.B, *_]), - "camera_front_blur/timestamp": Tensor(shape=[c.B, *_]), + "camera_front_blur_meta/timestamp": Tensor(shape=[c.B, *_]), "lidar_velodyne": Tensor(shape=[c.B, *_]), - "lidar_velodyne/timestamp": Tensor(shape=[c.B, *_]), + "lidar_velodyne_meta/timestamp": Tensor(shape=[c.B, *_]), "vehicle_data/ego_vehicle_controls/acceleration_pedal/ratio/unitless/value": Tensor( # noqa: E501 shape=[c.B, *_] ), @@ -257,7 +257,7 @@ def test_zod() -> None: **meta_rest, }, **batch_rest, - } if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any(( + } if set(input_id).issubset(cfg.dataloader.dataset.sources) and not any(( batch_rest, data_rest, meta_rest,