Skip to content

Commit

Permalink
Make it possible to display point annotations together with segments.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 660295215
  • Loading branch information
mjanusz authored and copybara-github committed Aug 8, 2024
1 parent fc0a12a commit 71b37a9
Showing 1 changed file with 128 additions and 28 deletions.
156 changes: 128 additions & 28 deletions ffn/utils/proofreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,25 @@
import copy
import itertools
import threading
from typing import Iterable, Optional
from typing import Iterable, Sequence

import networkx as nx
import neuroglancer


Point = tuple[int, int, int]
Point = tuple[int, int, int] # xyz
PointList = list[Point]

# One or more segments, or a mapping from layer name to segments.
ObjectItem = int | Iterable[int] | dict[str, Iterable[int]]

POINT_SHADER = """
void main() {
setColor(defaultColor());
setPointMarkerSize(10.);
}
"""


class Base:
"""Base class for proofreading workflows.
Expand All @@ -42,9 +52,21 @@ class Base:
def __init__(
self,
num_to_prefetch: int = 10,
locations: Optional[Iterable[Point]] = None,
objects: Optional[Iterable[ObjectItem]] = None,
locations: Sequence[Point] | None = None,
objects: Sequence[ObjectItem] | None = None,
points: Sequence[PointList] | None = None,
point_layer: str = 'points',
):
"""Constructor.
Args:
num_to_prefetch: number of object batches to prefetch
locations: XYZ locations to move the 3d cursor to
objects: segments to display
points: point annotations to display
point_layer: name of the annotation layer that will be automatically added
to the viewer if 'points' are provided
"""
self.viewer = neuroglancer.Viewer()
self.num_to_prefetch = num_to_prefetch

Expand All @@ -67,6 +89,14 @@ def __init__(

self.set_init_state()

self.points = points
self.point_layer = point_layer
if self.points is not None:
with self.viewer.txn() as s:
s.layers[point_layer] = neuroglancer.LocalAnnotationLayer(
s.dimensions, shader=POINT_SHADER
)

def _set_todo(self, objects: Iterable[ObjectItem]):
for o in objects:
if isinstance(o, collections.abc.Mapping):
Expand All @@ -84,9 +114,15 @@ def update_msg(self, msg):
with self.viewer.config_state.txn() as s:
s.status_messages['status'] = msg

def update_segments(self, segments, loc=None, layer='seg'):
s = copy.deepcopy(self.viewer.state)
l = s.layers[layer]
def update_segments(
self,
state: neuroglancer.viewer_state.ViewerState,
segments,
layer: str = 'seg',
):
"""Updates the segments selected in a layer."""

l = state.layers[layer]
l.segments = segments

if not self.apply_equivs:
Expand All @@ -97,10 +133,15 @@ def update_segments(self, segments, loc=None, layer='seg'):
a = [aa[layer] for aa in a]
l.equivalences.union(*a)

if loc is not None:
s.position = loc

self.viewer.set_state(s)
def update_points(
self, state: neuroglancer.viewer_state.ViewerState, points: PointList
):
if self.points is None:
return
l = state.layers[self.point_layer]
l.annotations = [
neuroglancer.PointAnnotation(id=repr(xyz), point=xyz) for xyz in points
]

def toggle_equiv(self):
self.apply_equivs = not self.apply_equivs
Expand All @@ -126,9 +167,11 @@ def prev_batch(self):
self.index = max(0, self.index)
self.update_batch()

def list_segments(self, index=None, layer='seg'):
if index is None:
index = self.index
def list_segments(
self, index: int | None = None, layer: str = 'seg'
) -> list[int]:
"""Returns the list of segments to display."""
index = self.index if index is None else index
return list(
set(
itertools.chain(
Expand All @@ -137,23 +180,44 @@ def list_segments(self, index=None, layer='seg'):
)
)

def list_points(self, index: int | None = None) -> PointList:
"""Returns the list of points to display."""
if self.points is None:
return []
index = self.index if index is None else index
return list(
itertools.chain(*[x for x in self.points[index : index + self.batch]])
)

def custom_msg(self):
return ''

def update_batch(self, update=True):
def update_batch(self, index: int | None = None):
"""Refreshes the currently displayed batch."""

if index is None:
index = self.index

s = copy.deepcopy(self.viewer.state)
if self.batch == 1 and self.locations is not None:
loc = self.locations[self.index]
else:
loc = None
s.position = self.locations[index]

for layer in self.managed_layers:
self.update_segments(self.list_segments(layer=layer), loc, layer=layer)
self.update_segments(
s, self.list_segments(index=index, layer=layer), layer=layer
)

self.update_points(s, self.list_points(index=index))
self.viewer.set_state(s)

self.update_msg(
'index:%d/%d batch:%d %s'
% (self.index, len(self.todo), self.batch, self.custom_msg())
)

def prefetch(self):
"""Prefetches the desired number of additional states."""

prefetch_states = []
for i in range(self.num_to_prefetch):
idx = self.index + (i + 1) * self.batch
Expand Down Expand Up @@ -184,7 +248,15 @@ class ObjectReview(Base):
batches.
"""

