diff --git a/tests/unit/db_annotations/test_annotation.py b/tests/unit/db_annotations/test_annotation.py index 7560b17e5..9ae7ce957 100644 --- a/tests/unit/db_annotations/test_annotation.py +++ b/tests/unit/db_annotations/test_annotation.py @@ -62,6 +62,7 @@ def test_add_update_delete_annotation( assert _annotation.collection == collection_id assert _annotation.layer_group == layer_group_id assert len(cast(list, _annotation.tags)) == 2 + assert len(_annotation.selected_segments) == 0 _annotation_d = _annotation.to_dict() assert _annotation_d["collection"] == _annotation.collection @@ -72,10 +73,12 @@ def test_add_update_delete_annotation( layer_group_id=layer_group_id, comment="this is also a test", tags=["tag2"], + selected_segments={"yo": [1, 2, 3]}, ) _annotation = annotation.read_annotation(_id) assert len(cast(list, _annotation.tags)) == 1 assert cast(list, _annotation.tags)[0] == "tag2" + assert _annotation.selected_segments == {"yo": [1, 2, 3]} annotation.delete_annotation(_id) with pytest.raises(KeyError): @@ -122,13 +125,15 @@ def test_add_update_annotations( layer_group_id=layer_group_id, comment="this is a test", tags=["tag0", "tag1"], + selected_segments=[{"yo": [0, -1]} for _ in annotations_raw], ) _annotations = annotation.read_annotations(annotation_ids=_ids) assert _annotations[0].ng_annotation.type == "line" assert _annotations[1].ng_annotation.type == "point" - assert len(cast(list, _annotations[0].tags)) == 2 - assert len(cast(list, _annotations[1].tags)) == 2 + assert len(_annotations[0].tags) == 2 + assert len(_annotations[1].tags) == 2 + assert _annotations[0].selected_segments == {"yo": [0, -1]} annotation.update_annotations( _ids, @@ -136,10 +141,12 @@ def test_add_update_annotations( layer_group_id=layer_group_id, comment="this is also a test", tags=["tag2"], + selected_segments=[{} for _ in _ids], ) _annotations = annotation.read_annotations(annotation_ids=_ids) - assert len(cast(list, _annotations[0].tags)) == 1 - assert len(cast(list, _annotations[1].tags)) == 1 + assert len(_annotations[0].tags) == 1 + assert len(_annotations[1].tags) == 1 + assert _annotations[0].selected_segments == {} def test_read_delete_annotations(firestore_emulator, annotations_db, collection_and_layer_group): diff --git a/zetta_utils/db_annotations/annotation.py b/zetta_utils/db_annotations/annotation.py index 7d56d2194..e941e8ce9 100644 --- a/zetta_utils/db_annotations/annotation.py +++ b/zetta_utils/db_annotations/annotation.py @@ -6,7 +6,7 @@ import json import time import uuid -from typing import Any, Union, cast, overload +from typing import Any, Mapping, Sequence, Union, cast, overload import attrs from neuroglancer.viewer_state import ( @@ -39,6 +39,7 @@ AnnotationKeys.POINT_B.value, AnnotationKeys.CENTER.value, AnnotationKeys.RADII.value, + "selected_segments", ) DB_NAME = "annotations" @@ -57,17 +58,19 @@ class AnnotationDBEntry: ng_annotation: NgAnnotation comment: str tags: list[str] + selected_segments: Mapping[str, Sequence[int]] created_at: float | None = None modified_at: float | None = None @staticmethod def from_dict(annotation_id: str, raw_dict: dict[str, Any]) -> AnnotationDBEntry: - raw_with_defaults: dict[str, Any] = {"tags": [], **raw_dict} + raw_with_defaults: dict[str, Any] = {"tags": [], "selected_segments": [], **raw_dict} shape_dict = copy.deepcopy(raw_with_defaults) shape_dict.pop("layer_group", None) shape_dict.pop("collection", None) shape_dict.pop("comment", None) shape_dict.pop("tags", None) + shape_dict.pop("selected_segments", None) ng_annotation = parse_ng_annotations([shape_dict])[0] result = AnnotationDBEntry( @@ -76,6 +79,7 @@ def from_dict(annotation_id: str, raw_dict: dict[str, Any]) -> AnnotationDBEntry collection=raw_with_defaults.get("collection", ""), comment=raw_with_defaults.get("comment", ""), tags=raw_with_defaults["tags"], + selected_segments=raw_with_defaults["selected_segments"], ng_annotation=ng_annotation, created_at=raw_with_defaults.get("created_at"), modified_at=raw_with_defaults.get("modified_at"), @@ -144,6 +148,7 @@ def _make_annotation_row( collection_id: str, layer_group_id: str, comment: str | None, + selected_segments: Mapping[str, Sequence[int]], tags: list[str] | None, ) -> tuple[str, dict]: row = annotation.to_json() @@ -156,6 +161,7 @@ def _make_annotation_row( row["created_at"] = time.time() if tags: row["tags"] = list(set(tags)) + row["selected_segments"] = selected_segments return annotation_id, row @@ -165,6 +171,7 @@ def add_annotation( collection_id: str, layer_group_id: str, comment: str | None = None, + selected_segments: Mapping[str, Sequence[int]] | None = None, tags: list[str] | None = None, ) -> str: annotation_id, row = _make_annotation_row( @@ -172,6 +179,7 @@ def add_annotation( collection_id=collection_id, layer_group_id=layer_group_id, comment=comment, + selected_segments=selected_segments if selected_segments else {}, tags=tags, ) col_keys = INDEXED_COLS + NON_INDEXED_COLS @@ -185,16 +193,21 @@ def add_annotations( collection_id: str, layer_group_id: str, comment: str | None = None, + selected_segments: Sequence[Mapping[str, Sequence[int]]] | None = None, tags: list[str] | None = None, ) -> list[str]: rows = [] annotation_ids = [] - for annotation in annotations: + for i, annotation in enumerate(annotations): + segments: Mapping[str, Sequence[int]] = {} + if selected_segments is not None: + segments = selected_segments[i] annotation_id, row = _make_annotation_row( annotation, collection_id=collection_id, layer_group_id=layer_group_id, comment=comment, + selected_segments=segments, tags=tags, ) rows.append(row) @@ -209,19 +222,23 @@ def update_annotation( *, collection_id: str | None = None, layer_group_id: str | None = None, + selected_segments: Mapping[str, Sequence[int]] | None = None, comment: str | None = None, tags: list[str] | None = None, ): col_keys = INDEXED_COLS + NON_INDEXED_COLS row: DBRowDataT = {"modified_at": time.time()} - if collection_id: + if collection_id is not None: row["collection"] = collection_id - if layer_group_id: + if layer_group_id is not None: row["layer_group"] = layer_group_id - if comment: + if comment is not None: row["comment"] = comment - if tags: + if tags is not None: row["tags"] = list(set(tags)) + + if selected_segments is not None: + row["selected_segments"] = {k: list(v) for k, v in selected_segments.items()} ANNOTATIONS_DB[(annotation_id, col_keys)] = row @@ -231,20 +248,23 @@ def update_annotations( collection_id: str | None = None, layer_group_id: str | None = None, comment: str | None = None, + selected_segments: Sequence[Mapping[str, Sequence[int]]] | None = None, tags: list[str] | None = None, ): col_keys = INDEXED_COLS + NON_INDEXED_COLS rows = [] - for _ in range(len(annotation_ids)): + for i in range(len(annotation_ids)): row: DBRowDataT = {"modified_at": time.time()} - if collection_id: + if collection_id is not None: row["collection"] = collection_id - if layer_group_id: + if layer_group_id is not None: row["layer_group"] = layer_group_id - if comment: + if comment is not None: row["comment"] = comment - if tags: + if tags is not None: row["tags"] = list(set(tags)) + if selected_segments is not None: + row["selected_segments"] = {k: list(v) for k, v in selected_segments[i].items()} rows.append(row) ANNOTATIONS_DB[(annotation_ids, col_keys)] = rows @@ -282,6 +302,7 @@ def add_bbox_annotation( collection_id: str, layer_group_id: str, comment: str | None = None, + selected_segments: Mapping[str, Sequence[int]] | None = None, tags: list[str] | None = None, ) -> str: # pragma: no cover # delegation return add_annotation( @@ -291,6 +312,7 @@ def add_bbox_annotation( ), collection_id=collection_id, layer_group_id=layer_group_id, + selected_segments=selected_segments, comment=comment, tags=tags, )