Skip to content

Commit

Permalink
feat: deterministic annotatino ids
Browse files Browse the repository at this point in the history
  • Loading branch information
supersergiy committed Feb 22, 2025
1 parent 88324e5 commit bfabc04
Showing 1 changed file with 60 additions and 16 deletions.
76 changes: 60 additions & 16 deletions zetta_utils/db_annotations/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import copy
import json
import time
import uuid
from typing import Any, Union, cast, overload
Expand All @@ -14,7 +15,10 @@
LineAnnotation,
PointAnnotation,
)
from typeguard import typechecked

from zetta_utils import builder
from zetta_utils.geometry.bbox import BBox3D
from zetta_utils.layer.db_layer import DBRowDataT
from zetta_utils.layer.db_layer.firestore import build_firestore_layer
from zetta_utils.parsing.ngl_state import AnnotationKeys
Expand Down Expand Up @@ -135,23 +139,41 @@ def read_annotations(
return {k: AnnotationDBEntry.from_dict(k, cast(dict, v)) for k, v in result_raw.items()}


def add_annotation(
def _make_annotation_row(
annotation: NgAnnotation,
*,
collection_id: str,
layer_group_id: str,
comment: str | None = None,
tags: list[str] | None = None,
) -> str:
comment: str | None,
tags: list[str] | None,
) -> tuple[str, dict]:
row = annotation.to_json()
row["collection"] = collection_id
row["layer_group"] = layer_group_id
row["comment"] = comment
row["created_at"] = time.time()

row_str = json.dumps(row, sort_keys=True)
annotation_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, row_str))
row["created_at"] = time.time()
if tags:
row["tags"] = list(set(tags))
annotation_id = str(uuid.uuid4())
return annotation_id, row


def add_annotation(
annotation: NgAnnotation,
*,
collection_id: str,
layer_group_id: str,
comment: str | None = None,
tags: list[str] | None = None,
) -> str:
annotation_id, row = _make_annotation_row(
annotation,
collection_id=collection_id,
layer_group_id=layer_group_id,
comment=comment,
tags=tags,
)
col_keys = INDEXED_COLS + NON_INDEXED_COLS
ANNOTATIONS_DB[(annotation_id, col_keys)] = row
return annotation_id
Expand All @@ -167,16 +189,16 @@ def add_annotations(
) -> list[str]:
rows = []
annotation_ids = []
for ann in annotations:
row = ann.to_json()
row["collection"] = collection_id
row["layer_group"] = layer_group_id
row["comment"] = comment
row["created_at"] = time.time()
if tags:
row["tags"] = list(set(tags))
for annotation in annotations:
annotation_id, row = _make_annotation_row(
annotation,
collection_id=collection_id,
layer_group_id=layer_group_id,
comment=comment,
tags=tags,
)
rows.append(row)
annotation_ids.append(str(uuid.uuid4()))
annotation_ids.append(annotation_id)
col_keys = INDEXED_COLS + NON_INDEXED_COLS
ANNOTATIONS_DB[(annotation_ids, col_keys)] = rows
return annotation_ids
Expand Down Expand Up @@ -250,3 +272,25 @@ def parse_ng_annotations(annotations_raw: list[dict]) -> list[NgAnnotation]:
else:
annotations.append(AxisAlignedBoundingBoxAnnotation(**ann))
return annotations


@typechecked
@builder.register("add_bbox_annotation")
def add_bbox_annotation(
bbox: BBox3D,
*,
collection_id: str,
layer_group_id: str,
comment: str | None = None,
tags: list[str] | None = None,
) -> str: # pragma: no cover # delegation
return add_annotation(
AxisAlignedBoundingBoxAnnotation(
point_a=bbox.start,
point_b=bbox.end,
),
collection_id=collection_id,
layer_group_id=layer_group_id,
comment=comment,
tags=tags,
)

0 comments on commit bfabc04

Please sign in to comment.