def __init__(self, objects, bad, num_to_prefetch=10, locations=None):
def __init__(
self,
objects: Sequence[ObjectItem],
bad,
num_to_prefetch: int = 10,
locations: Sequence[Point] | None = None,
points: Sequence[PointList] | None = None,
**kwargs
):
"""Constructor.
Args:
Expand All @@ -196,9 +268,15 @@ def __init__(self, objects, bad, num_to_prefetch=10, locations=None):
locations: iterable of xyz tuples of length len(objects). If specified,
the cursor will be automaticaly moved to the location corresponding to
the current object if batch == 1.
points: point annotations to display
**kwargs: passed to 'Base'
"""
super().__init__(
num_to_prefetch=num_to_prefetch, locations=locations, objects=objects
num_to_prefetch=num_to_prefetch,
locations=locations,
objects=objects,
points=points,
**kwargs
)
self.bad = bad

Expand Down Expand Up @@ -253,20 +331,29 @@ class ObjectClassification(Base):

def __init__(
self,
objects: Iterable[ObjectItem],
key_to_class,
objects: Sequence[ObjectItem],
key_to_class: dict[str, str],
num_to_prefetch: int = 10,
locations=None,
locations: Sequence[Point] | None = None,
points: Sequence[PointList] | None = None,
**kwargs
):
"""Constructor.
Args:
objects: iterable of object IDs
key_to_class: dict mapping keys to class labels
num_to_prefetch: number of `objects` to prefetch
locations: XYZ locations to move the 3d cursor to
points: point annotations to display
**kwargs: passed to Base
"""
super().__init__(
num_to_prefetch=num_to_prefetch, locations=locations, objects=objects
num_to_prefetch=num_to_prefetch,
locations=locations,
objects=objects,
points=points,
**kwargs
)

self.results = defaultdict(set) # class -> ids
Expand Down Expand Up @@ -328,9 +415,20 @@ class GraphUpdater(Base):
"""

def __init__(
self, graph, objects: Iterable[ObjectItem], bad, num_to_prefetch: int = 0
self,
graph: nx.Graph,
objects: Sequence[ObjectItem],
bad,
num_to_prefetch: int = 0,
points: Sequence[PointList] | None = None,
**kwargs
):
super().__init__(objects=objects, num_to_prefetch=num_to_prefetch)
super().__init__(
objects=objects,
num_to_prefetch=num_to_prefetch,
points=points,
**kwargs
)
self.graph = graph
self.split_objects = []
self.split_path = []
Expand Down Expand Up @@ -391,7 +489,9 @@ def add_ccs(self):
if sid in self.graph:
curr |= set(nx.node_connected_component(self.graph, sid))

self.update_segments(curr)
s = copy.deepcopy(self.viewer.state)
self.update_segments(s, curr)
self.viewer.set_state(s)
self.sem.release()

def accept_split(self):
Expand Down

0 comments on commit 71b37a9

Please sign in to comment.