diff --git a/zetta_utils/training/datasets/collection_dataset.py b/zetta_utils/training/datasets/collection_dataset.py index 71f8b822c..058418ba1 100644 --- a/zetta_utils/training/datasets/collection_dataset.py +++ b/zetta_utils/training/datasets/collection_dataset.py @@ -1,5 +1,6 @@ import json import os +from collections import defaultdict from typing import Sequence import fsspec @@ -28,8 +29,8 @@ def build_collection_dataset( chunk_size: Sequence[int], chunk_stride: Sequence[int], layer_rename_map: dict[str, str], - shared_read_procs: Sequence[DataProcessor], - per_layer_read_procs: dict[str, Sequence[DataProcessor]], + per_layer_read_procs: dict[str, Sequence[DataProcessor]] | None = None, + shared_read_procs: Sequence[DataProcessor] = tuple(), tags: list[str] | None = None, ) -> JointDataset: datasets = {} @@ -37,6 +38,10 @@ def build_collection_dataset( collection_ids=[collection_name], tags=tags, union=False ) layer_group_map: dict[str, dict[str, Layer]] = {} + if per_layer_read_procs is None: + per_layer_read_procs_dict = defaultdict(tuple) + else: + per_layer_read_procs_dict = per_layer_read_procs for i, annotation in enumerate(annotations.values()): if annotation.layer_group not in layer_group_map: @@ -47,17 +52,18 @@ def build_collection_dataset( name = layer.name if name in layer_rename_map: name = layer_rename_map[name] - read_procs = per_layer_read_procs[name] + read_procs = per_layer_read_procs_dict[name] layers[name] = build_cv_layer(path=layer.source, read_procs=read_procs) layer_group_map[annotation.layer_group] = layers else: layers = layer_group_map[annotation.layer_group] + z_resolution = resolution[-1] for layer in layers.values(): info_path = os.path.join(layer.backend.name.strip("precomputed://"), "info") with fsspec.open(info_path) as f: info = json.loads(f.read()) - z_resolutions = set([e["resolution"][-1] for e in info["scales"]]) + z_resolutions = {e["resolution"][-1] for e in info["scales"]} if len(z_resolutions) != 1: raise RuntimeError("Only layers with single z resolution are supported") z_resolution = list(z_resolutions)[0] @@ -68,15 +74,13 @@ def build_collection_dataset( point_b_nm = Vec3D(*ng_bbox.pointB).int() start_coord = [ - round((min(point_a_nm[i], point_b_nm[i]) / this_resolution[i])) - * this_resolution[i] + int((min(point_a_nm[i], point_b_nm[i]) / this_resolution[i])) * this_resolution[i] for i in range(3) ] end_coord = [ - round(max(point_a_nm[i], point_b_nm[i]) / this_resolution[i]) * this_resolution[i] + int(max(point_a_nm[i], point_b_nm[i]) / this_resolution[i]) * this_resolution[i] for i in range(3) ] - bbox = BBox3D.from_coords(start_coord=start_coord, end_coord=end_coord) datasets[str(i)] = LayerDataset( layer=build_layer_set(layers=layers, read_procs=shared_read_procs),