From 26b2ae9f855a5ab95132496c925d028f63671d56 Mon Sep 17 00:00:00 2001 From: Sergiy Date: Fri, 30 Aug 2024 20:28:09 +0000 Subject: [PATCH 1/5] wip --- .../operations/copy_annotated.py | 19 +++++ zetta_utils/geometry/bbox.py | 19 ++++- zetta_utils/training/datasets/__init__.py | 1 + .../training/datasets/collection_dataset.py | 75 +++++++++++++++++++ zetta_utils/training/lightning/train.py | 2 +- .../training/lightning/trainers/default.py | 26 +++---- 6 files changed, 127 insertions(+), 15 deletions(-) create mode 100644 zetta_utils/db_annotations/operations/copy_annotated.py create mode 100644 zetta_utils/training/datasets/collection_dataset.py diff --git a/zetta_utils/db_annotations/operations/copy_annotated.py b/zetta_utils/db_annotations/operations/copy_annotated.py new file mode 100644 index 000000000..356c2e90f --- /dev/null +++ b/zetta_utils/db_annotations/operations/copy_annotated.py @@ -0,0 +1,19 @@ +from typing import Sequence + +from zetta_utils import builder +from zetta_utils.db_annotations.annotation import read_annotations +from zetta_utils.layer.volumetric.layer import VolumetricLayer + + +@builder.register("copy_annotated_data") +def CopyAnnotatedFlow( + src: VolumetricLayer, + dst: VolumetricLayer, + collection_name: str, + layer_group_name: str, + resolution: Sequence[float], +): + annotations = read_annotations( + collection_ids=[collection_name], layer_group_ids=[layer_group_name] + ) + breakpoint() diff --git a/zetta_utils/geometry/bbox.py b/zetta_utils/geometry/bbox.py index e81a7aa54..ba945b345 100644 --- a/zetta_utils/geometry/bbox.py +++ b/zetta_utils/geometry/bbox.py @@ -4,7 +4,8 @@ from itertools import product from math import floor from typing import Literal, Optional, Sequence, Union, cast - +from neuroglancer.viewer_state import AxisAlignedBoundingBoxAnnotation + import attrs from typeguard import typechecked @@ -63,6 +64,22 @@ def shape(self) -> Vec3D: # pragma: no cover """returns the shape coordinates.""" return self.end - self.start + @staticmethod + def from_ng_bbox( + ng_bbox: AxisAlignedBoundingBoxAnnotation, + base_resolution: Sequence[float] + ) -> BBox3D: + point_a_nm = Vec3D(*ng_bbox.pointA).int() * Vec3D(*base_resolution) + point_b_nm = Vec3D(*ng_bbox.pointB).int() * Vec3D(*base_resolution) + start_coord = [min(point_a_nm[i], point_b_nm[i]) for i in range(3)] + end_coord = [max(point_a_nm[i], point_b_nm[i]) for i in range(3)] + bbox = BBox3D.from_coords( + start_coord=start_coord, + end_coord=end_coord, + resolution=[1, 1, 1] + ) + return bbox + @classmethod def from_slices( cls, diff --git a/zetta_utils/training/datasets/__init__.py b/zetta_utils/training/datasets/__init__.py index 944f497ec..1cad10922 100644 --- a/zetta_utils/training/datasets/__init__.py +++ b/zetta_utils/training/datasets/__init__.py @@ -3,3 +3,4 @@ from .joint_dataset import JointDataset from .layer_dataset import LayerDataset from .sample_indexers import RandomIndexer, VolumetricStridedIndexer +from .collection_dataset import build_collection_dataset diff --git a/zetta_utils/training/datasets/collection_dataset.py b/zetta_utils/training/datasets/collection_dataset.py new file mode 100644 index 000000000..6cb6bd1aa --- /dev/null +++ b/zetta_utils/training/datasets/collection_dataset.py @@ -0,0 +1,75 @@ +from typing import Sequence + +from neuroglancer.viewer_state import AxisAlignedBoundingBoxAnnotation +from typeguard import typechecked + +from zetta_utils import builder, db_annotations +from zetta_utils.geometry.bbox import BBox3D +from zetta_utils.geometry.vec import Vec3D +from zetta_utils.layer.layer_base import Layer +from zetta_utils.layer.layer_set.build import build_layer_set +from zetta_utils.layer.tools_base import DataProcessor +from zetta_utils.layer.volumetric.cloudvol.build import build_cv_layer +from zetta_utils.layer.volumetric.layer import VolumetricLayer +from zetta_utils.training.datasets.joint_dataset import JointDataset +from zetta_utils.training.datasets.layer_dataset import LayerDataset +from zetta_utils.training.datasets.sample_indexers.volumetric_strided_indexer import ( + VolumetricStridedIndexer, +) + + +@builder.register("build_collection_dataset") +@typechecked +def build_collection_dataset( + collection_name: str, + resolution: Sequence[float], + chunk_size: Sequence[int], + chunk_stride: Sequence[int], + base_resolution: Sequence[float], + layer_rename_map: dict[str, str], + shared_read_procs: Sequence[DataProcessor], + per_layer_read_procs: dict[str, Sequence[DataProcessor]] , +) -> JointDataset: + + datasets = {} + annotations = db_annotations.read_annotations(collection_ids=[collection_name]) + layer_group_map: dict[str, dict[str, Layer]]= {} + + for i, annotation in enumerate(annotations.values()): + if annotation.layer_group_id not in layer_group_map: + layer_group = db_annotations.read_layer_group(annotation.layer_group_id) + db_layers = db_annotations.read_layers(layer_ids=layer_group.layers) + layers = {} + for layer in db_layers: + name = layer.name + if name in layer_rename_map: + name = layer_rename_map[name] + read_procs = per_layer_read_procs[name] + layers[name] = build_cv_layer(path=layer.source, read_procs=read_procs) + layer_group_map[annotation.layer_group_id] = layers + else: + layers = layer_group_map[annotation.layer_group_id] + + if isinstance(annotation.ng_annotation, AxisAlignedBoundingBoxAnnotation): + bbox = BBox3D.from_ng_bbox( + ng_bbox=annotation.ng_annotation, + base_resolution=base_resolution + ).snapped([0, 0, 0], resolution, "shrink") + datasets[str(i)] = ( + LayerDataset( + layer=build_layer_set( + layers=layers, + read_procs=shared_read_procs + ), + sample_indexer=VolumetricStridedIndexer( + resolution=resolution, + chunk_size=chunk_size, + stride=chunk_stride, + mode="shrink", + bbox=bbox, + ) + + ) + ) + dset = JointDataset(mode="vertical", datasets=datasets) + return dset diff --git a/zetta_utils/training/lightning/train.py b/zetta_utils/training/lightning/train.py index 5c1f671ae..fa6cee510 100644 --- a/zetta_utils/training/lightning/train.py +++ b/zetta_utils/training/lightning/train.py @@ -134,7 +134,7 @@ def lightning_train( train_args[k] = v else: arg_spec = builder.get_initial_builder_spec(v) - if arg_spec is None: + if arg_spec is None and v is not None: raise RuntimeError( f"No builder spec found for `{k}`. Remote training requires arguments to " "be created using `builder` module." diff --git a/zetta_utils/training/lightning/trainers/default.py b/zetta_utils/training/lightning/trainers/default.py index fe8bf62a8..a80371fc3 100644 --- a/zetta_utils/training/lightning/trainers/default.py +++ b/zetta_utils/training/lightning/trainers/default.py @@ -33,19 +33,19 @@ def trace_and_save_model( args_packed, ): # pragma: no cover # pylint: disable=broad-except, used-before-assignment model, trace_input, filepath, name = args_packed - trace = torch.jit.trace(model, trace_input) - filepath_jit = f"{filepath}.static-{torch.__version__}-{name}.jit" - with fsspec.open(filepath_jit, "wb") as f: - torch.jit.save(trace, f) - try: - filepath_onnx = f"{filepath}.static-{torch.__version__}-{name}.onnx" - with fsspec.open(filepath_onnx, "wb") as f: - filesystem = f.fs - torch.onnx.export(model, trace_input, f, opset_version=ONNX_OPSET_VERSION) - return None - except Exception as e: - filesystem.delete(filepath_onnx) - return type(e).__name__, e.args[0] + # trace = torch.jit.trace(model, trace_input) + # filepath_jit = f"{filepath}.static-{torch.__version__}-{name}.jit" + # with fsspec.open(filepath_jit, "wb") as f: + # torch.jit.save(trace, f) + # try: + # filepath_onnx = f"{filepath}.static-{torch.__version__}-{name}.onnx" + # with fsspec.open(filepath_onnx, "wb") as f: + # filesystem = f.fs + # torch.onnx.export(model, trace_input, f, opset_version=ONNX_OPSET_VERSION) + # return None + # except Exception as e: + # filesystem.delete(filepath_onnx) + # return type(e).__name__, e.args[0] @builder.register("ZettaDefaultTrainer") From c4920ad4e049e90317442510f3d04e2c10137ca5 Mon Sep 17 00:00:00 2001 From: Sergiy Date: Mon, 28 Oct 2024 00:47:05 +0000 Subject: [PATCH 2/5] feat: tags --- .../training/datasets/collection_dataset.py | 35 ++++++++----------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/zetta_utils/training/datasets/collection_dataset.py b/zetta_utils/training/datasets/collection_dataset.py index 6cb6bd1aa..18c448d9b 100644 --- a/zetta_utils/training/datasets/collection_dataset.py +++ b/zetta_utils/training/datasets/collection_dataset.py @@ -28,12 +28,12 @@ def build_collection_dataset( base_resolution: Sequence[float], layer_rename_map: dict[str, str], shared_read_procs: Sequence[DataProcessor], - per_layer_read_procs: dict[str, Sequence[DataProcessor]] , + per_layer_read_procs: dict[str, Sequence[DataProcessor]], + tags: list[str] | None = None, ) -> JointDataset: - datasets = {} - annotations = db_annotations.read_annotations(collection_ids=[collection_name]) - layer_group_map: dict[str, dict[str, Layer]]= {} + annotations = db_annotations.read_annotations(collection_ids=[collection_name], tags=tags) + layer_group_map: dict[str, dict[str, Layer]] = {} for i, annotation in enumerate(annotations.values()): if annotation.layer_group_id not in layer_group_map: @@ -52,24 +52,17 @@ def build_collection_dataset( if isinstance(annotation.ng_annotation, AxisAlignedBoundingBoxAnnotation): bbox = BBox3D.from_ng_bbox( - ng_bbox=annotation.ng_annotation, - base_resolution=base_resolution + ng_bbox=annotation.ng_annotation, base_resolution=base_resolution ).snapped([0, 0, 0], resolution, "shrink") - datasets[str(i)] = ( - LayerDataset( - layer=build_layer_set( - layers=layers, - read_procs=shared_read_procs - ), - sample_indexer=VolumetricStridedIndexer( - resolution=resolution, - chunk_size=chunk_size, - stride=chunk_stride, - mode="shrink", - bbox=bbox, - ) - - ) + datasets[str(i)] = LayerDataset( + layer=build_layer_set(layers=layers, read_procs=shared_read_procs), + sample_indexer=VolumetricStridedIndexer( + resolution=resolution, + chunk_size=chunk_size, + stride=chunk_stride, + mode="shrink", + bbox=bbox, + ), ) dset = JointDataset(mode="vertical", datasets=datasets) return dset From cd587e7bc4ca10a89a4f513cd81304de7dbebf06 Mon Sep 17 00:00:00 2001 From: Sergiy Date: Mon, 28 Oct 2024 01:43:27 +0000 Subject: [PATCH 3/5] chore: minor cleanup --- tests/conftest.py | 2 +- zetta_utils/log.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 311815055..dddcc4b01 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -68,7 +68,7 @@ def firestore_emulator(): """Ensure that the Firestore service is up and responsive.""" client = docker.from_env() # type: ignore - project = "test-project" + project = "test-project-x0" port = "8080" container = client.containers.run( "mtlynch/firestore-emulator:latest", diff --git a/zetta_utils/log.py b/zetta_utils/log.py index ca5aa8812..5578ab9ad 100644 --- a/zetta_utils/log.py +++ b/zetta_utils/log.py @@ -86,6 +86,7 @@ def configure_logger(level=None, third_party_level="ERROR"): for _ in ( "python_jsonschema_objects", "pytorch_lightning", + "lightning_fabric", "urllib3", "urllib3.connectionpool", "google", From 369e022a88f962647d7bb341390a0f954a5eb5bd Mon Sep 17 00:00:00 2001 From: Sergiy Date: Thu, 19 Dec 2024 17:18:37 +0000 Subject: [PATCH 4/5] collection dset updates --- .../training/datasets/collection_dataset.py | 48 ++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/zetta_utils/training/datasets/collection_dataset.py b/zetta_utils/training/datasets/collection_dataset.py index 18c448d9b..71f8b822c 100644 --- a/zetta_utils/training/datasets/collection_dataset.py +++ b/zetta_utils/training/datasets/collection_dataset.py @@ -1,5 +1,8 @@ +import json +import os from typing import Sequence +import fsspec from neuroglancer.viewer_state import AxisAlignedBoundingBoxAnnotation from typeguard import typechecked @@ -10,7 +13,6 @@ from zetta_utils.layer.layer_set.build import build_layer_set from zetta_utils.layer.tools_base import DataProcessor from zetta_utils.layer.volumetric.cloudvol.build import build_cv_layer -from zetta_utils.layer.volumetric.layer import VolumetricLayer from zetta_utils.training.datasets.joint_dataset import JointDataset from zetta_utils.training.datasets.layer_dataset import LayerDataset from zetta_utils.training.datasets.sample_indexers.volumetric_strided_indexer import ( @@ -25,19 +27,20 @@ def build_collection_dataset( resolution: Sequence[float], chunk_size: Sequence[int], chunk_stride: Sequence[int], - base_resolution: Sequence[float], layer_rename_map: dict[str, str], shared_read_procs: Sequence[DataProcessor], per_layer_read_procs: dict[str, Sequence[DataProcessor]], tags: list[str] | None = None, ) -> JointDataset: datasets = {} - annotations = db_annotations.read_annotations(collection_ids=[collection_name], tags=tags) + annotations = db_annotations.read_annotations( + collection_ids=[collection_name], tags=tags, union=False + ) layer_group_map: dict[str, dict[str, Layer]] = {} for i, annotation in enumerate(annotations.values()): - if annotation.layer_group_id not in layer_group_map: - layer_group = db_annotations.read_layer_group(annotation.layer_group_id) + if annotation.layer_group not in layer_group_map: + layer_group = db_annotations.read_layer_group(annotation.layer_group) db_layers = db_annotations.read_layers(layer_ids=layer_group.layers) layers = {} for layer in db_layers: @@ -46,18 +49,39 @@ def build_collection_dataset( name = layer_rename_map[name] read_procs = per_layer_read_procs[name] layers[name] = build_cv_layer(path=layer.source, read_procs=read_procs) - layer_group_map[annotation.layer_group_id] = layers + layer_group_map[annotation.layer_group] = layers else: - layers = layer_group_map[annotation.layer_group_id] - + layers = layer_group_map[annotation.layer_group] + z_resolution = resolution[-1] + for layer in layers.values(): + info_path = os.path.join(layer.backend.name.strip("precomputed://"), "info") + with fsspec.open(info_path) as f: + info = json.loads(f.read()) + z_resolutions = set([e["resolution"][-1] for e in info["scales"]]) + if len(z_resolutions) != 1: + raise RuntimeError("Only layers with single z resolution are supported") + z_resolution = list(z_resolutions)[0] + this_resolution = [resolution[0], resolution[1], z_resolution] if isinstance(annotation.ng_annotation, AxisAlignedBoundingBoxAnnotation): - bbox = BBox3D.from_ng_bbox( - ng_bbox=annotation.ng_annotation, base_resolution=base_resolution - ).snapped([0, 0, 0], resolution, "shrink") + ng_bbox = annotation.ng_annotation + point_a_nm = Vec3D(*ng_bbox.pointA).int() + point_b_nm = Vec3D(*ng_bbox.pointB).int() + + start_coord = [ + round((min(point_a_nm[i], point_b_nm[i]) / this_resolution[i])) + * this_resolution[i] + for i in range(3) + ] + end_coord = [ + round(max(point_a_nm[i], point_b_nm[i]) / this_resolution[i]) * this_resolution[i] + for i in range(3) + ] + + bbox = BBox3D.from_coords(start_coord=start_coord, end_coord=end_coord) datasets[str(i)] = LayerDataset( layer=build_layer_set(layers=layers, read_procs=shared_read_procs), sample_indexer=VolumetricStridedIndexer( - resolution=resolution, + resolution=this_resolution, chunk_size=chunk_size, stride=chunk_stride, mode="shrink", From 8cc74e7c6764506ac0bfb9c4c12a7017eceecd56 Mon Sep 17 00:00:00 2001 From: Sergiy Date: Fri, 3 Jan 2025 02:39:39 +0000 Subject: [PATCH 5/5] collection dataset fix --- .../training/datasets/collection_dataset.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/zetta_utils/training/datasets/collection_dataset.py b/zetta_utils/training/datasets/collection_dataset.py index 71f8b822c..058418ba1 100644 --- a/zetta_utils/training/datasets/collection_dataset.py +++ b/zetta_utils/training/datasets/collection_dataset.py @@ -1,5 +1,6 @@ import json import os +from collections import defaultdict from typing import Sequence import fsspec @@ -28,8 +29,8 @@ def build_collection_dataset( chunk_size: Sequence[int], chunk_stride: Sequence[int], layer_rename_map: dict[str, str], - shared_read_procs: Sequence[DataProcessor], - per_layer_read_procs: dict[str, Sequence[DataProcessor]], + per_layer_read_procs: dict[str, Sequence[DataProcessor]] | None = None, + shared_read_procs: Sequence[DataProcessor] = tuple(), tags: list[str] | None = None, ) -> JointDataset: datasets = {} @@ -37,6 +38,10 @@ def build_collection_dataset( collection_ids=[collection_name], tags=tags, union=False ) layer_group_map: dict[str, dict[str, Layer]] = {} + if per_layer_read_procs is None: + per_layer_read_procs_dict = defaultdict(tuple) + else: + per_layer_read_procs_dict = per_layer_read_procs for i, annotation in enumerate(annotations.values()): if annotation.layer_group not in layer_group_map: @@ -47,17 +52,18 @@ def build_collection_dataset( name = layer.name if name in layer_rename_map: name = layer_rename_map[name] - read_procs = per_layer_read_procs[name] + read_procs = per_layer_read_procs_dict[name] layers[name] = build_cv_layer(path=layer.source, read_procs=read_procs) layer_group_map[annotation.layer_group] = layers else: layers = layer_group_map[annotation.layer_group] + z_resolution = resolution[-1] for layer in layers.values(): info_path = os.path.join(layer.backend.name.strip("precomputed://"), "info") with fsspec.open(info_path) as f: info = json.loads(f.read()) - z_resolutions = set([e["resolution"][-1] for e in info["scales"]]) + z_resolutions = {e["resolution"][-1] for e in info["scales"]} if len(z_resolutions) != 1: raise RuntimeError("Only layers with single z resolution are supported") z_resolution = list(z_resolutions)[0] @@ -68,15 +74,13 @@ def build_collection_dataset( point_b_nm = Vec3D(*ng_bbox.pointB).int() start_coord = [ - round((min(point_a_nm[i], point_b_nm[i]) / this_resolution[i])) - * this_resolution[i] + int((min(point_a_nm[i], point_b_nm[i]) / this_resolution[i])) * this_resolution[i] for i in range(3) ] end_coord = [ - round(max(point_a_nm[i], point_b_nm[i]) / this_resolution[i]) * this_resolution[i] + int(max(point_a_nm[i], point_b_nm[i]) / this_resolution[i]) * this_resolution[i] for i in range(3) ] - bbox = BBox3D.from_coords(start_coord=start_coord, end_coord=end_coord) datasets[str(i)] = LayerDataset( layer=build_layer_set(layers=layers, read_procs=shared_read_procs),