Skip to content

Commit

Permalink
collection dataset fix
Browse files Browse the repository at this point in the history
  • Loading branch information
supersergiy committed Jan 20, 2025
1 parent 369e022 commit 8cc74e7
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions zetta_utils/training/datasets/collection_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from collections import defaultdict
from typing import Sequence

import fsspec
Expand Down Expand Up @@ -28,15 +29,19 @@ def build_collection_dataset(
chunk_size: Sequence[int],
chunk_stride: Sequence[int],
layer_rename_map: dict[str, str],
shared_read_procs: Sequence[DataProcessor],
per_layer_read_procs: dict[str, Sequence[DataProcessor]],
per_layer_read_procs: dict[str, Sequence[DataProcessor]] | None = None,
shared_read_procs: Sequence[DataProcessor] = tuple(),
tags: list[str] | None = None,
) -> JointDataset:
datasets = {}
annotations = db_annotations.read_annotations(

Check warning on line 37 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L36-L37

Added lines #L36 - L37 were not covered by tests
collection_ids=[collection_name], tags=tags, union=False
)
layer_group_map: dict[str, dict[str, Layer]] = {}
if per_layer_read_procs is None:
per_layer_read_procs_dict = defaultdict(tuple)

Check warning on line 42 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L40-L42

Added lines #L40 - L42 were not covered by tests
else:
per_layer_read_procs_dict = per_layer_read_procs

Check warning on line 44 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L44

Added line #L44 was not covered by tests

for i, annotation in enumerate(annotations.values()):
if annotation.layer_group not in layer_group_map:
Expand All @@ -47,17 +52,18 @@ def build_collection_dataset(
name = layer.name
if name in layer_rename_map:
name = layer_rename_map[name]
read_procs = per_layer_read_procs[name]
read_procs = per_layer_read_procs_dict[name]
layers[name] = build_cv_layer(path=layer.source, read_procs=read_procs)
layer_group_map[annotation.layer_group] = layers

Check warning on line 57 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L46-L57

Added lines #L46 - L57 were not covered by tests
else:
layers = layer_group_map[annotation.layer_group]

Check warning on line 59 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L59

Added line #L59 was not covered by tests

z_resolution = resolution[-1]
for layer in layers.values():
info_path = os.path.join(layer.backend.name.strip("precomputed://"), "info")
with fsspec.open(info_path) as f:
info = json.loads(f.read())
z_resolutions = set([e["resolution"][-1] for e in info["scales"]])
z_resolutions = {e["resolution"][-1] for e in info["scales"]}
if len(z_resolutions) != 1:
raise RuntimeError("Only layers with single z resolution are supported")
z_resolution = list(z_resolutions)[0]
Expand All @@ -68,15 +74,13 @@ def build_collection_dataset(
point_b_nm = Vec3D(*ng_bbox.pointB).int()

Check warning on line 74 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L61-L74

Added lines #L61 - L74 were not covered by tests

start_coord = [

Check warning on line 76 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L76

Added line #L76 was not covered by tests
round((min(point_a_nm[i], point_b_nm[i]) / this_resolution[i]))
* this_resolution[i]
int((min(point_a_nm[i], point_b_nm[i]) / this_resolution[i])) * this_resolution[i]
for i in range(3)
]
end_coord = [

Check warning on line 80 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L80

Added line #L80 was not covered by tests
round(max(point_a_nm[i], point_b_nm[i]) / this_resolution[i]) * this_resolution[i]
int(max(point_a_nm[i], point_b_nm[i]) / this_resolution[i]) * this_resolution[i]
for i in range(3)
]

bbox = BBox3D.from_coords(start_coord=start_coord, end_coord=end_coord)
datasets[str(i)] = LayerDataset(

Check warning on line 85 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L84-L85

Added lines #L84 - L85 were not covered by tests
layer=build_layer_set(layers=layers, read_procs=shared_read_procs),
Expand Down

0 comments on commit 8cc74e7

Please sign in to comment.