Skip to content

Commit

Permalink
feat: parallel sample processing
Browse files Browse the repository at this point in the history
  • Loading branch information
egorchakov committed Mar 3, 2025
1 parent 6eb6844 commit 8dd0367
Show file tree
Hide file tree
Showing 34 changed files with 1,055 additions and 812 deletions.
23 changes: 6 additions & 17 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,9 @@ jobs:
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"

- name: sync
run: just sync

- name: build protos
run: just build-protos

- name: format
run: just format --check

- name: lint
run: just lint

- name: typecheck
run: just typecheck

- name: test
run: just test
- run: just sync
- run: just build
- run: just format --check
- run: just lint
- run: just typecheck
- run: just test
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ repos:
- id: pyupgrade

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.5
rev: v0.9.9
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror
rev: 1.26.0
rev: 1.28.1
hooks:
- id: basedpyright

Expand Down
138 changes: 74 additions & 64 deletions config/_templates/dataset/carla.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,75 +56,85 @@
_target_: rbyte.Dataset
_recursive_: false
_convert_: all
inputs:
sources:
#@ for input_id in drives:
(@=input_id@):
sources:
#@ for source_id in cameras:
(@=source_id@):
index_column: _idx_
source:
_target_: rbyte.io.PathTensorSource
path: "${data_dir}/(@=input_id@)/frames/(@=source_id@).defish.mp4/576x324/{:09d}.jpg"
decoder:
_target_: simplejpeg.decode_jpeg
_partial_: true
colorspace: rgb
fastdct: true
fastupsample: true
#@ for source_id in cameras:
(@=source_id@):
index_column: _idx_
source:
_target_: rbyte.io.PathTensorSource
path: "${data_dir}/(@=input_id@)/frames/(@=source_id@).defish.mp4/576x324/{:09d}.jpg"
decoder:
_target_: simplejpeg.decode_jpeg
_partial_: true
colorspace: rgb
fastdct: true
fastupsample: true

#@ end
#@ end
#@ end

samples:
pipeline:
_target_: pipefunc.Pipeline
validate_type_annotations: false
functions:
- _target_: pipefunc.PipeFunc
bound:
path: ${data_dir}/(@=input_id@)/ego_logs.json
output_name: ego_logs
func:
_target_: rbyte.io.JsonDataFrameBuilder
fields:
records:
control.brake:
control.throttle:
control.steer:
state.velocity.value:
state.acceleration.value:
samples:
inputs:
#@ for input_id in drives:
(@=input_id@):
ego_logs_path: ${data_dir}/(@=input_id@)/ego_logs.json
#@ end

- _target_: pipefunc.PipeFunc
renames:
input: ego_logs
output_name: data
func:
_target_: rbyte.io.DataFrameConcater
method: vertical
pipeline:
_target_: pipefunc.Pipeline
validate_type_annotations: false
functions:
- _target_: pipefunc.PipeFunc
renames:
path: ego_logs_path
output_name: ego_logs
mapspec: "ego_logs_path[i] -> ego_logs[i]"
func:
_target_: rbyte.io.JsonDataFrameBuilder
fields:
records:
control.brake:
control.throttle:
control.steer:
state.velocity.value:
state.acceleration.value:

- _target_: pipefunc.PipeFunc
renames:
input: data
output_name: data_resampled
func:
_target_: rbyte.io.DataFrameFpsResampler
fps_in: 20
fps_out: 30
- _target_: pipefunc.PipeFunc
renames:
input: ego_logs
output_name: data
mapspec: "ego_logs[i] -> data[i]"
func:
_target_: rbyte.io.DataFrameConcater
method: vertical

- _target_: pipefunc.PipeFunc
renames:
input: data_resampled
output_name: data_indexed
func:
_target_: rbyte.io.DataFrameIndexer
name: _idx_
- _target_: pipefunc.PipeFunc
renames:
input: data
output_name: resampled
mapspec: "data[i] -> resampled[i]"
func:
_target_: rbyte.io.DataFrameFpsResampler
fps_in: 20
fps_out: 30

- _target_: pipefunc.PipeFunc
renames:
input: data_indexed
output_name: data_filtered
func:
_target_: rbyte.io.DataFrameFilter
predicate: |
`control.throttle` > 0.5
#@ end
- _target_: pipefunc.PipeFunc
renames:
input: resampled
output_name: indexed
mapspec: "resampled[i] -> indexed[i]"
func:
_target_: rbyte.io.DataFrameIndexer
name: _idx_

- _target_: pipefunc.PipeFunc
renames:
input: indexed
output_name: samples
mapspec: "indexed[i] -> samples[i]"
func:
_target_: rbyte.io.DataFrameFilter
predicate: |
`control.throttle` > 0.5
91 changes: 51 additions & 40 deletions config/_templates/dataset/mimicgen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,49 +14,60 @@
_target_: rbyte.Dataset
_recursive_: false
_convert_: all
inputs:
sources:
#@ for input_id, input_keys in inputs.items():
#@ for input_key in input_keys:
(@=input_id@)(@=input_key@):
sources:
#@ for frame_key in frame_keys:
(@=frame_key@):
index_column: _idx_
source:
_target_: rbyte.io.Hdf5TensorSource
path: "${data_dir}/(@=input_id@).hdf5"
key: (@=input_key@)/(@=frame_key@)
#@ end
#@ for frame_key in frame_keys:
(@=frame_key@):
index_column: _idx_
source:
_target_: rbyte.io.Hdf5TensorSource
path: "${data_dir}/(@=input_id@).hdf5"
key: (@=input_key@)/(@=frame_key@)
#@ end
#@ end
#@ end

samples:
pipeline:
_target_: pipefunc.Pipeline
validate_type_annotations: false
functions:
- _target_: pipefunc.PipeFunc
bound:
path: "${data_dir}/(@=input_id@).hdf5"
output_name: data
func:
_target_: rbyte.io.Hdf5DataFrameBuilder
fields:
(@=input_key@):
obs/robot0_eef_pos:
samples:
inputs:
#@ for input_id, input_keys in inputs.items():
#@ for input_key in input_keys:
(@=input_id@)(@=input_key@):
path: "${data_dir}/(@=input_id@).hdf5"
prefix: (@=input_key@)
#@ end
#@ end

- _target_: pipefunc.PipeFunc
renames:
input: data
output_name: data_indexed
func:
_target_: rbyte.io.DataFrameIndexer
name: _idx_
executor:
_target_: concurrent.futures.ThreadPoolExecutor

- _target_: pipefunc.PipeFunc
renames:
input: data_indexed
output_name: data_concated
func:
_target_: rbyte.io.DataFrameConcater
method: vertical
#@ end
#@ end
pipeline:
_target_: pipefunc.Pipeline
validate_type_annotations: false
functions:
- _target_: pipefunc.PipeFunc
output_name: data
mapspec: "path[i], prefix[i] -> data[i]"
func:
_target_: rbyte.io.Hdf5DataFrameBuilder
fields:
obs/robot0_eef_pos:

- _target_: pipefunc.PipeFunc
renames:
input: data
output_name: indexed
mapspec: "data[i] -> indexed[i]"
func:
_target_: rbyte.io.DataFrameIndexer
name: _idx_

- _target_: pipefunc.PipeFunc
renames:
input: indexed
output_name: concated
mapspec: "indexed[i] -> concated[i]"
func:
_target_: rbyte.io.DataFrameConcater
method: vertical
Loading

0 comments on commit 8dd0367

Please sign in to comment.