From ce940d25599e0f03d6d3bd365da79b4af77cb2dd Mon Sep 17 00:00:00 2001 From: Andy Sweet Date: Tue, 21 May 2024 11:34:29 -0700 Subject: [PATCH] Support dense segmentation and oriented-point annotations (#21) This adds support for segmentation-mask and oriented-points annotations types. ![Screenshot 2024-05-15 at 16 42 33](https://github.com/chanzuckerberg/napari-cryoet-data-portal/assets/2608297/dfeee38d-5d58-41c3-ba78-9f87afd589a0) Segmentation masks are supported by loading the corresponding OME-Zarrs in a similar way to the tomogram (i.e. according to resolution). Oriented points are currently read and displayed as regular points. We can follow up with another PR to augment that with the orientation information (e.g. with a vectors layer). As some datasets contain tomograms with lots of different annotations (e.g. 10000), performance can suffer both while waiting for data to arrive and when interacting with napari once it has. This also fixes an alignment issue caused by the highest resolution tomogram volume having a non-unit volume. I originally separated that into #31, but there was some overlap with the changes here, so I brought it into this PR. Closes #22 and #23 --- setup.cfg | 2 + src/napari_cryoet_data_portal/_open_widget.py | 95 ++++++++------- src/napari_cryoet_data_portal/_reader.py | 110 +++++++++++++++--- src/napari_cryoet_data_portal/_sample_data.py | 4 +- .../_tests/conftest.py | 9 +- .../_tests/test_open_widget.py | 4 +- .../_tests/test_reader.py | 28 ++++- 7 files changed, 187 insertions(+), 65 deletions(-) diff --git a/setup.cfg b/setup.cfg index 55657c6..7b572df 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,10 +31,12 @@ project_urls = [options] packages = find: install_requires = + cmap cryoet_data_portal ~= 3.0 fsspec[http,s3] npe2 numpy + napari>=0.4.19 napari_ome_zarr ndjson qtpy diff --git a/src/napari_cryoet_data_portal/_open_widget.py b/src/napari_cryoet_data_portal/_open_widget.py index 40db883..3b15f78 100644 --- a/src/napari_cryoet_data_portal/_open_widget.py +++ b/src/napari_cryoet_data_portal/_open_widget.py @@ -19,7 +19,7 @@ from napari_cryoet_data_portal._logging import logger from napari_cryoet_data_portal._progress_widget import ProgressWidget from napari_cryoet_data_portal._reader import ( - read_annotation, + read_annotation_files, read_tomogram, ) @@ -34,16 +34,6 @@ class Resolution: indices: Tuple[int, ...] scale: float - @property - def offset(self) -> float: - """The offset due to a larger first pixel for lower resolutions. - - When visualized in napari, this ensures that the different multi-scale - layers opened separately share the same visual extent in the canvas that - starts at (-0.5, -0.5, -0.5). - """ - return (self.scale - 1) / 2 - MULTI_RESOLUTION = Resolution(name="Multi", indices=(0, 1, 2), scale=1) HIGH_RESOLUTION = Resolution(name="High", indices=(0,), scale=1) @@ -138,24 +128,11 @@ def _loadTomogram( resolution: Resolution, ) -> Generator[FullLayerData, None, None]: logger.debug("OpenWidget._loadTomogram: %s", tomogram.name) - image_data, image_attrs, _ = read_tomogram(tomogram) - # Skip indexing for multi-resolution to avoid adding any - # unnecessary nodes to the dask compute graph. - if resolution is not MULTI_RESOLUTION: - image_data = image_data[resolution.indices[0]] - # Materialize low resolution immediately on this thread to prevent napari blocking. - if resolution is LOW_RESOLUTION: - image_data = np.asarray(image_data) - # Get scale before resolution scaling for annotations. - image_scale = image_attrs["scale"] - image_attrs["scale"] = tuple( - resolution.scale * s for s in image_scale - ) - image_translate = image_attrs.get("translate", (0,) * len(image_attrs["scale"])) - image_attrs["translate"] = tuple( - resolution.offset + t for t in image_translate - ) - yield image_data, image_attrs, "image" + image_layer = read_tomogram(tomogram) + # Extract image_scale before the resolution is taken into account, + # so we can use it to align other annotations later. + image_scale = image_layer[1]["scale"] + yield _handle_image_at_resolution(image_layer, resolution) # Looking up tomogram.tomogram_voxel_spacing.annotations triggers a query # using the client from where the tomogram was found. @@ -167,21 +144,12 @@ def _loadTomogram( ) for annotation in annotations: - point_paths = tuple( - f.https_path - for f in annotation.files - if f.shape_type == "Point" - ) - if len(point_paths) > 0: - anno_data, anno_attrs, anno_type = read_annotation(annotation, tomogram=tomogram) - # Inherit scale from full resolution image so that we can pick up - # that scale when it changes. - anno_attrs["scale"] = image_scale - # Scaling points also changes the size, so adjust accordingly. - anno_attrs["size"] /= np.mean(image_scale) - yield anno_data, anno_attrs, anno_type - else: - logger.warn("Found no points annotations. Skipping.") + for layer in read_annotation_files(annotation, tomogram=tomogram): + if layer[2] == "labels": + layer = _handle_image_at_resolution(layer, resolution) + elif layer[2] == "points": + layer = _handle_points_at_scale(layer, image_scale) + yield layer def _onLayerLoaded(self, layer_data: FullLayerData) -> None: logger.debug("OpenWidget._onLayerLoaded") @@ -190,5 +158,44 @@ def _onLayerLoaded(self, layer_data: FullLayerData) -> None: self._viewer.add_image(data, **attrs) elif layer_type == "points": self._viewer.add_points(data, **attrs) + elif layer_type == "labels": + self._viewer.add_labels(data, **attrs) else: raise AssertionError(f"Unexpected {layer_type=}") + + +def _handle_image_at_resolution(layer_data: FullLayerData, resolution: Resolution) -> FullLayerData: + data, attrs, layer_type = layer_data + # Skip indexing for multi-resolution to avoid adding any + # unnecessary nodes to the dask compute graph. + if resolution is not MULTI_RESOLUTION: + data = data[resolution.indices[0]] + + # Materialize low resolution immediately on this thread to prevent napari blocking. + # Once async loading is working on a stable napari release, we could remove this. + if resolution is LOW_RESOLUTION: + data = np.asarray(data) + + # Adjust the scale and and translation based on the resolution. + image_scale = attrs["scale"] + attrs["scale"] = tuple(resolution.scale * s for s in image_scale) + # Offset the translation due to a larger first pixel for lower resolutions. + # When visualized in napari, this ensures that the different multi-scale + # layers opened separately share the same visual extent in the canvas that + # starts at some scaled version of (-0.5, -0.5, -0.5). + image_translate = attrs.get("translate", (0,) * len(image_scale)) + attrs["translate"] = tuple( + (s * (resolution.scale - 1) / 2) + t + for s, t in zip(image_scale, image_translate) + ) + return data, attrs, layer_type + + +def _handle_points_at_scale(layer_data: FullLayerData, image_scale: Tuple[float, float, float]) -> FullLayerData: + data, attrs, layer_type = layer_data + # Inherit scale from full resolution image, so that points are visually + # aligned with the image. + attrs["scale"] = image_scale + # Scaling points also changes the size, so adjust accordingly. + attrs["size"] /= np.mean(image_scale) + return data, attrs, layer_type \ No newline at end of file diff --git a/src/napari_cryoet_data_portal/_reader.py b/src/napari_cryoet_data_portal/_reader.py index 86cb318..70ca7d5 100644 --- a/src/napari_cryoet_data_portal/_reader.py +++ b/src/napari_cryoet_data_portal/_reader.py @@ -1,22 +1,34 @@ """Functions to read data from the portal into napari types.""" -from typing import Any, Dict, List, Optional, Tuple +import warnings +from typing import Any, Dict, Generator, List, Optional, Tuple import fsspec +import numpy as np import ndjson from napari_ome_zarr import napari_get_reader from npe2.types import FullLayerData, PathOrPaths, ReaderFunction -from cryoet_data_portal import Annotation, Tomogram +from cryoet_data_portal import Annotation, AnnotationFile, Tomogram +from cmap import Colormap +from napari.utils.colormaps import direct_colormap from napari_cryoet_data_portal._logging import logger +# Maps integer value of Annotation.object_id to a color. +OBJECT_COLORMAP = Colormap("colorbrewer:set1_8") +# Fallback color when ID cannot be parsed. +DEFAULT_OBJECT_COLOR = np.array(OBJECT_COLORMAP(0).rgba) -OBJECT_COLOR = { - "ribosome": "red", - "ribosome, 80 s": "red", - "fatty acid synthase": "darkblue", -} -DEFAULT_OBJECT_COLOR = "red" + +def _annotation_color(annotation: Annotation) -> np.ndarray: + """Maps an annotation to a color based on its object_id.""" + try: + object_id = int(annotation.object_id.split(":")[-1]) + except RuntimeError as e: + logger.error("Failed to parse integer from object_id: %s\%s", annotation.object_id, e) + return DEFAULT_OBJECT_COLOR + color = OBJECT_COLORMAP(object_id % len(OBJECT_COLORMAP.color_stops)) + return np.array(color.rgba) def tomogram_ome_zarr_reader(path: PathOrPaths) -> Optional[ReaderFunction]: @@ -128,8 +140,8 @@ def points_annotations_reader(path: PathOrPaths) -> Optional[ReaderFunction]: -------- >>> annotation_dir = 's3://cryoet-data-portal-public/10000/TS_026/Tomograms/VoxelSpacing13.48/Annotations' >>> path = ( - f'{annotation_dir}/sara_goetz-ribosome-1.0.json', - f'{annotation_dir}/sara_goetz-fatty_acid_synthase-1.0.json', + f'{annotation_dir}/sara_goetz-ribosome-1.0.ndjson', + f'{annotation_dir}/sara_goetz-fatty_acid_synthase-1.0.ndjson', ) >>> reader = points_annotations_reader(path) >>> layers = reader(path) @@ -167,9 +179,11 @@ def read_points_annotations_ndjson(path: str) -> FullLayerData: attributes = { "name": "annotations", "size": 14, - "face_color": "red", + "face_color": DEFAULT_OBJECT_COLOR, "opacity": 0.5, - "out_of_slice_display": True, + # Disable out-of-slice display because of: + # https://github.com/napari/napari/issues/6914 + "out_of_slice_display": False, } return data, attributes, "points" @@ -197,6 +211,10 @@ def read_annotation(annotation: Annotation, *, tomogram: Optional[Tomogram] = No >>> data, attrs, _ = read_annotation(annotation) >>> points = Points(data, **attrs) """ + warnings.warn( + "read_annotation is deprecated from v0.4.0 because of Annotation schema changes. " + "Use read_annotation_files instead.", + category=DeprecationWarning) point_paths = tuple( f.https_path for f in annotation.files @@ -211,10 +229,74 @@ def read_annotation(annotation: Annotation, *, tomogram: Optional[Tomogram] = No else: attributes["name"] = f"{tomogram.name}-{name}" attributes["metadata"] = annotation.to_dict() - attributes["face_color"] = OBJECT_COLOR.get(name.lower(), DEFAULT_OBJECT_COLOR) + attributes["face_color"] = _annotation_color(annotation) + return data, attributes, layer_type + + +def read_annotation_files(annotation: Annotation, *, tomogram: Optional[Tomogram] = None) -> Generator[FullLayerData, None, None]: + """Reads multiple annotation layers. + + Parameters + ---------- + annotation : Annotation + The tomogram annotation. + tomogram : Tomogram, optional + The associated tomogram, which may be used for other metadata. + + Yields + ------- + napari layer data tuple + The data, attributes, and type name of the layer that would be + returned by `Points.as_layer_data_tuple` or `Labels.as_layer_data_tuple`. + + Examples + -------- + >>> client = Client() + >>> annotation = client.find_one(Annotation) + >>> for data, attrs, typ in read_annotation_files(annotation): + layer = Layer.create(data, attrs, typ) + """ + for f in annotation.files: + if (f.shape_type in ("Point", "OrientedPoint")) and (f.format == "ndjson"): + yield _read_points_annotation_file(f, anno=annotation, tomogram=tomogram) + elif (f.shape_type == "SegmentationMask") and (f.format == "zarr"): + yield _read_labels_annotation_file(f, anno=annotation, tomogram=tomogram) + else: + logger.warn("Found unsupported annotation file: %s, %s. Skipping.", f.shape_type, f.format) + + +def _read_points_annotation_file(anno_file: AnnotationFile, *, anno: Annotation, tomogram: Optional[Tomogram]) -> FullLayerData: + assert anno_file.shape_type in ("Point", "OrientedPoint") + assert anno_file.format == "ndjson" + data, attributes, layer_type = read_points_annotations_ndjson(anno_file.https_path) + name = anno.object_name + if tomogram is None: + attributes["name"] = name + else: + attributes["name"] = f"{tomogram.name}-{name}" + attributes["metadata"] = anno_file.to_dict() + attributes["face_color"] = _annotation_color(anno) return data, attributes, layer_type +def _read_labels_annotation_file(anno_file: AnnotationFile, *, anno: Annotation, tomogram: Optional[Tomogram]) -> FullLayerData: + assert anno_file.shape_type == "SegmentationMask" + assert anno_file.format == "zarr" + data, attributes, _ = read_tomogram_ome_zarr(anno_file.https_path) + name = anno.object_name + if tomogram is None: + attributes["name"] = name + else: + attributes["name"] = f"{tomogram.name}-{name}" + attributes["metadata"] = anno_file.to_dict() + attributes["opacity"] = 0.5 + attributes["colormap"] = direct_colormap({ + None: np.zeros(4), + 1: _annotation_color(anno), + }) + return data, attributes, "labels" + + def _read_points_data( path: str, ) -> List[Tuple[float, float, float]]: @@ -223,7 +305,7 @@ def _read_points_data( sub_data = [ _annotation_to_point(annotation) for annotation in ndjson.load(f) - if annotation["type"] == "point" + if annotation["type"] in ("point", "orientedPoint") ] data.extend(sub_data) return data diff --git a/src/napari_cryoet_data_portal/_sample_data.py b/src/napari_cryoet_data_portal/_sample_data.py index 225396e..735f94d 100644 --- a/src/napari_cryoet_data_portal/_sample_data.py +++ b/src/napari_cryoet_data_portal/_sample_data.py @@ -36,12 +36,12 @@ def _read_tomogram_from_10000(name: str) -> List[FullLayerData]: ribosome_annotations = [ item for item in annotations - if item.object_name.lower() == "cytosolic ribosome" + if "cytosolic ribosome" in item.object_name.lower() ].pop() fas_annotations = [ item for item in annotations - if item.object_name.lower() == "fatty acid synthase" + if "fatty acid synthase" in item.object_name.lower() ].pop() ribosome_points = read_annotation(ribosome_annotations, tomogram=tomogram) fatty_acid_points = read_annotation(fas_annotations, tomogram=tomogram) diff --git a/src/napari_cryoet_data_portal/_tests/conftest.py b/src/napari_cryoet_data_portal/_tests/conftest.py index 1ade18f..516ff68 100644 --- a/src/napari_cryoet_data_portal/_tests/conftest.py +++ b/src/napari_cryoet_data_portal/_tests/conftest.py @@ -1,5 +1,6 @@ import pytest -from cryoet_data_portal import Client, Dataset, Tomogram + +from cryoet_data_portal import Annotation, AnnotationFile, Client, Dataset, Tomogram @pytest.fixture() @@ -21,3 +22,9 @@ def tomogram(client: Client) -> Tomogram: Tomogram.https_omezarr_dir.like("%13.480%"), ], ).pop() + + +@pytest.fixture() +def annotation_with_points(client: Client) -> Annotation: + anno_file = AnnotationFile.find(client, [AnnotationFile.shape_type == "Point"]).pop() + return anno_file.annotation diff --git a/src/napari_cryoet_data_portal/_tests/test_open_widget.py b/src/napari_cryoet_data_portal/_tests/test_open_widget.py index db2a6d2..70f1bbf 100644 --- a/src/napari_cryoet_data_portal/_tests/test_open_widget.py +++ b/src/napari_cryoet_data_portal/_tests/test_open_widget.py @@ -36,7 +36,7 @@ def test_set_tomogram_adds_layers_to_viewer(widget: OpenWidget, tomogram: Tomogr with qtbot.waitSignal(widget._progress.finished, timeout=30000): widget.setTomogram(tomogram) - assert len(widget._viewer.layers) == 3 + assert len(widget._viewer.layers) > 1 def test_set_tomogram_adds_layers_to_viewer_without_clearing_existing(widget: OpenWidget, tomogram: Tomogram, qtbot: QtBot): @@ -48,5 +48,5 @@ def test_set_tomogram_adds_layers_to_viewer_without_clearing_existing(widget: Op with qtbot.waitSignal(widget._progress.finished, timeout=30000): widget.setTomogram(tomogram) - assert len(widget._viewer.layers) == 4 + assert len(widget._viewer.layers) > 1 \ No newline at end of file diff --git a/src/napari_cryoet_data_portal/_tests/test_reader.py b/src/napari_cryoet_data_portal/_tests/test_reader.py index 2657253..b5a5309 100644 --- a/src/napari_cryoet_data_portal/_tests/test_reader.py +++ b/src/napari_cryoet_data_portal/_tests/test_reader.py @@ -1,9 +1,14 @@ +import pytest from typing import Callable +import numpy as np +from cryoet_data_portal import Annotation from napari import Viewer from napari.layers import Points -from napari_cryoet_data_portal import ( +from napari_cryoet_data_portal._reader import ( + read_annotation, + read_annotation_files, read_points_annotations_ndjson, read_tomogram_ome_zarr, ) @@ -22,7 +27,7 @@ def test_read_tomogram_ome_zarr(): assert data[0].shape == (1000, 928, 960) assert data[1].shape == (500, 464, 480) assert data[2].shape == (250, 232, 240) - assert attrs["scale"] == (1, 1, 1) + np.testing.assert_allclose(attrs["scale"], (13.48, 13.48, 13.48), atol=0.01) assert layer_type == "image" @@ -44,3 +49,22 @@ def test_open_points_annotations(make_napari_viewer: Callable[[], Viewer]): assert len(layers) == 1 assert isinstance(layers[0], Points) + + +def test_read_annotation(annotation_with_points: Annotation): + with pytest.warns(DeprecationWarning): + data, attrs, layer_type = read_annotation(annotation_with_points) + + assert len(data) > 0 + assert len(attrs["name"]) > 0 + assert layer_type == "points" + + +def test_read_annotation_files(annotation_with_points: Annotation): + layers = list(read_annotation_files(annotation_with_points)) + + assert len(layers) == 1 + data, attrs, layer_type = layers[0] + assert len(data) > 0 + assert len(attrs["name"]) > 0 + assert layer_type == "points"