Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training loop stuff #821

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def firestore_emulator():
"""Ensure that the Firestore service is up and responsive."""

client = docker.from_env() # type: ignore
project = "test-project"
project = "test-project-x0"
port = "8080"
container = client.containers.run(
"mtlynch/firestore-emulator:latest",
Expand Down
19 changes: 19 additions & 0 deletions zetta_utils/db_annotations/operations/copy_annotated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Sequence

from zetta_utils import builder
from zetta_utils.db_annotations.annotation import read_annotations
from zetta_utils.layer.volumetric.layer import VolumetricLayer


@builder.register("copy_annotated_data")
def CopyAnnotatedFlow(
src: VolumetricLayer,
dst: VolumetricLayer,
collection_name: str,
layer_group_name: str,
resolution: Sequence[float],
):
annotations = read_annotations(
collection_ids=[collection_name], layer_group_ids=[layer_group_name]
)
breakpoint()
19 changes: 18 additions & 1 deletion zetta_utils/geometry/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from itertools import product
from math import floor
from typing import Literal, Optional, Sequence, Union, cast

from neuroglancer.viewer_state import AxisAlignedBoundingBoxAnnotation

import attrs
from typeguard import typechecked

Expand Down Expand Up @@ -63,6 +64,22 @@ def shape(self) -> Vec3D: # pragma: no cover
"""returns the shape coordinates."""
return self.end - self.start

@staticmethod
def from_ng_bbox(
ng_bbox: AxisAlignedBoundingBoxAnnotation,
base_resolution: Sequence[float]
) -> BBox3D:
point_a_nm = Vec3D(*ng_bbox.pointA).int() * Vec3D(*base_resolution)
point_b_nm = Vec3D(*ng_bbox.pointB).int() * Vec3D(*base_resolution)
start_coord = [min(point_a_nm[i], point_b_nm[i]) for i in range(3)]
end_coord = [max(point_a_nm[i], point_b_nm[i]) for i in range(3)]
bbox = BBox3D.from_coords(
start_coord=start_coord,
end_coord=end_coord,
resolution=[1, 1, 1]
)
return bbox

@classmethod
def from_slices(
cls,
Expand Down
1 change: 1 addition & 0 deletions zetta_utils/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def configure_logger(level=None, third_party_level="ERROR"):
for _ in (
"python_jsonschema_objects",
"pytorch_lightning",
"lightning_fabric",
"urllib3",
"urllib3.connectionpool",
"google",
Expand Down
1 change: 1 addition & 0 deletions zetta_utils/training/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .joint_dataset import JointDataset
from .layer_dataset import LayerDataset
from .sample_indexers import RandomIndexer, VolumetricStridedIndexer
from .collection_dataset import build_collection_dataset
96 changes: 96 additions & 0 deletions zetta_utils/training/datasets/collection_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import json
import os
from collections import defaultdict
from typing import Sequence

import fsspec
from neuroglancer.viewer_state import AxisAlignedBoundingBoxAnnotation
from typeguard import typechecked

from zetta_utils import builder, db_annotations
from zetta_utils.geometry.bbox import BBox3D
from zetta_utils.geometry.vec import Vec3D
from zetta_utils.layer.layer_base import Layer
from zetta_utils.layer.layer_set.build import build_layer_set
from zetta_utils.layer.tools_base import DataProcessor
from zetta_utils.layer.volumetric.cloudvol.build import build_cv_layer
from zetta_utils.training.datasets.joint_dataset import JointDataset
from zetta_utils.training.datasets.layer_dataset import LayerDataset
from zetta_utils.training.datasets.sample_indexers.volumetric_strided_indexer import (
VolumetricStridedIndexer,
)


@builder.register("build_collection_dataset")
@typechecked
def build_collection_dataset(
collection_name: str,
resolution: Sequence[float],
chunk_size: Sequence[int],
chunk_stride: Sequence[int],
layer_rename_map: dict[str, str],
per_layer_read_procs: dict[str, Sequence[DataProcessor]] | None = None,
shared_read_procs: Sequence[DataProcessor] = tuple(),
tags: list[str] | None = None,
) -> JointDataset:
datasets = {}
annotations = db_annotations.read_annotations(

Check warning on line 37 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L36-L37

Added lines #L36 - L37 were not covered by tests
collection_ids=[collection_name], tags=tags, union=False
)
layer_group_map: dict[str, dict[str, Layer]] = {}
if per_layer_read_procs is None:
per_layer_read_procs_dict = defaultdict(tuple)

Check warning on line 42 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L40-L42

Added lines #L40 - L42 were not covered by tests
else:
per_layer_read_procs_dict = per_layer_read_procs

Check warning on line 44 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L44

Added line #L44 was not covered by tests

for i, annotation in enumerate(annotations.values()):
if annotation.layer_group not in layer_group_map:
layer_group = db_annotations.read_layer_group(annotation.layer_group)
db_layers = db_annotations.read_layers(layer_ids=layer_group.layers)
layers = {}
for layer in db_layers:
name = layer.name
if name in layer_rename_map:
name = layer_rename_map[name]
read_procs = per_layer_read_procs_dict[name]
layers[name] = build_cv_layer(path=layer.source, read_procs=read_procs)
layer_group_map[annotation.layer_group] = layers

Check warning on line 57 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L46-L57

Added lines #L46 - L57 were not covered by tests
else:
layers = layer_group_map[annotation.layer_group]

Check warning on line 59 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L59

Added line #L59 was not covered by tests

z_resolution = resolution[-1]
for layer in layers.values():
info_path = os.path.join(layer.backend.name.strip("precomputed://"), "info")
with fsspec.open(info_path) as f:
info = json.loads(f.read())
z_resolutions = {e["resolution"][-1] for e in info["scales"]}
if len(z_resolutions) != 1:
raise RuntimeError("Only layers with single z resolution are supported")
z_resolution = list(z_resolutions)[0]
this_resolution = [resolution[0], resolution[1], z_resolution]
if isinstance(annotation.ng_annotation, AxisAlignedBoundingBoxAnnotation):
ng_bbox = annotation.ng_annotation
point_a_nm = Vec3D(*ng_bbox.pointA).int()
point_b_nm = Vec3D(*ng_bbox.pointB).int()

Check warning on line 74 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L61-L74

Added lines #L61 - L74 were not covered by tests

start_coord = [

Check warning on line 76 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L76

Added line #L76 was not covered by tests
int((min(point_a_nm[i], point_b_nm[i]) / this_resolution[i])) * this_resolution[i]
for i in range(3)
]
end_coord = [

Check warning on line 80 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L80

Added line #L80 was not covered by tests
int(max(point_a_nm[i], point_b_nm[i]) / this_resolution[i]) * this_resolution[i]
for i in range(3)
]
bbox = BBox3D.from_coords(start_coord=start_coord, end_coord=end_coord)
datasets[str(i)] = LayerDataset(

Check warning on line 85 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L84-L85

Added lines #L84 - L85 were not covered by tests
layer=build_layer_set(layers=layers, read_procs=shared_read_procs),
sample_indexer=VolumetricStridedIndexer(
resolution=this_resolution,
chunk_size=chunk_size,
stride=chunk_stride,
mode="shrink",
bbox=bbox,
),
)
dset = JointDataset(mode="vertical", datasets=datasets)
return dset

Check warning on line 96 in zetta_utils/training/datasets/collection_dataset.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/training/datasets/collection_dataset.py#L95-L96

Added lines #L95 - L96 were not covered by tests
2 changes: 1 addition & 1 deletion zetta_utils/training/lightning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def lightning_train(
train_args[k] = v
else:
arg_spec = builder.get_initial_builder_spec(v)
if arg_spec is None:
if arg_spec is None and v is not None:
raise RuntimeError(
f"No builder spec found for `{k}`. Remote training requires arguments to "
"be created using `builder` module."
Expand Down
26 changes: 13 additions & 13 deletions zetta_utils/training/lightning/trainers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@ def trace_and_save_model(
args_packed,
): # pragma: no cover # pylint: disable=broad-except, used-before-assignment
model, trace_input, filepath, name = args_packed
trace = torch.jit.trace(model, trace_input)
filepath_jit = f"{filepath}.static-{torch.__version__}-{name}.jit"
with fsspec.open(filepath_jit, "wb") as f:
torch.jit.save(trace, f)
try:
filepath_onnx = f"{filepath}.static-{torch.__version__}-{name}.onnx"
with fsspec.open(filepath_onnx, "wb") as f:
filesystem = f.fs
torch.onnx.export(model, trace_input, f, opset_version=ONNX_OPSET_VERSION)
return None
except Exception as e:
filesystem.delete(filepath_onnx)
return type(e).__name__, e.args[0]
# trace = torch.jit.trace(model, trace_input)
# filepath_jit = f"{filepath}.static-{torch.__version__}-{name}.jit"
# with fsspec.open(filepath_jit, "wb") as f:
# torch.jit.save(trace, f)
# try:
# filepath_onnx = f"{filepath}.static-{torch.__version__}-{name}.onnx"
# with fsspec.open(filepath_onnx, "wb") as f:
# filesystem = f.fs
# torch.onnx.export(model, trace_input, f, opset_version=ONNX_OPSET_VERSION)
# return None
# except Exception as e:
# filesystem.delete(filepath_onnx)
# return type(e).__name__, e.args[0]


@builder.register("ZettaDefaultTrainer")
Expand Down
Loading