Skip to content

Commit

Permalink
feat: selected segments for annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
supersergiy committed Feb 24, 2025
1 parent 7b331ca commit b02f3ac
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 16 deletions.
15 changes: 11 additions & 4 deletions tests/unit/db_annotations/test_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def test_add_update_delete_annotation(
assert _annotation.collection == collection_id
assert _annotation.layer_group == layer_group_id
assert len(cast(list, _annotation.tags)) == 2
assert len(_annotation.selected_segments) == 0

_annotation_d = _annotation.to_dict()
assert _annotation_d["collection"] == _annotation.collection
Expand All @@ -72,10 +73,12 @@ def test_add_update_delete_annotation(
layer_group_id=layer_group_id,
comment="this is also a test",
tags=["tag2"],
selected_segments={"yo": [1, 2, 3]},
)
_annotation = annotation.read_annotation(_id)
assert len(cast(list, _annotation.tags)) == 1
assert cast(list, _annotation.tags)[0] == "tag2"
assert _annotation.selected_segments == {"yo": [1, 2, 3]}

annotation.delete_annotation(_id)
with pytest.raises(KeyError):
Expand Down Expand Up @@ -122,24 +125,28 @@ def test_add_update_annotations(
layer_group_id=layer_group_id,
comment="this is a test",
tags=["tag0", "tag1"],
selected_segments=[{"yo": [0, -1]} for _ in annotations_raw],
)

_annotations = annotation.read_annotations(annotation_ids=_ids)
assert _annotations[0].ng_annotation.type == "line"
assert _annotations[1].ng_annotation.type == "point"
assert len(cast(list, _annotations[0].tags)) == 2
assert len(cast(list, _annotations[1].tags)) == 2
assert len(_annotations[0].tags) == 2
assert len(_annotations[1].tags) == 2
assert _annotations[0].selected_segments == {"yo": [0, -1]}

annotation.update_annotations(
_ids,
collection_id=collection_id,
layer_group_id=layer_group_id,
comment="this is also a test",
tags=["tag2"],
selected_segments=[{} for _ in _ids],
)
_annotations = annotation.read_annotations(annotation_ids=_ids)
assert len(cast(list, _annotations[0].tags)) == 1
assert len(cast(list, _annotations[1].tags)) == 1
assert len(_annotations[0].tags) == 1
assert len(_annotations[1].tags) == 1
assert _annotations[0].selected_segments == {}


def test_read_delete_annotations(firestore_emulator, annotations_db, collection_and_layer_group):
Expand Down
46 changes: 34 additions & 12 deletions zetta_utils/db_annotations/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import time
import uuid
from typing import Any, Union, cast, overload
from typing import Any, Mapping, Sequence, Union, cast, overload

import attrs
from neuroglancer.viewer_state import (
Expand Down Expand Up @@ -39,6 +39,7 @@
AnnotationKeys.POINT_B.value,
AnnotationKeys.CENTER.value,
AnnotationKeys.RADII.value,
"selected_segments",
)

DB_NAME = "annotations"
Expand All @@ -57,17 +58,19 @@ class AnnotationDBEntry:
ng_annotation: NgAnnotation
comment: str
tags: list[str]
selected_segments: Mapping[str, Sequence[int]]
created_at: float | None = None
modified_at: float | None = None

@staticmethod
def from_dict(annotation_id: str, raw_dict: dict[str, Any]) -> AnnotationDBEntry:
raw_with_defaults: dict[str, Any] = {"tags": [], **raw_dict}
raw_with_defaults: dict[str, Any] = {"tags": [], "selected_segments": [], **raw_dict}
shape_dict = copy.deepcopy(raw_with_defaults)
shape_dict.pop("layer_group", None)
shape_dict.pop("collection", None)
shape_dict.pop("comment", None)
shape_dict.pop("tags", None)
shape_dict.pop("selected_segments", None)
ng_annotation = parse_ng_annotations([shape_dict])[0]

result = AnnotationDBEntry(
Expand All @@ -76,6 +79,7 @@ def from_dict(annotation_id: str, raw_dict: dict[str, Any]) -> AnnotationDBEntry
collection=raw_with_defaults.get("collection", ""),
comment=raw_with_defaults.get("comment", ""),
tags=raw_with_defaults["tags"],
selected_segments=raw_with_defaults["selected_segments"],
ng_annotation=ng_annotation,
created_at=raw_with_defaults.get("created_at"),
modified_at=raw_with_defaults.get("modified_at"),
Expand Down Expand Up @@ -144,6 +148,7 @@ def _make_annotation_row(
collection_id: str,
layer_group_id: str,
comment: str | None,
selected_segments: Mapping[str, Sequence[int]],
tags: list[str] | None,
) -> tuple[str, dict]:
row = annotation.to_json()
Expand All @@ -156,6 +161,7 @@ def _make_annotation_row(
row["created_at"] = time.time()
if tags:
row["tags"] = list(set(tags))
row["selected_segments"] = selected_segments
return annotation_id, row


Expand All @@ -165,13 +171,15 @@ def add_annotation(
collection_id: str,
layer_group_id: str,
comment: str | None = None,
selected_segments: Mapping[str, Sequence[int]] | None = None,
tags: list[str] | None = None,
) -> str:
annotation_id, row = _make_annotation_row(
annotation,
collection_id=collection_id,
layer_group_id=layer_group_id,
comment=comment,
selected_segments=selected_segments if selected_segments else {},
tags=tags,
)
col_keys = INDEXED_COLS + NON_INDEXED_COLS
Expand All @@ -185,16 +193,21 @@ def add_annotations(
collection_id: str,
layer_group_id: str,
comment: str | None = None,
selected_segments: Sequence[Mapping[str, Sequence[int]]] | None = None,
tags: list[str] | None = None,
) -> list[str]:
rows = []
annotation_ids = []
for annotation in annotations:
for i, annotation in enumerate(annotations):
segments: Mapping[str, Sequence[int]] = {}
if selected_segments is not None:
segments = selected_segments[i]
annotation_id, row = _make_annotation_row(
annotation,
collection_id=collection_id,
layer_group_id=layer_group_id,
comment=comment,
selected_segments=segments,
tags=tags,
)
rows.append(row)
Expand All @@ -209,19 +222,23 @@ def update_annotation(
*,
collection_id: str | None = None,
layer_group_id: str | None = None,
selected_segments: Mapping[str, Sequence[int]] | None = None,
comment: str | None = None,
tags: list[str] | None = None,
):
col_keys = INDEXED_COLS + NON_INDEXED_COLS
row: DBRowDataT = {"modified_at": time.time()}
if collection_id:
if collection_id is not None:
row["collection"] = collection_id
if layer_group_id:
if layer_group_id is not None:
row["layer_group"] = layer_group_id
if comment:
if comment is not None:
row["comment"] = comment
if tags:
if tags is not None:
row["tags"] = list(set(tags))

if selected_segments is not None:
row["selected_segments"] = {k: list(v) for k, v in selected_segments.items()}
ANNOTATIONS_DB[(annotation_id, col_keys)] = row


Expand All @@ -231,20 +248,23 @@ def update_annotations(
collection_id: str | None = None,
layer_group_id: str | None = None,
comment: str | None = None,
selected_segments: Sequence[Mapping[str, Sequence[int]]] | None = None,
tags: list[str] | None = None,
):
col_keys = INDEXED_COLS + NON_INDEXED_COLS
rows = []
for _ in range(len(annotation_ids)):
for i in range(len(annotation_ids)):
row: DBRowDataT = {"modified_at": time.time()}
if collection_id:
if collection_id is not None:
row["collection"] = collection_id
if layer_group_id:
if layer_group_id is not None:
row["layer_group"] = layer_group_id
if comment:
if comment is not None:
row["comment"] = comment
if tags:
if tags is not None:
row["tags"] = list(set(tags))
if selected_segments is not None:
row["selected_segments"] = {k: list(v) for k, v in selected_segments[i].items()}
rows.append(row)
ANNOTATIONS_DB[(annotation_ids, col_keys)] = rows

Expand Down Expand Up @@ -282,6 +302,7 @@ def add_bbox_annotation(
collection_id: str,
layer_group_id: str,
comment: str | None = None,
selected_segments: Mapping[str, Sequence[int]] | None = None,
tags: list[str] | None = None,
) -> str: # pragma: no cover # delegation
return add_annotation(
Expand All @@ -291,6 +312,7 @@ def add_bbox_annotation(
),
collection_id=collection_id,
layer_group_id=layer_group_id,
selected_segments=selected_segments,
comment=comment,
tags=tags,
)

0 comments on commit b02f3ac

Please sign in to comment.