diff --git a/tests/conftest.py b/tests/conftest.py index 311815055..dddcc4b01 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -68,7 +68,7 @@ def firestore_emulator(): """Ensure that the Firestore service is up and responsive.""" client = docker.from_env() # type: ignore - project = "test-project" + project = "test-project-x0" port = "8080" container = client.containers.run( "mtlynch/firestore-emulator:latest", diff --git a/zetta_utils/db_annotations/operations/copy_annotated.py b/zetta_utils/db_annotations/operations/copy_annotated.py new file mode 100644 index 000000000..356c2e90f --- /dev/null +++ b/zetta_utils/db_annotations/operations/copy_annotated.py @@ -0,0 +1,19 @@ +from typing import Sequence + +from zetta_utils import builder +from zetta_utils.db_annotations.annotation import read_annotations +from zetta_utils.layer.volumetric.layer import VolumetricLayer + + +@builder.register("copy_annotated_data") +def CopyAnnotatedFlow( + src: VolumetricLayer, + dst: VolumetricLayer, + collection_name: str, + layer_group_name: str, + resolution: Sequence[float], +): + annotations = read_annotations( + collection_ids=[collection_name], layer_group_ids=[layer_group_name] + ) + breakpoint() diff --git a/zetta_utils/geometry/bbox.py b/zetta_utils/geometry/bbox.py index e81a7aa54..ba945b345 100644 --- a/zetta_utils/geometry/bbox.py +++ b/zetta_utils/geometry/bbox.py @@ -4,7 +4,8 @@ from itertools import product from math import floor from typing import Literal, Optional, Sequence, Union, cast - +from neuroglancer.viewer_state import AxisAlignedBoundingBoxAnnotation + import attrs from typeguard import typechecked @@ -63,6 +64,22 @@ def shape(self) -> Vec3D: # pragma: no cover """returns the shape coordinates.""" return self.end - self.start + @staticmethod + def from_ng_bbox( + ng_bbox: AxisAlignedBoundingBoxAnnotation, + base_resolution: Sequence[float] + ) -> BBox3D: + point_a_nm = Vec3D(*ng_bbox.pointA).int() * Vec3D(*base_resolution) + point_b_nm = Vec3D(*ng_bbox.pointB).int() * Vec3D(*base_resolution) + start_coord = [min(point_a_nm[i], point_b_nm[i]) for i in range(3)] + end_coord = [max(point_a_nm[i], point_b_nm[i]) for i in range(3)] + bbox = BBox3D.from_coords( + start_coord=start_coord, + end_coord=end_coord, + resolution=[1, 1, 1] + ) + return bbox + @classmethod def from_slices( cls, diff --git a/zetta_utils/log.py b/zetta_utils/log.py index ca5aa8812..5578ab9ad 100644 --- a/zetta_utils/log.py +++ b/zetta_utils/log.py @@ -86,6 +86,7 @@ def configure_logger(level=None, third_party_level="ERROR"): for _ in ( "python_jsonschema_objects", "pytorch_lightning", + "lightning_fabric", "urllib3", "urllib3.connectionpool", "google", diff --git a/zetta_utils/training/datasets/__init__.py b/zetta_utils/training/datasets/__init__.py index 944f497ec..1cad10922 100644 --- a/zetta_utils/training/datasets/__init__.py +++ b/zetta_utils/training/datasets/__init__.py @@ -3,3 +3,4 @@ from .joint_dataset import JointDataset from .layer_dataset import LayerDataset from .sample_indexers import RandomIndexer, VolumetricStridedIndexer +from .collection_dataset import build_collection_dataset diff --git a/zetta_utils/training/datasets/collection_dataset.py b/zetta_utils/training/datasets/collection_dataset.py new file mode 100644 index 000000000..058418ba1 --- /dev/null +++ b/zetta_utils/training/datasets/collection_dataset.py @@ -0,0 +1,96 @@ +import json +import os +from collections import defaultdict +from typing import Sequence + +import fsspec +from neuroglancer.viewer_state import AxisAlignedBoundingBoxAnnotation +from typeguard import typechecked + +from zetta_utils import builder, db_annotations +from zetta_utils.geometry.bbox import BBox3D +from zetta_utils.geometry.vec import Vec3D +from zetta_utils.layer.layer_base import Layer +from zetta_utils.layer.layer_set.build import build_layer_set +from zetta_utils.layer.tools_base import DataProcessor +from zetta_utils.layer.volumetric.cloudvol.build import build_cv_layer +from zetta_utils.training.datasets.joint_dataset import JointDataset +from zetta_utils.training.datasets.layer_dataset import LayerDataset +from zetta_utils.training.datasets.sample_indexers.volumetric_strided_indexer import ( + VolumetricStridedIndexer, +) + + +@builder.register("build_collection_dataset") +@typechecked +def build_collection_dataset( + collection_name: str, + resolution: Sequence[float], + chunk_size: Sequence[int], + chunk_stride: Sequence[int], + layer_rename_map: dict[str, str], + per_layer_read_procs: dict[str, Sequence[DataProcessor]] | None = None, + shared_read_procs: Sequence[DataProcessor] = tuple(), + tags: list[str] | None = None, +) -> JointDataset: + datasets = {} + annotations = db_annotations.read_annotations( + collection_ids=[collection_name], tags=tags, union=False + ) + layer_group_map: dict[str, dict[str, Layer]] = {} + if per_layer_read_procs is None: + per_layer_read_procs_dict = defaultdict(tuple) + else: + per_layer_read_procs_dict = per_layer_read_procs + + for i, annotation in enumerate(annotations.values()): + if annotation.layer_group not in layer_group_map: + layer_group = db_annotations.read_layer_group(annotation.layer_group) + db_layers = db_annotations.read_layers(layer_ids=layer_group.layers) + layers = {} + for layer in db_layers: + name = layer.name + if name in layer_rename_map: + name = layer_rename_map[name] + read_procs = per_layer_read_procs_dict[name] + layers[name] = build_cv_layer(path=layer.source, read_procs=read_procs) + layer_group_map[annotation.layer_group] = layers + else: + layers = layer_group_map[annotation.layer_group] + + z_resolution = resolution[-1] + for layer in layers.values(): + info_path = os.path.join(layer.backend.name.strip("precomputed://"), "info") + with fsspec.open(info_path) as f: + info = json.loads(f.read()) + z_resolutions = {e["resolution"][-1] for e in info["scales"]} + if len(z_resolutions) != 1: + raise RuntimeError("Only layers with single z resolution are supported") + z_resolution = list(z_resolutions)[0] + this_resolution = [resolution[0], resolution[1], z_resolution] + if isinstance(annotation.ng_annotation, AxisAlignedBoundingBoxAnnotation): + ng_bbox = annotation.ng_annotation + point_a_nm = Vec3D(*ng_bbox.pointA).int() + point_b_nm = Vec3D(*ng_bbox.pointB).int() + + start_coord = [ + int((min(point_a_nm[i], point_b_nm[i]) / this_resolution[i])) * this_resolution[i] + for i in range(3) + ] + end_coord = [ + int(max(point_a_nm[i], point_b_nm[i]) / this_resolution[i]) * this_resolution[i] + for i in range(3) + ] + bbox = BBox3D.from_coords(start_coord=start_coord, end_coord=end_coord) + datasets[str(i)] = LayerDataset( + layer=build_layer_set(layers=layers, read_procs=shared_read_procs), + sample_indexer=VolumetricStridedIndexer( + resolution=this_resolution, + chunk_size=chunk_size, + stride=chunk_stride, + mode="shrink", + bbox=bbox, + ), + ) + dset = JointDataset(mode="vertical", datasets=datasets) + return dset diff --git a/zetta_utils/training/lightning/train.py b/zetta_utils/training/lightning/train.py index 5c1f671ae..fa6cee510 100644 --- a/zetta_utils/training/lightning/train.py +++ b/zetta_utils/training/lightning/train.py @@ -134,7 +134,7 @@ def lightning_train( train_args[k] = v else: arg_spec = builder.get_initial_builder_spec(v) - if arg_spec is None: + if arg_spec is None and v is not None: raise RuntimeError( f"No builder spec found for `{k}`. Remote training requires arguments to " "be created using `builder` module." diff --git a/zetta_utils/training/lightning/trainers/default.py b/zetta_utils/training/lightning/trainers/default.py index fe8bf62a8..a80371fc3 100644 --- a/zetta_utils/training/lightning/trainers/default.py +++ b/zetta_utils/training/lightning/trainers/default.py @@ -33,19 +33,19 @@ def trace_and_save_model( args_packed, ): # pragma: no cover # pylint: disable=broad-except, used-before-assignment model, trace_input, filepath, name = args_packed - trace = torch.jit.trace(model, trace_input) - filepath_jit = f"{filepath}.static-{torch.__version__}-{name}.jit" - with fsspec.open(filepath_jit, "wb") as f: - torch.jit.save(trace, f) - try: - filepath_onnx = f"{filepath}.static-{torch.__version__}-{name}.onnx" - with fsspec.open(filepath_onnx, "wb") as f: - filesystem = f.fs - torch.onnx.export(model, trace_input, f, opset_version=ONNX_OPSET_VERSION) - return None - except Exception as e: - filesystem.delete(filepath_onnx) - return type(e).__name__, e.args[0] + # trace = torch.jit.trace(model, trace_input) + # filepath_jit = f"{filepath}.static-{torch.__version__}-{name}.jit" + # with fsspec.open(filepath_jit, "wb") as f: + # torch.jit.save(trace, f) + # try: + # filepath_onnx = f"{filepath}.static-{torch.__version__}-{name}.onnx" + # with fsspec.open(filepath_onnx, "wb") as f: + # filesystem = f.fs + # torch.onnx.export(model, trace_input, f, opset_version=ONNX_OPSET_VERSION) + # return None + # except Exception as e: + # filesystem.delete(filepath_onnx) + # return type(e).__name__, e.args[0] @builder.register("ZettaDefaultTrainer")