diff --git a/setup.cfg b/setup.cfg index 55657c6..7b572df 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,10 +31,12 @@ project_urls = [options] packages = find: install_requires = + cmap cryoet_data_portal ~= 3.0 fsspec[http,s3] npe2 numpy + napari>=0.4.19 napari_ome_zarr ndjson qtpy diff --git a/src/napari_cryoet_data_portal/_open_widget.py b/src/napari_cryoet_data_portal/_open_widget.py index 40db883..3b15f78 100644 --- a/src/napari_cryoet_data_portal/_open_widget.py +++ b/src/napari_cryoet_data_portal/_open_widget.py @@ -19,7 +19,7 @@ from napari_cryoet_data_portal._logging import logger from napari_cryoet_data_portal._progress_widget import ProgressWidget from napari_cryoet_data_portal._reader import ( - read_annotation, + read_annotation_files, read_tomogram, ) @@ -34,16 +34,6 @@ class Resolution: indices: Tuple[int, ...] scale: float - @property - def offset(self) -> float: - """The offset due to a larger first pixel for lower resolutions. - - When visualized in napari, this ensures that the different multi-scale - layers opened separately share the same visual extent in the canvas that - starts at (-0.5, -0.5, -0.5). - """ - return (self.scale - 1) / 2 - MULTI_RESOLUTION = Resolution(name="Multi", indices=(0, 1, 2), scale=1) HIGH_RESOLUTION = Resolution(name="High", indices=(0,), scale=1) @@ -138,24 +128,11 @@ def _loadTomogram( resolution: Resolution, ) -> Generator[FullLayerData, None, None]: logger.debug("OpenWidget._loadTomogram: %s", tomogram.name) - image_data, image_attrs, _ = read_tomogram(tomogram) - # Skip indexing for multi-resolution to avoid adding any - # unnecessary nodes to the dask compute graph. - if resolution is not MULTI_RESOLUTION: - image_data = image_data[resolution.indices[0]] - # Materialize low resolution immediately on this thread to prevent napari blocking. - if resolution is LOW_RESOLUTION: - image_data = np.asarray(image_data) - # Get scale before resolution scaling for annotations. - image_scale = image_attrs["scale"] - image_attrs["scale"] = tuple( - resolution.scale * s for s in image_scale - ) - image_translate = image_attrs.get("translate", (0,) * len(image_attrs["scale"])) - image_attrs["translate"] = tuple( - resolution.offset + t for t in image_translate - ) - yield image_data, image_attrs, "image" + image_layer = read_tomogram(tomogram) + # Extract image_scale before the resolution is taken into account, + # so we can use it to align other annotations later. + image_scale = image_layer[1]["scale"] + yield _handle_image_at_resolution(image_layer, resolution) # Looking up tomogram.tomogram_voxel_spacing.annotations triggers a query # using the client from where the tomogram was found. @@ -167,21 +144,12 @@ def _loadTomogram( ) for annotation in annotations: - point_paths = tuple( - f.https_path - for f in annotation.files - if f.shape_type == "Point" - ) - if len(point_paths) > 0: - anno_data, anno_attrs, anno_type = read_annotation(annotation, tomogram=tomogram) - # Inherit scale from full resolution image so that we can pick up - # that scale when it changes. - anno_attrs["scale"] = image_scale - # Scaling points also changes the size, so adjust accordingly. - anno_attrs["size"] /= np.mean(image_scale) - yield anno_data, anno_attrs, anno_type - else: - logger.warn("Found no points annotations. Skipping.") + for layer in read_annotation_files(annotation, tomogram=tomogram): + if layer[2] == "labels": + layer = _handle_image_at_resolution(layer, resolution) + elif layer[2] == "points": + layer = _handle_points_at_scale(layer, image_scale) + yield layer def _onLayerLoaded(self, layer_data: FullLayerData) -> None: logger.debug("OpenWidget._onLayerLoaded") @@ -190,5 +158,44 @@ def _onLayerLoaded(self, layer_data: FullLayerData) -> None: self._viewer.add_image(data, **attrs) elif layer_type == "points": self._viewer.add_points(data, **attrs) + elif layer_type == "labels": + self._viewer.add_labels(data, **attrs) else: raise AssertionError(f"Unexpected {layer_type=}") + + +def _handle_image_at_resolution(layer_data: FullLayerData, resolution: Resolution) -> FullLayerData: + data, attrs, layer_type = layer_data + # Skip indexing for multi-resolution to avoid adding any + # unnecessary nodes to the dask compute graph. + if resolution is not MULTI_RESOLUTION: + data = data[resolution.indices[0]] + + # Materialize low resolution immediately on this thread to prevent napari blocking. + # Once async loading is working on a stable napari release, we could remove this. + if resolution is LOW_RESOLUTION: + data = np.asarray(data) + + # Adjust the scale and and translation based on the resolution. + image_scale = attrs["scale"] + attrs["scale"] = tuple(resolution.scale * s for s in image_scale) + # Offset the translation due to a larger first pixel for lower resolutions. + # When visualized in napari, this ensures that the different multi-scale + # layers opened separately share the same visual extent in the canvas that + # starts at some scaled version of (-0.5, -0.5, -0.5). + image_translate = attrs.get("translate", (0,) * len(image_scale)) + attrs["translate"] = tuple( + (s * (resolution.scale - 1) / 2) + t + for s, t in zip(image_scale, image_translate) + ) + return data, attrs, layer_type + + +def _handle_points_at_scale(layer_data: FullLayerData, image_scale: Tuple[float, float, float]) -> FullLayerData: + data, attrs, layer_type = layer_data + # Inherit scale from full resolution image, so that points are visually + # aligned with the image. + attrs["scale"] = image_scale + # Scaling points also changes the size, so adjust accordingly. + attrs["size"] /= np.mean(image_scale) + return data, attrs, layer_type \ No newline at end of file diff --git a/src/napari_cryoet_data_portal/_reader.py b/src/napari_cryoet_data_portal/_reader.py index 86cb318..70ca7d5 100644 --- a/src/napari_cryoet_data_portal/_reader.py +++ b/src/napari_cryoet_data_portal/_reader.py @@ -1,22 +1,34 @@ """Functions to read data from the portal into napari types.""" -from typing import Any, Dict, List, Optional, Tuple +import warnings +from typing import Any, Dict, Generator, List, Optional, Tuple import fsspec +import numpy as np import ndjson from napari_ome_zarr import napari_get_reader from npe2.types import FullLayerData, PathOrPaths, ReaderFunction -from cryoet_data_portal import Annotation, Tomogram +from cryoet_data_portal import Annotation, AnnotationFile, Tomogram +from cmap import Colormap +from napari.utils.colormaps import direct_colormap from napari_cryoet_data_portal._logging import logger +# Maps integer value of Annotation.object_id to a color. +OBJECT_COLORMAP = Colormap("colorbrewer:set1_8") +# Fallback color when ID cannot be parsed. +DEFAULT_OBJECT_COLOR = np.array(OBJECT_COLORMAP(0).rgba) -OBJECT_COLOR = { - "ribosome": "red", - "ribosome, 80 s": "red", - "fatty acid synthase": "darkblue", -} -DEFAULT_OBJECT_COLOR = "red" + +def _annotation_color(annotation: Annotation) -> np.ndarray: + """Maps an annotation to a color based on its object_id.""" + try: + object_id = int(annotation.object_id.split(":")[-1]) + except RuntimeError as e: + logger.error("Failed to parse integer from object_id: %s\%s", annotation.object_id, e) + return DEFAULT_OBJECT_COLOR + color = OBJECT_COLORMAP(object_id % len(OBJECT_COLORMAP.color_stops)) + return np.array(color.rgba) def tomogram_ome_zarr_reader(path: PathOrPaths) -> Optional[ReaderFunction]: @@ -128,8 +140,8 @@ def points_annotations_reader(path: PathOrPaths) -> Optional[ReaderFunction]: -------- >>> annotation_dir = 's3://cryoet-data-portal-public/10000/TS_026/Tomograms/VoxelSpacing13.48/Annotations' >>> path = ( - f'{annotation_dir}/sara_goetz-ribosome-1.0.json', - f'{annotation_dir}/sara_goetz-fatty_acid_synthase-1.0.json', + f'{annotation_dir}/sara_goetz-ribosome-1.0.ndjson', + f'{annotation_dir}/sara_goetz-fatty_acid_synthase-1.0.ndjson', ) >>> reader = points_annotations_reader(path) >>> layers = reader(path) @@ -167,9 +179,11 @@ def read_points_annotations_ndjson(path: str) -> FullLayerData: attributes = { "name": "annotations", "size": 14, - "face_color": "red", + "face_color": DEFAULT_OBJECT_COLOR, "opacity": 0.5, - "out_of_slice_display": True, + # Disable out-of-slice display because of: + # https://github.com/napari/napari/issues/6914 + "out_of_slice_display": False, } return data, attributes, "points" @@ -197,6 +211,10 @@ def read_annotation(annotation: Annotation, *, tomogram: Optional[Tomogram] = No >>> data, attrs, _ = read_annotation(annotation) >>> points = Points(data, **attrs) """ + warnings.warn( + "read_annotation is deprecated from v0.4.0 because of Annotation schema changes. " + "Use read_annotation_files instead.", + category=DeprecationWarning) point_paths = tuple( f.https_path for f in annotation.files @@ -211,10 +229,74 @@ def read_annotation(annotation: Annotation, *, tomogram: Optional[Tomogram] = No else: attributes["name"] = f"{tomogram.name}-{name}" attributes["metadata"] = annotation.to_dict() - attributes["face_color"] = OBJECT_COLOR.get(name.lower(), DEFAULT_OBJECT_COLOR) + attributes["face_color"] = _annotation_color(annotation) + return data, attributes, layer_type + + +def read_annotation_files(annotation: Annotation, *, tomogram: Optional[Tomogram] = None) -> Generator[FullLayerData, None, None]: + """Reads multiple annotation layers. + + Parameters + ---------- + annotation : Annotation + The tomogram annotation. + tomogram : Tomogram, optional + The associated tomogram, which may be used for other metadata. + + Yields + ------- + napari layer data tuple + The data, attributes, and type name of the layer that would be + returned by `Points.as_layer_data_tuple` or `Labels.as_layer_data_tuple`. + + Examples + -------- + >>> client = Client() + >>> annotation = client.find_one(Annotation) + >>> for data, attrs, typ in read_annotation_files(annotation): + layer = Layer.create(data, attrs, typ) + """ + for f in annotation.files: + if (f.shape_type in ("Point", "OrientedPoint")) and (f.format == "ndjson"): + yield _read_points_annotation_file(f, anno=annotation, tomogram=tomogram) + elif (f.shape_type == "SegmentationMask") and (f.format == "zarr"): + yield _read_labels_annotation_file(f, anno=annotation, tomogram=tomogram) + else: + logger.warn("Found unsupported annotation file: %s, %s. Skipping.", f.shape_type, f.format) + + +def _read_points_annotation_file(anno_file: AnnotationFile, *, anno: Annotation, tomogram: Optional[Tomogram]) -> FullLayerData: + assert anno_file.shape_type in ("Point", "OrientedPoint") + assert anno_file.format == "ndjson" + data, attributes, layer_type = read_points_annotations_ndjson(anno_file.https_path) + name = anno.object_name + if tomogram is None: + attributes["name"] = name + else: + attributes["name"] = f"{tomogram.name}-{name}" + attributes["metadata"] = anno_file.to_dict() + attributes["face_color"] = _annotation_color(anno) return data, attributes, layer_type +def _read_labels_annotation_file(anno_file: AnnotationFile, *, anno: Annotation, tomogram: Optional[Tomogram]) -> FullLayerData: + assert anno_file.shape_type == "SegmentationMask" + assert anno_file.format == "zarr" + data, attributes, _ = read_tomogram_ome_zarr(anno_file.https_path) + name = anno.object_name + if tomogram is None: + attributes["name"] = name + else: + attributes["name"] = f"{tomogram.name}-{name}" + attributes["metadata"] = anno_file.to_dict() + attributes["opacity"] = 0.5 + attributes["colormap"] = direct_colormap({ + None: np.zeros(4), + 1: _annotation_color(anno), + }) + return data, attributes, "labels" + + def _read_points_data( path: str, ) -> List[Tuple[float, float, float]]: @@ -223,7 +305,7 @@ def _read_points_data( sub_data = [ _annotation_to_point(annotation) for annotation in ndjson.load(f) - if annotation["type"] == "point" + if annotation["type"] in ("point", "orientedPoint") ] data.extend(sub_data) return data diff --git a/src/napari_cryoet_data_portal/_sample_data.py b/src/napari_cryoet_data_portal/_sample_data.py index 225396e..735f94d 100644 --- a/src/napari_cryoet_data_portal/_sample_data.py +++ b/src/napari_cryoet_data_portal/_sample_data.py @@ -36,12 +36,12 @@ def _read_tomogram_from_10000(name: str) -> List[FullLayerData]: ribosome_annotations = [ item for item in annotations - if item.object_name.lower() == "cytosolic ribosome" + if "cytosolic ribosome" in item.object_name.lower() ].pop() fas_annotations = [ item for item in annotations - if item.object_name.lower() == "fatty acid synthase" + if "fatty acid synthase" in item.object_name.lower() ].pop() ribosome_points = read_annotation(ribosome_annotations, tomogram=tomogram) fatty_acid_points = read_annotation(fas_annotations, tomogram=tomogram) diff --git a/src/napari_cryoet_data_portal/_tests/conftest.py b/src/napari_cryoet_data_portal/_tests/conftest.py index 1ade18f..516ff68 100644 --- a/src/napari_cryoet_data_portal/_tests/conftest.py +++ b/src/napari_cryoet_data_portal/_tests/conftest.py @@ -1,5 +1,6 @@ import pytest -from cryoet_data_portal import Client, Dataset, Tomogram + +from cryoet_data_portal import Annotation, AnnotationFile, Client, Dataset, Tomogram @pytest.fixture() @@ -21,3 +22,9 @@ def tomogram(client: Client) -> Tomogram: Tomogram.https_omezarr_dir.like("%13.480%"), ], ).pop() + + +@pytest.fixture() +def annotation_with_points(client: Client) -> Annotation: + anno_file = AnnotationFile.find(client, [AnnotationFile.shape_type == "Point"]).pop() + return anno_file.annotation diff --git a/src/napari_cryoet_data_portal/_tests/test_open_widget.py b/src/napari_cryoet_data_portal/_tests/test_open_widget.py index db2a6d2..70f1bbf 100644 --- a/src/napari_cryoet_data_portal/_tests/test_open_widget.py +++ b/src/napari_cryoet_data_portal/_tests/test_open_widget.py @@ -36,7 +36,7 @@ def test_set_tomogram_adds_layers_to_viewer(widget: OpenWidget, tomogram: Tomogr with qtbot.waitSignal(widget._progress.finished, timeout=30000): widget.setTomogram(tomogram) - assert len(widget._viewer.layers) == 3 + assert len(widget._viewer.layers) > 1 def test_set_tomogram_adds_layers_to_viewer_without_clearing_existing(widget: OpenWidget, tomogram: Tomogram, qtbot: QtBot): @@ -48,5 +48,5 @@ def test_set_tomogram_adds_layers_to_viewer_without_clearing_existing(widget: Op with qtbot.waitSignal(widget._progress.finished, timeout=30000): widget.setTomogram(tomogram) - assert len(widget._viewer.layers) == 4 + assert len(widget._viewer.layers) > 1 \ No newline at end of file diff --git a/src/napari_cryoet_data_portal/_tests/test_reader.py b/src/napari_cryoet_data_portal/_tests/test_reader.py index 2657253..b5a5309 100644 --- a/src/napari_cryoet_data_portal/_tests/test_reader.py +++ b/src/napari_cryoet_data_portal/_tests/test_reader.py @@ -1,9 +1,14 @@ +import pytest from typing import Callable +import numpy as np +from cryoet_data_portal import Annotation from napari import Viewer from napari.layers import Points -from napari_cryoet_data_portal import ( +from napari_cryoet_data_portal._reader import ( + read_annotation, + read_annotation_files, read_points_annotations_ndjson, read_tomogram_ome_zarr, ) @@ -22,7 +27,7 @@ def test_read_tomogram_ome_zarr(): assert data[0].shape == (1000, 928, 960) assert data[1].shape == (500, 464, 480) assert data[2].shape == (250, 232, 240) - assert attrs["scale"] == (1, 1, 1) + np.testing.assert_allclose(attrs["scale"], (13.48, 13.48, 13.48), atol=0.01) assert layer_type == "image" @@ -44,3 +49,22 @@ def test_open_points_annotations(make_napari_viewer: Callable[[], Viewer]): assert len(layers) == 1 assert isinstance(layers[0], Points) + + +def test_read_annotation(annotation_with_points: Annotation): + with pytest.warns(DeprecationWarning): + data, attrs, layer_type = read_annotation(annotation_with_points) + + assert len(data) > 0 + assert len(attrs["name"]) > 0 + assert layer_type == "points" + + +def test_read_annotation_files(annotation_with_points: Annotation): + layers = list(read_annotation_files(annotation_with_points)) + + assert len(layers) == 1 + data, attrs, layer_type = layers[0] + assert len(data) > 0 + assert len(attrs["name"]) > 0 + assert layer_type == "points"