Skip to content

Commit

Permalink
Support dense segmentation and oriented-point annotations (#21)
Browse files Browse the repository at this point in the history
This adds support for segmentation-mask and oriented-points annotations
types.

![Screenshot 2024-05-15 at 16 42
33](https://github.com/chanzuckerberg/napari-cryoet-data-portal/assets/2608297/dfeee38d-5d58-41c3-ba78-9f87afd589a0)

Segmentation masks are supported by loading the corresponding OME-Zarrs
in a similar way to the tomogram (i.e. according to resolution).

Oriented points are currently read and displayed as regular points. We
can follow up with another PR to augment that with the orientation
information (e.g. with a vectors layer).

As some datasets contain tomograms with lots of different annotations
(e.g. 10000), performance can suffer both while waiting for data to
arrive and when interacting with napari once it has.

This also fixes an alignment issue caused by the highest resolution
tomogram volume having a non-unit volume. I originally separated that
into #31, but there was some overlap with the changes here, so I brought
it into this PR.

Closes #22 and #23
  • Loading branch information
andy-sweet committed May 21, 2024
1 parent 8f5c5c5 commit ce940d2
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 65 deletions.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ project_urls =
[options]
packages = find:
install_requires =
cmap
cryoet_data_portal ~= 3.0
fsspec[http,s3]
npe2
numpy
napari>=0.4.19
napari_ome_zarr
ndjson
qtpy
Expand Down
95 changes: 51 additions & 44 deletions src/napari_cryoet_data_portal/_open_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from napari_cryoet_data_portal._logging import logger
from napari_cryoet_data_portal._progress_widget import ProgressWidget
from napari_cryoet_data_portal._reader import (
read_annotation,
read_annotation_files,
read_tomogram,
)

Expand All @@ -34,16 +34,6 @@ class Resolution:
indices: Tuple[int, ...]
scale: float

@property
def offset(self) -> float:
"""The offset due to a larger first pixel for lower resolutions.
When visualized in napari, this ensures that the different multi-scale
layers opened separately share the same visual extent in the canvas that
starts at (-0.5, -0.5, -0.5).
"""
return (self.scale - 1) / 2


MULTI_RESOLUTION = Resolution(name="Multi", indices=(0, 1, 2), scale=1)
HIGH_RESOLUTION = Resolution(name="High", indices=(0,), scale=1)
Expand Down Expand Up @@ -138,24 +128,11 @@ def _loadTomogram(
resolution: Resolution,
) -> Generator[FullLayerData, None, None]:
logger.debug("OpenWidget._loadTomogram: %s", tomogram.name)
image_data, image_attrs, _ = read_tomogram(tomogram)
# Skip indexing for multi-resolution to avoid adding any
# unnecessary nodes to the dask compute graph.
if resolution is not MULTI_RESOLUTION:
image_data = image_data[resolution.indices[0]]
# Materialize low resolution immediately on this thread to prevent napari blocking.
if resolution is LOW_RESOLUTION:
image_data = np.asarray(image_data)
# Get scale before resolution scaling for annotations.
image_scale = image_attrs["scale"]
image_attrs["scale"] = tuple(
resolution.scale * s for s in image_scale
)
image_translate = image_attrs.get("translate", (0,) * len(image_attrs["scale"]))
image_attrs["translate"] = tuple(
resolution.offset + t for t in image_translate
)
yield image_data, image_attrs, "image"
image_layer = read_tomogram(tomogram)
# Extract image_scale before the resolution is taken into account,
# so we can use it to align other annotations later.
image_scale = image_layer[1]["scale"]
yield _handle_image_at_resolution(image_layer, resolution)

# Looking up tomogram.tomogram_voxel_spacing.annotations triggers a query
# using the client from where the tomogram was found.
Expand All @@ -167,21 +144,12 @@ def _loadTomogram(
)

for annotation in annotations:
point_paths = tuple(
f.https_path
for f in annotation.files
if f.shape_type == "Point"
)
if len(point_paths) > 0:
anno_data, anno_attrs, anno_type = read_annotation(annotation, tomogram=tomogram)
# Inherit scale from full resolution image so that we can pick up
# that scale when it changes.
anno_attrs["scale"] = image_scale
# Scaling points also changes the size, so adjust accordingly.
anno_attrs["size"] /= np.mean(image_scale)
yield anno_data, anno_attrs, anno_type
else:
logger.warn("Found no points annotations. Skipping.")
for layer in read_annotation_files(annotation, tomogram=tomogram):
if layer[2] == "labels":
layer = _handle_image_at_resolution(layer, resolution)
elif layer[2] == "points":
layer = _handle_points_at_scale(layer, image_scale)
yield layer

def _onLayerLoaded(self, layer_data: FullLayerData) -> None:
logger.debug("OpenWidget._onLayerLoaded")
Expand All @@ -190,5 +158,44 @@ def _onLayerLoaded(self, layer_data: FullLayerData) -> None:
self._viewer.add_image(data, **attrs)
elif layer_type == "points":
self._viewer.add_points(data, **attrs)
elif layer_type == "labels":
self._viewer.add_labels(data, **attrs)
else:
raise AssertionError(f"Unexpected {layer_type=}")


def _handle_image_at_resolution(layer_data: FullLayerData, resolution: Resolution) -> FullLayerData:
data, attrs, layer_type = layer_data
# Skip indexing for multi-resolution to avoid adding any
# unnecessary nodes to the dask compute graph.
if resolution is not MULTI_RESOLUTION:
data = data[resolution.indices[0]]

# Materialize low resolution immediately on this thread to prevent napari blocking.
# Once async loading is working on a stable napari release, we could remove this.
if resolution is LOW_RESOLUTION:
data = np.asarray(data)

# Adjust the scale and and translation based on the resolution.
image_scale = attrs["scale"]
attrs["scale"] = tuple(resolution.scale * s for s in image_scale)
# Offset the translation due to a larger first pixel for lower resolutions.
# When visualized in napari, this ensures that the different multi-scale
# layers opened separately share the same visual extent in the canvas that
# starts at some scaled version of (-0.5, -0.5, -0.5).
image_translate = attrs.get("translate", (0,) * len(image_scale))
attrs["translate"] = tuple(
(s * (resolution.scale - 1) / 2) + t
for s, t in zip(image_scale, image_translate)
)
return data, attrs, layer_type


def _handle_points_at_scale(layer_data: FullLayerData, image_scale: Tuple[float, float, float]) -> FullLayerData:
data, attrs, layer_type = layer_data
# Inherit scale from full resolution image, so that points are visually
# aligned with the image.
attrs["scale"] = image_scale
# Scaling points also changes the size, so adjust accordingly.
attrs["size"] /= np.mean(image_scale)
return data, attrs, layer_type
110 changes: 96 additions & 14 deletions src/napari_cryoet_data_portal/_reader.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@
"""Functions to read data from the portal into napari types."""

from typing import Any, Dict, List, Optional, Tuple
import warnings
from typing import Any, Dict, Generator, List, Optional, Tuple
import fsspec

import numpy as np
import ndjson
from napari_ome_zarr import napari_get_reader
from npe2.types import FullLayerData, PathOrPaths, ReaderFunction
from cryoet_data_portal import Annotation, Tomogram
from cryoet_data_portal import Annotation, AnnotationFile, Tomogram
from cmap import Colormap
from napari.utils.colormaps import direct_colormap

from napari_cryoet_data_portal._logging import logger

# Maps integer value of Annotation.object_id to a color.
OBJECT_COLORMAP = Colormap("colorbrewer:set1_8")
# Fallback color when ID cannot be parsed.
DEFAULT_OBJECT_COLOR = np.array(OBJECT_COLORMAP(0).rgba)

OBJECT_COLOR = {
"ribosome": "red",
"ribosome, 80 s": "red",
"fatty acid synthase": "darkblue",
}
DEFAULT_OBJECT_COLOR = "red"

def _annotation_color(annotation: Annotation) -> np.ndarray:
"""Maps an annotation to a color based on its object_id."""
try:
object_id = int(annotation.object_id.split(":")[-1])
except RuntimeError as e:
logger.error("Failed to parse integer from object_id: %s\%s", annotation.object_id, e)
return DEFAULT_OBJECT_COLOR
color = OBJECT_COLORMAP(object_id % len(OBJECT_COLORMAP.color_stops))
return np.array(color.rgba)


def tomogram_ome_zarr_reader(path: PathOrPaths) -> Optional[ReaderFunction]:
Expand Down Expand Up @@ -128,8 +140,8 @@ def points_annotations_reader(path: PathOrPaths) -> Optional[ReaderFunction]:
--------
>>> annotation_dir = 's3://cryoet-data-portal-public/10000/TS_026/Tomograms/VoxelSpacing13.48/Annotations'
>>> path = (
f'{annotation_dir}/sara_goetz-ribosome-1.0.json',
f'{annotation_dir}/sara_goetz-fatty_acid_synthase-1.0.json',
f'{annotation_dir}/sara_goetz-ribosome-1.0.ndjson',
f'{annotation_dir}/sara_goetz-fatty_acid_synthase-1.0.ndjson',
)
>>> reader = points_annotations_reader(path)
>>> layers = reader(path)
Expand Down Expand Up @@ -167,9 +179,11 @@ def read_points_annotations_ndjson(path: str) -> FullLayerData:
attributes = {
"name": "annotations",
"size": 14,
"face_color": "red",
"face_color": DEFAULT_OBJECT_COLOR,
"opacity": 0.5,
"out_of_slice_display": True,
# Disable out-of-slice display because of:
# https://github.com/napari/napari/issues/6914
"out_of_slice_display": False,
}
return data, attributes, "points"

Expand Down Expand Up @@ -197,6 +211,10 @@ def read_annotation(annotation: Annotation, *, tomogram: Optional[Tomogram] = No
>>> data, attrs, _ = read_annotation(annotation)
>>> points = Points(data, **attrs)
"""
warnings.warn(
"read_annotation is deprecated from v0.4.0 because of Annotation schema changes. "
"Use read_annotation_files instead.",
category=DeprecationWarning)
point_paths = tuple(
f.https_path
for f in annotation.files
Expand All @@ -211,10 +229,74 @@ def read_annotation(annotation: Annotation, *, tomogram: Optional[Tomogram] = No
else:
attributes["name"] = f"{tomogram.name}-{name}"
attributes["metadata"] = annotation.to_dict()
attributes["face_color"] = OBJECT_COLOR.get(name.lower(), DEFAULT_OBJECT_COLOR)
attributes["face_color"] = _annotation_color(annotation)
return data, attributes, layer_type


def read_annotation_files(annotation: Annotation, *, tomogram: Optional[Tomogram] = None) -> Generator[FullLayerData, None, None]:
"""Reads multiple annotation layers.
Parameters
----------
annotation : Annotation
The tomogram annotation.
tomogram : Tomogram, optional
The associated tomogram, which may be used for other metadata.
Yields
-------
napari layer data tuple
The data, attributes, and type name of the layer that would be
returned by `Points.as_layer_data_tuple` or `Labels.as_layer_data_tuple`.
Examples
--------
>>> client = Client()
>>> annotation = client.find_one(Annotation)
>>> for data, attrs, typ in read_annotation_files(annotation):
layer = Layer.create(data, attrs, typ)
"""
for f in annotation.files:
if (f.shape_type in ("Point", "OrientedPoint")) and (f.format == "ndjson"):
yield _read_points_annotation_file(f, anno=annotation, tomogram=tomogram)
elif (f.shape_type == "SegmentationMask") and (f.format == "zarr"):
yield _read_labels_annotation_file(f, anno=annotation, tomogram=tomogram)
else:
logger.warn("Found unsupported annotation file: %s, %s. Skipping.", f.shape_type, f.format)


def _read_points_annotation_file(anno_file: AnnotationFile, *, anno: Annotation, tomogram: Optional[Tomogram]) -> FullLayerData:
assert anno_file.shape_type in ("Point", "OrientedPoint")
assert anno_file.format == "ndjson"
data, attributes, layer_type = read_points_annotations_ndjson(anno_file.https_path)
name = anno.object_name
if tomogram is None:
attributes["name"] = name
else:
attributes["name"] = f"{tomogram.name}-{name}"
attributes["metadata"] = anno_file.to_dict()
attributes["face_color"] = _annotation_color(anno)
return data, attributes, layer_type


def _read_labels_annotation_file(anno_file: AnnotationFile, *, anno: Annotation, tomogram: Optional[Tomogram]) -> FullLayerData:
assert anno_file.shape_type == "SegmentationMask"
assert anno_file.format == "zarr"
data, attributes, _ = read_tomogram_ome_zarr(anno_file.https_path)
name = anno.object_name
if tomogram is None:
attributes["name"] = name
else:
attributes["name"] = f"{tomogram.name}-{name}"
attributes["metadata"] = anno_file.to_dict()
attributes["opacity"] = 0.5
attributes["colormap"] = direct_colormap({
None: np.zeros(4),
1: _annotation_color(anno),
})
return data, attributes, "labels"


def _read_points_data(
path: str,
) -> List[Tuple[float, float, float]]:
Expand All @@ -223,7 +305,7 @@ def _read_points_data(
sub_data = [
_annotation_to_point(annotation)
for annotation in ndjson.load(f)
if annotation["type"] == "point"
if annotation["type"] in ("point", "orientedPoint")
]
data.extend(sub_data)
return data
Expand Down
4 changes: 2 additions & 2 deletions src/napari_cryoet_data_portal/_sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def _read_tomogram_from_10000(name: str) -> List[FullLayerData]:
ribosome_annotations = [
item
for item in annotations
if item.object_name.lower() == "cytosolic ribosome"
if "cytosolic ribosome" in item.object_name.lower()
].pop()
fas_annotations = [
item
for item in annotations
if item.object_name.lower() == "fatty acid synthase"
if "fatty acid synthase" in item.object_name.lower()
].pop()
ribosome_points = read_annotation(ribosome_annotations, tomogram=tomogram)
fatty_acid_points = read_annotation(fas_annotations, tomogram=tomogram)
Expand Down
9 changes: 8 additions & 1 deletion src/napari_cryoet_data_portal/_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from cryoet_data_portal import Client, Dataset, Tomogram

from cryoet_data_portal import Annotation, AnnotationFile, Client, Dataset, Tomogram


@pytest.fixture()
Expand All @@ -21,3 +22,9 @@ def tomogram(client: Client) -> Tomogram:
Tomogram.https_omezarr_dir.like("%13.480%"),
],
).pop()


@pytest.fixture()
def annotation_with_points(client: Client) -> Annotation:
anno_file = AnnotationFile.find(client, [AnnotationFile.shape_type == "Point"]).pop()
return anno_file.annotation
4 changes: 2 additions & 2 deletions src/napari_cryoet_data_portal/_tests/test_open_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_set_tomogram_adds_layers_to_viewer(widget: OpenWidget, tomogram: Tomogr
with qtbot.waitSignal(widget._progress.finished, timeout=30000):
widget.setTomogram(tomogram)

assert len(widget._viewer.layers) == 3
assert len(widget._viewer.layers) > 1


def test_set_tomogram_adds_layers_to_viewer_without_clearing_existing(widget: OpenWidget, tomogram: Tomogram, qtbot: QtBot):
Expand All @@ -48,5 +48,5 @@ def test_set_tomogram_adds_layers_to_viewer_without_clearing_existing(widget: Op
with qtbot.waitSignal(widget._progress.finished, timeout=30000):
widget.setTomogram(tomogram)

assert len(widget._viewer.layers) == 4
assert len(widget._viewer.layers) > 1

Loading

0 comments on commit ce940d2

Please sign in to comment.