diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d63d4d6..d7260e43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,27 +12,69 @@ and this project adheres to [Semantic Versioning][]. ### Added +#### Major + +- Implemented support in SpatialData for storing multiple tables. These tables can annotate a SpatialElement but not + necessarily so. +- Added SQL like joins that can be executed by calling one public function `join_sdata_spatialelement_table`. The + following joins are supported: `left`, `left_exclusive`, `right`, `right_exclusive` and `inner`. The function has + an option to match rows. For `left` only matching `left` is supported and for `right` join only `right` matching of + rows is supported. Not all joins are supported for `Labels` elements. +- Added function `match_element_to_table` which allows the user to perform a right join of `SpatialElement`(s) with a + table with rows matching the row order in the table. +- Increased in-memory vs on-disk control: changes performed in-memory (e.g. adding a new image) are not automatically + performed on-disk. + +#### Minor + +- Added public helper function get_table_keys in spatialdata.models to retrieve annotation information of a given + table. +- Added public helper function check_target_region_column_symmetry in spatialdata.models to check whether annotation + metadata in table.uns['spatialdata_attrs'] corresponds with respective columns in table.obs. +- Added function validate_table_in_spatialdata in SpatialData to validate the annotation target of a table being + present in the SpatialData object. +- Added function get_annotated_regions in SpatialData to get the regions annotated by a given table. +- Added function get_region_key_column in SpatialData to get the region_key column in table.obs. +- Added function get_instance_key_column in SpatialData to get the instance_key column in table.obs. +- Added function set_table_annotates_spatialelement in SpatialData to either set or change the annotation metadata of + a table in a given SpatialData object. +- Added table_name parameter to the aggregate function to allow users to give a custom table name to table resulting + from aggregation. +- Added table_name parameter to the get_values function. +- Added tables property in SpatialData. +- Added tables setter in SpatialData. +- Added gen_spatial_elements generator in SpatialData to generate the SpatialElements in a given SpatialData object. +- Added gen_elements generator in SpatialData to generate elements of a SpatialData object including tables. - added SpatialData.subset() API - added SpatialData.locate_element() API -- added transform_to_data_extent() +- added utils function: transform_to_data_extent() - added utils function: are_extents_equal() - added utils function: postpone_transformation() - added utils function: remove_transformations_to_coordinate_system() - added utils function: get_centroids() +- added utils function: deepcopy() +- added operation: to_circles() +- added testing utilities: assert_spatial_data_objects_are_identical(), assert_elements_are_identical(), + assert_elements_dict_are_identical() -### Minor +### Changed -- improved usability and robustness of sdata.write() when overwrite=True @aeisenbarth +#### Major + +- refactored data loader for deep learning + +#### Minor + +- Changed the string representation of SpatialData to reflect the changes in regard to multiple tables. ### Fixed +#### Major + +- improved usability and robustness of sdata.write() when overwrite=True @aeisenbarth - generalized queries to any combination of 2D/3D data and 2D/3D query region #409 - fixed warnings for categorical dtypes in tables in TableModel and PointsModel -#### Minor - -- refactored data loader for deep learning - ## [0.0.14] - 2023-10-11 ### Added diff --git a/docs/api.md b/docs/api.md index 2388c9f6..c3e39478 100644 --- a/docs/api.md +++ b/docs/api.md @@ -28,10 +28,14 @@ Operations on `SpatialData` objects. get_values get_extent get_centroids + join_sdata_spatialelement_table + match_element_to_table + get_centroids match_table_to_element concatenate transform rasterize + to_circles aggregate ``` @@ -43,6 +47,7 @@ Operations on `SpatialData` objects. unpad_raster are_extents_equal + deepcopy ``` ## Models @@ -139,3 +144,16 @@ The transformations that can be defined between elements and coordinate systems save_transformations get_dask_backing_files ``` + +## Testing utilities + +```{eval-rst} +.. currentmodule:: spatialdata.testing + +.. autosummary:: + :toctree: generated + + assert_spatial_data_objects_are_identical + assert_elements_are_identical + assert_elements_dict_are_identical +``` diff --git a/docs/conf.py b/docs/conf.py index b394f0b0..8262f9d4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -133,10 +133,11 @@ html_title = project_name html_logo = "_static/img/spatialdata_horizontal.png" -# html_theme_options = { -# "repository_url": repository_url, -# "use_repository_button": True, -# } +html_theme_options = { + "navigation_with_keys": True, + # "repository_url": repository_url, + # "use_repository_button": True, +} pygments_style = "default" diff --git a/pyproject.toml b/pyproject.toml index a6d8de38..8f020ada 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,8 @@ dependencies = [ "xarray-spatial>=0.3.5", "tqdm", "fsspec<=2023.6", - "dask<=2024.2.1" + "dask<=2024.2.1", + "pooch", ] [project.optional-dependencies] @@ -58,6 +59,7 @@ docs = [ # For notebooks "ipython>=8.6.0", "sphinx-copybutton", + "sphinx-pytest", ] test = [ "pytest", diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index ced697ca..d899a6ec 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -17,11 +17,14 @@ "dataloader", "concatenate", "rasterize", + "to_circles", "transform", "aggregate", "bounding_box_query", "polygon_query", "get_values", + "join_sdata_spatialelement_table", + "match_element_to_table", "match_table_to_element", "SpatialData", "get_extent", @@ -31,17 +34,25 @@ "save_transformations", "get_dask_backing_files", "are_extents_equal", + "deepcopy", ] from spatialdata import dataloader, models, transformations +from spatialdata._core._deepcopy import deepcopy from spatialdata._core.centroids import get_centroids from spatialdata._core.concatenate import concatenate from spatialdata._core.data_extent import are_extents_equal, get_extent from spatialdata._core.operations.aggregate import aggregate from spatialdata._core.operations.rasterize import rasterize from spatialdata._core.operations.transform import transform +from spatialdata._core.operations.vectorize import to_circles from spatialdata._core.query._utils import circles_to_polygons, get_bounding_box_corners -from spatialdata._core.query.relational_query import get_values, match_table_to_element +from spatialdata._core.query.relational_query import ( + get_values, + join_sdata_spatialelement_table, + match_element_to_table, + match_table_to_element, +) from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query from spatialdata._core.spatialdata import SpatialData from spatialdata._io._utils import get_dask_backing_files, save_transformations diff --git a/src/spatialdata/_core/_deepcopy.py b/src/spatialdata/_core/_deepcopy.py new file mode 100644 index 00000000..e3634df4 --- /dev/null +++ b/src/spatialdata/_core/_deepcopy.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from copy import deepcopy as _deepcopy +from functools import singledispatch + +from anndata import AnnData +from dask.array.core import Array as DaskArray +from dask.array.core import from_array +from dask.dataframe.core import DataFrame as DaskDataFrame +from geopandas import GeoDataFrame +from multiscale_spatial_image import MultiscaleSpatialImage +from spatial_image import SpatialImage + +from spatialdata._core.spatialdata import SpatialData +from spatialdata._utils import multiscale_spatial_image_from_data_tree +from spatialdata.models._utils import SpatialElement +from spatialdata.models.models import Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, get_model + + +@singledispatch +def deepcopy(element: SpatialData | SpatialElement | AnnData) -> SpatialData | SpatialElement | AnnData: + """ + Deepcopy a SpatialData or SpatialElement object. + + Deepcopy will load the data in memory. Using this function for large Dask-backed objects is discouraged. In that + case, please save the SpatialData object to a different disk location and read it back again. + + Parameters + ---------- + element + The SpatialData or SpatialElement object to deepcopy + + Returns + ------- + A deepcopy of the SpatialData or SpatialElement object + + Notes + ----- + The order of the columns for a deepcopied points element may be differ from the original one, please see more here: + https://github.com/scverse/spatialdata/issues/486 + """ + raise RuntimeError(f"Wrong type for deepcopy: {type(element)}") + + +# In the implementations below, when the data is loaded from Dask, we first use compute() and then we deepcopy the data. +# This leads to double copying the data, but since we expect the data to be small, this is acceptable. +@deepcopy.register(SpatialData) +def _(sdata: SpatialData) -> SpatialData: + elements_dict = {} + for _, element_name, element in sdata.gen_elements(): + elements_dict[element_name] = deepcopy(element) + return SpatialData.from_elements_dict(elements_dict) + + +@deepcopy.register(SpatialImage) +def _(element: SpatialImage) -> SpatialImage: + model = get_model(element) + if isinstance(element.data, DaskArray): + element = element.compute() + if model in [Image2DModel, Image3DModel]: + return model.parse(element.copy(deep=True), c_coords=element["c"]) # type: ignore[call-arg] + assert model in [Labels2DModel, Labels3DModel] + return model.parse(element.copy(deep=True)) + + +@deepcopy.register(MultiscaleSpatialImage) +def _(element: MultiscaleSpatialImage) -> MultiscaleSpatialImage: + # the complexity here is due to the fact that the parsers don't accept MultiscaleSpatialImage types and that we need + # to convert the DataTree to a MultiscaleSpatialImage. This will be simplified once we support + # multiscale_spatial_image 1.0.0 + model = get_model(element) + for key in element: + ds = element[key].ds + assert len(ds) == 1 + variable = ds.__iter__().__next__() + if isinstance(element[key][variable].data, DaskArray): + element[key][variable] = element[key][variable].compute() + msi = multiscale_spatial_image_from_data_tree(element.copy(deep=True)) + for key in msi: + ds = msi[key].ds + variable = ds.__iter__().__next__() + msi[key][variable].data = from_array(msi[key][variable].data) + element[key][variable].data = from_array(element[key][variable].data) + assert model in [Image2DModel, Image3DModel, Labels2DModel, Labels3DModel] + model().validate(msi) + return msi + + +@deepcopy.register(GeoDataFrame) +def _(gdf: GeoDataFrame) -> GeoDataFrame: + new_gdf = _deepcopy(gdf) + # temporary fix for https://github.com/scverse/spatialdata/issues/286. + new_attrs = _deepcopy(gdf.attrs) + new_gdf.attrs = new_attrs + return new_gdf + + +@deepcopy.register(DaskDataFrame) +def _(df: DaskDataFrame) -> DaskDataFrame: + return PointsModel.parse(df.compute().copy(deep=True)) + + +@deepcopy.register(AnnData) +def _(adata: AnnData) -> AnnData: + return adata.copy() diff --git a/src/spatialdata/_core/_elements.py b/src/spatialdata/_core/_elements.py index f9abe4d9..862a380b 100644 --- a/src/spatialdata/_core/_elements.py +++ b/src/spatialdata/_core/_elements.py @@ -7,6 +7,7 @@ from typing import Any from warnings import warn +from anndata import AnnData from dask.dataframe.core import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame @@ -20,6 +21,7 @@ Labels3DModel, PointsModel, ShapesModel, + TableModel, get_axes_names, get_model, ) @@ -103,3 +105,13 @@ def __setitem__(self, key: str, value: DaskDataFrame) -> None: raise TypeError(f"Unknown element type with schema: {schema!r}.") PointsModel().validate(value) super().__setitem__(key, value) + + +class Tables(Elements): + def __setitem__(self, key: str, value: AnnData) -> None: + self._check_key(key, self.keys(), self._shared_keys) + schema = get_model(value) + if schema != TableModel: + raise TypeError(f"Unknown element type with schema: {schema!r}.") + TableModel().validate(value) + super().__setitem__(key, value) diff --git a/src/spatialdata/_core/_utils.py b/src/spatialdata/_core/_utils.py new file mode 100644 index 00000000..1c22c802 --- /dev/null +++ b/src/spatialdata/_core/_utils.py @@ -0,0 +1,22 @@ +from spatialdata._core.spatialdata import SpatialData + + +def _find_common_table_keys(sdatas: list[SpatialData]) -> set[str]: + """ + Find table keys present in more than one SpatialData object. + + Parameters + ---------- + sdatas + A list of SpatialData objects. + + Returns + ------- + A set of common keys that are present in the tables of more than one SpatialData object. + """ + common_keys = set(sdatas[0].tables.keys()) + + for sdata in sdatas[1:]: + common_keys.intersection_update(sdata.tables.keys()) + + return common_keys diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index eadc1d6b..8312d660 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -1,12 +1,15 @@ from __future__ import annotations +from collections import defaultdict from copy import copy # Should probably go up at the top from itertools import chain from typing import Any +from warnings import warn import numpy as np from anndata import AnnData +from spatialdata._core._utils import _find_common_table_keys from spatialdata._core.spatialdata import SpatialData from spatialdata.models import TableModel @@ -23,6 +26,8 @@ def _concatenate_tables( ) -> AnnData: import anndata as ad + if not all(TableModel.ATTRS_KEY in table.uns for table in tables): + raise ValueError("Not all tables are annotating a spatial element") region_keys = [table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] for table in tables] instance_keys = [table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] for table in tables] regions = [table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] for table in tables] @@ -71,6 +76,7 @@ def concatenate( sdatas: list[SpatialData], region_key: str | None = None, instance_key: str | None = None, + concatenate_tables: bool = False, **kwargs: Any, ) -> SpatialData: """ @@ -85,6 +91,8 @@ def concatenate( If all region_keys are the same, the `region_key` is used. instance_key The key to use for the instance column in the concatenated object. + concatenate_tables + Whether to merge the tables in case of having the same element name. kwargs See :func:`anndata.concat` for more details. @@ -108,16 +116,43 @@ def concatenate( assert isinstance(sdatas, list), "sdatas must be a list" assert len(sdatas) > 0, "sdatas must be a non-empty list" - merged_table = _concatenate_tables( - [sdata.table for sdata in sdatas if sdata.table is not None], region_key, instance_key, **kwargs - ) + if not concatenate_tables: + key_counts: dict[str, int] = defaultdict(int) + for sdata in sdatas: + for k in sdata.tables: + key_counts[k] += 1 + + if any(value > 1 for value in key_counts.values()): + warn( + "Duplicate table names found. Tables will be added with integer suffix. Set concatenate_tables to True" + "if concatenation is wished for instead.", + UserWarning, + stacklevel=2, + ) + merged_tables = {} + count_dict: dict[str, int] = defaultdict(int) + + for sdata in sdatas: + for k, v in sdata.tables.items(): + new_key = f"{k}_{count_dict[k]}" if key_counts[k] > 1 else k + count_dict[k] += 1 + merged_tables[new_key] = v + else: + common_keys = _find_common_table_keys(sdatas) + merged_tables = {} + for sdata in sdatas: + for k, v in sdata.tables.items(): + if k in common_keys and merged_tables.get(k) is not None: + merged_tables[k] = _concatenate_tables([merged_tables[k], v], region_key, instance_key, **kwargs) + else: + merged_tables[k] = v return SpatialData( images=merged_images, labels=merged_labels, points=merged_points, shapes=merged_shapes, - table=merged_table, + tables=merged_tables, ) diff --git a/src/spatialdata/_core/operations/_utils.py b/src/spatialdata/_core/operations/_utils.py index d34d536f..a5fbd7be 100644 --- a/src/spatialdata/_core/operations/_utils.py +++ b/src/spatialdata/_core/operations/_utils.py @@ -98,7 +98,7 @@ def transform_to_data_extent( scale_to_target_d["z"] = target_depth / sizes[data_extent_axes.index("z")] scale_to_target = Scale([scale_to_target_d[ax] for ax in data_extent_axes], axes=data_extent_axes) - for el in sdata_vector_transformed._gen_elements_values(): + for el in sdata_vector_transformed._gen_spatial_element_values(): t = get_transformation(el, to_coordinate_system=coordinate_system) assert isinstance(t, BaseTransformation) sequence = Sequence([t, translation_to_origin, scale_to_target]) @@ -112,7 +112,7 @@ def transform_to_data_extent( **sdata_vector_transformed_inplace.points, } - for _, element_name, element in sdata_raster._gen_elements(): + for _, element_name, element in sdata_raster.gen_spatial_elements(): if isinstance(element, (MultiscaleSpatialImage, SpatialImage)): rasterized = rasterize( element, @@ -128,9 +128,9 @@ def transform_to_data_extent( sdata_to_return_elements[element_name] = rasterized else: sdata_to_return_elements[element_name] = element - if sdata.table is not None: - sdata_to_return_elements["table"] = sdata.table if not maintain_positioning: for el in sdata_to_return_elements.values(): set_transformation(el, transformation={coordinate_system: Identity()}, set_all=True) + for k, v in sdata.tables.items(): + sdata_to_return_elements[k] = v.copy() return SpatialData.from_elements_dict(sdata_to_return_elements) diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index eaadc92c..e1a13b6b 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -64,6 +64,7 @@ def aggregate( region_key: str = "region", instance_key: str = "instance_id", deepcopy: bool = True, + table_name: str | None = None, **kwargs: Any, ) -> SpatialData: """ @@ -125,6 +126,8 @@ def aggregate( deepcopy Whether to deepcopy the shapes in the returned `SpatialData` object. If the shapes are large (e.g. large multiscale labels), you may consider disabling the deepcopy to use a lazy Dask representation. + table_name + The table optionally containing the value_key and the name of the table in the returned `SpatialData` object. kwargs Additional keyword arguments to pass to :func:`xrspatial.zonal_stats`. @@ -200,6 +203,7 @@ def aggregate( value_key=value_key, agg_func=agg_func, fractions=fractions, + table_name=table_name, ) # eventually remove the colum of ones if it was added @@ -214,10 +218,12 @@ def aggregate( if adata is None: raise NotImplementedError(f"Cannot aggregate {values_type} by {by_type}") + table_name = table_name if table_name is not None else "table" # create a SpatialData object with the aggregated table and the "by" shapes shapes_name = by if isinstance(by, str) else "by" return _create_sdata_from_table_and_shapes( table=adata, + table_name=table_name, shapes_name=shapes_name, shapes=by_, region_key=region_key, @@ -228,15 +234,23 @@ def aggregate( def _create_sdata_from_table_and_shapes( table: ad.AnnData, + table_name: str, shapes: GeoDataFrame | SpatialImage | MultiscaleSpatialImage, shapes_name: str, region_key: str, instance_key: str, deepcopy: bool, ) -> SpatialData: - from spatialdata._utils import _deepcopy_geodataframe - - table.obs[instance_key] = table.obs_names.copy() + from spatialdata._core._deepcopy import deepcopy as _deepcopy + + shapes_index_dtype = shapes.index.dtype if isinstance(shapes, GeoDataFrame) else shapes.dtype + try: + table.obs[instance_key] = table.obs_names.copy().astype(shapes_index_dtype) + except ValueError as err: + raise TypeError( + f"Instance key column dtype in table resulting from aggregation cannot be cast to the dtype of" + f"element {shapes_name}.index" + ) from err table.obs[region_key] = shapes_name table = TableModel.parse(table, region=shapes_name, region_key=region_key, instance_key=instance_key) @@ -245,9 +259,9 @@ def _create_sdata_from_table_and_shapes( table.obs[instance_key] = table.obs[instance_key].astype(int) if deepcopy: - shapes = _deepcopy_geodataframe(shapes) + shapes = _deepcopy(shapes) - return SpatialData.from_elements_dict({shapes_name: shapes, "": table}) + return SpatialData.from_elements_dict({shapes_name: shapes, table_name: table}) def _aggregate_image_by_labels( @@ -317,6 +331,7 @@ def _aggregate_shapes( by: gpd.GeoDataFrame, values_sdata: SpatialData | None = None, values_element_name: str | None = None, + table_name: str | None = None, value_key: str | list[str] | None = None, agg_func: str | list[str] = "count", fractions: bool = False, @@ -343,13 +358,17 @@ def _aggregate_shapes( Column in value dataframe to perform aggregation on. agg_func Aggregation function to apply over grouped values. Passed to pandas.DataFrame.groupby.agg. + table_name + Name of the table optionally containing the value_key column. """ from spatialdata.models import points_dask_dataframe_to_geopandas assert value_key is not None assert (values_sdata is None) == (values_element_name is None) if values_sdata is not None: - actual_values = get_values(value_key=value_key, sdata=values_sdata, element_name=values_element_name) + actual_values = get_values( + value_key=value_key, sdata=values_sdata, element_name=values_element_name, table_name=table_name + ) else: actual_values = get_values(value_key=value_key, element=values) assert isinstance(actual_values, pd.DataFrame), f"Expected pd.DataFrame, got {type(actual_values)}" diff --git a/src/spatialdata/_core/operations/rasterize.py b/src/spatialdata/_core/operations/rasterize.py index e682c88c..f37f9068 100644 --- a/src/spatialdata/_core/operations/rasterize.py +++ b/src/spatialdata/_core/operations/rasterize.py @@ -230,7 +230,7 @@ def _( ) new_name = f"{name}_rasterized_{element_type}" new_images[new_name] = rasterized - return SpatialData(images=new_images, table=sdata.table) + return SpatialData(images=new_images, tables=sdata.tables) # get xdata diff --git a/src/spatialdata/_core/operations/vectorize.py b/src/spatialdata/_core/operations/vectorize.py new file mode 100644 index 00000000..7f496b1c --- /dev/null +++ b/src/spatialdata/_core/operations/vectorize.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from functools import singledispatch + +import numpy as np +import pandas as pd +from geopandas import GeoDataFrame +from multiscale_spatial_image import MultiscaleSpatialImage +from shapely import MultiPolygon, Point, Polygon +from spatial_image import SpatialImage + +from spatialdata._core.centroids import get_centroids +from spatialdata._core.operations.aggregate import aggregate +from spatialdata.models import ( + Image2DModel, + Image3DModel, + Labels3DModel, + ShapesModel, + SpatialElement, + get_axes_names, + get_model, +) +from spatialdata.transformations.operations import get_transformation +from spatialdata.transformations.transformations import Identity + +INTRINSIC_COORDINATE_SYSTEM = "__intrinsic" + + +@singledispatch +def to_circles( + data: SpatialElement, +) -> GeoDataFrame: + """ + Convert a set of geometries (2D/3D labels, 2D shapes) to approximated circles/spheres. + + Parameters + ---------- + data + The SpatialElement representing the geometries to approximate as circles/spheres. + + Returns + ------- + The approximated circles/spheres. + + Notes + ----- + The approximation is done by computing the centroids and the area/volume of the geometries. The geometries are then + replaced by circles/spheres with the same centroids and area/volume. + """ + raise RuntimeError(f"Unsupported type: {type(data)}") + + +@to_circles.register(SpatialImage) +@to_circles.register(MultiscaleSpatialImage) +def _( + element: SpatialImage | MultiscaleSpatialImage, +) -> GeoDataFrame: + model = get_model(element) + if model in (Image2DModel, Image3DModel): + raise RuntimeError("Cannot apply to_circles() to images.") + if model == Labels3DModel: + raise RuntimeError("to_circles() is not yet implemented for 3D labels.") + + # reduce to the single scale case + if isinstance(element, MultiscaleSpatialImage): + element_single_scale = SpatialImage(element["scale0"].values().__iter__().__next__()) + else: + element_single_scale = element + shape = element_single_scale.shape + + # find the area of labels, estimate the radius from it; find the centroids + axes = get_axes_names(element) + model = Image3DModel if "z" in axes else Image2DModel + ones = model.parse(np.ones((1,) + shape), dims=("c",) + axes) + aggregated = aggregate(values=ones, by=element_single_scale, agg_func="sum")["table"] + areas = aggregated.X.todense().A1.reshape(-1) + aobs = aggregated.obs + aobs["areas"] = areas + aobs["radius"] = np.sqrt(areas / np.pi) + + # get the centroids; remove the background if present (the background is not considered during aggregation) + centroids = _get_centroids(element) + if 0 in centroids.index: + centroids = centroids.drop(index=0) + # instance_id is the key used by the aggregation APIs + aobs.index = aobs["instance_id"] + aobs.index.name = None + assert len(aobs) == len(centroids) + obs = pd.merge(aobs, centroids, left_index=True, right_index=True, how="inner") + assert len(obs) == len(centroids) + return _make_circles(element, obs) + + +@to_circles.register(GeoDataFrame) +def _( + element: GeoDataFrame, +) -> GeoDataFrame: + if isinstance(element.geometry.iloc[0], (Polygon, MultiPolygon)): + radius = np.sqrt(element.geometry.area / np.pi) + centroids = _get_centroids(element) + obs = pd.DataFrame({"radius": radius}) + obs = pd.merge(obs, centroids, left_index=True, right_index=True, how="inner") + return _make_circles(element, obs) + assert isinstance(element.geometry.iloc[0], Point), ( + f"Unsupported geometry type: " f"{type(element.geometry.iloc[0])}" + ) + return element + + +def _get_centroids(element: SpatialElement) -> pd.DataFrame: + d = get_transformation(element, get_all=True) + assert isinstance(d, dict) + if INTRINSIC_COORDINATE_SYSTEM in d: + raise RuntimeError(f"The name {INTRINSIC_COORDINATE_SYSTEM} is reserved.") + d[INTRINSIC_COORDINATE_SYSTEM] = Identity() + centroids = get_centroids(element, coordinate_system=INTRINSIC_COORDINATE_SYSTEM).compute() + del d[INTRINSIC_COORDINATE_SYSTEM] + return centroids + + +def _make_circles(element: SpatialImage | MultiscaleSpatialImage | GeoDataFrame, obs: pd.DataFrame) -> GeoDataFrame: + spatial_axes = sorted(get_axes_names(element)) + centroids = obs[spatial_axes].values + transformations = get_transformation(element, get_all=True) + assert isinstance(transformations, dict) + return ShapesModel.parse( + centroids, + geometry=0, + index=obs.index, + radius=obs["radius"].values, + transformations=transformations.copy(), + ) + + +# TODO: depending of the implementation, add a parameter to control the degree of approximation of the constructed +# polygons/multipolygons +@singledispatch +def to_polygons( + data: SpatialElement, + target_coordinate_system: str, +) -> GeoDataFrame: + """ + Convert a set of geometries (2D labels, 2D shapes) to approximated 2D polygons/multypolygons. + + Parameters + ---------- + data + The SpatialElement representing the geometries to approximate as 2D polygons/multipolygons. + target_coordinate_system + The coordinate system to which the geometries to consider should be transformed. + + Returns + ------- + The approximated 2D polygons/multipolygons in the specified coordinate system. + """ + raise RuntimeError("Unsupported type: {type(data)}") diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index 25e8caa9..15fbe5c9 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -1,8 +1,13 @@ from __future__ import annotations +from typing import Any + import geopandas as gpd +from anndata import AnnData from xarray import DataArray +from spatialdata._core._elements import Tables +from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata._utils import Number, _parse_list_into_array @@ -78,3 +83,39 @@ def get_bounding_box_corners( ], coords={"corner": range(8), "axis": list(axes)}, ) + + +def _get_filtered_or_unfiltered_tables( + filter_table: bool, elements: dict[str, Any], sdata: SpatialData +) -> dict[str, AnnData] | Tables: + """ + Get the tables in a SpatialData object. + + The tables of the SpatialData object can either be filtered to only include the tables that annotate an element in + elements or all tables are returned. + + Parameters + ---------- + filter_table + Specifies whether to filter the tables to only include tables that annotate elements in the retrieved + SpatialData object of the query. + elements + A dictionary containing the elements to use for filtering the tables. + sdata + The SpatialData object that contains the tables to filter. + + Returns + ------- + A dictionary containing the filtered or unfiltered tables based on the value of the 'filter_table' parameter. + + """ + if filter_table: + from spatialdata._core.query.relational_query import _filter_table_by_elements + + return { + name: filtered_table + for name, table in sdata.tables.items() + if (filtered_table := _filter_table_by_elements(table, elements)) and len(filtered_table) != 0 + } + + return sdata.tables diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index d75897e7..46713e64 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1,7 +1,12 @@ from __future__ import annotations +import math +import warnings +from collections import defaultdict from dataclasses import dataclass -from typing import Any +from enum import Enum +from functools import partial +from typing import Any, Literal import dask.array as da import numpy as np @@ -21,10 +26,35 @@ SpatialElement, TableModel, get_model, + get_table_keys, ) -def _filter_table_by_coordinate_system(table: AnnData | None, coordinate_system: str | list[str]) -> AnnData | None: +def _get_element_annotators(sdata: SpatialData, element_name: str) -> set[str]: + """ + Retrieve names of tables that annotate a SpatialElement in a SpatialData object. + + Parameters + ---------- + sdata + SpatialData object. + element_name + The name of the SpatialElement. + + Returns + ------- + The names of the tables annotating the SpatialElement. + """ + table_names = set() + for name, table in sdata.tables.items(): + if table.uns.get(TableModel.ATTRS_KEY): + regions, _, _ = get_table_keys(table) + if element_name in regions: + table_names.add(name) + return table_names + + +def _filter_table_by_element_names(table: AnnData | None, element_names: str | list[str]) -> AnnData | None: """ Filter an AnnData table to keep only the rows that are in the coordinate system. @@ -32,23 +62,38 @@ def _filter_table_by_coordinate_system(table: AnnData | None, coordinate_system: ---------- table The table to filter; if None, returns None - coordinate_system - The coordinate system to keep + element_names + The element_names to keep in the tables obs.region column Returns ------- The filtered table, or None if the input table was None """ - if table is None: + if table is None or not table.uns.get(TableModel.ATTRS_KEY): return None table_mapping_metadata = table.uns[TableModel.ATTRS_KEY] region_key = table_mapping_metadata[TableModel.REGION_KEY_KEY] table.obs = pd.DataFrame(table.obs) - table = table[table.obs[region_key].isin(coordinate_system)].copy() + table = table[table.obs[region_key].isin(element_names)].copy() table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = table.obs[region_key].unique().tolist() return table +def _get_unique_label_values_as_index(element: SpatialElement) -> pd.Index: + if isinstance(element, SpatialImage): + # get unique labels value (including 0 if present) + instances = da.unique(element.data).compute() + else: + assert isinstance(element, MultiscaleSpatialImage) + v = element["scale0"].values() + assert len(v) == 1 + xdata = next(iter(v)) + # can be slow + instances = da.unique(xdata.data).compute() + return pd.Index(np.sort(instances)) + + +# TODO: replace function use throughout repo by `join_sdata_spatialelement_table` def _filter_table_by_elements( table: AnnData | None, elements_dict: dict[str, dict[str, Any]], match_rows: bool = False ) -> AnnData | None: @@ -129,7 +174,390 @@ def _filter_table_by_elements( return table -def match_table_to_element(sdata: SpatialData, element_name: str) -> AnnData: +def _get_joined_table_indices( + joined_indices: pd.Index | None, + element_indices: pd.RangeIndex, + table_instance_key_column: pd.Series, + match_rows: Literal["left", "no", "right"], +) -> pd.Index: + """ + Get indices of the table that are present in element_indices. + + Parameters + ---------- + joined_indices + Current indices that have been found to match indices of an element + element_indices + Element indices to match against table_instance_key_column. + table_instance_key_column + The column of a table containing the instance ids. + match_rows + Whether to match the indices of the element and table and if so how. If left, element_indices take priority and + if right table instance ids take priority. + + Returns + ------- + The indices that of the table that match the SpatialElement indices. + """ + mask = table_instance_key_column.isin(element_indices) + if joined_indices is None: + if match_rows == "left": + joined_indices = _match_rows(table_instance_key_column, mask, element_indices, match_rows) + else: + joined_indices = table_instance_key_column[mask].index + else: + if match_rows == "left": + add_indices = _match_rows(table_instance_key_column, mask, element_indices, match_rows) + joined_indices = joined_indices.append(add_indices) + # in place append does not work with pd.Index + else: + joined_indices = joined_indices.append(table_instance_key_column[mask].index) + return joined_indices + + +def _get_masked_element( + element_indices: pd.RangeIndex, + element: SpatialElement, + table_instance_key_column: pd.Series, + match_rows: Literal["left", "no", "right"], +) -> SpatialElement: + """ + Get element rows matching the instance ids in the table_instance_key_column. + + Parameters + ---------- + element_indices + The indices of an element. + element + The spatial element to be masked. + table_instance_key_column + The column of a table containing the instance ids + match_rows + Whether to match the indices of the element and table and if so how. If left, element_indices take priority and + if right table instance ids take priority. + + Returns + ------- + The masked spatial element based on the provided indices and match rows. + """ + mask = table_instance_key_column.isin(element_indices) + masked_table_instance_key_column = table_instance_key_column[mask] + mask_values = mask_values if len(mask_values := masked_table_instance_key_column.values) != 0 else None + if match_rows == "right": + mask_values = _match_rows(table_instance_key_column, mask, element_indices, match_rows) + + return element.loc[mask_values, :] + + +def _right_exclusive_join_spatialelement_table( + element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] +) -> tuple[dict[str, Any], AnnData | None]: + regions, region_column_name, instance_key = get_table_keys(table) + groups_df = table.obs.groupby(by=region_column_name) + mask = [] + for element_type, name_element in element_dict.items(): + for name, element in name_element.items(): + if name in regions: + group_df = groups_df.get_group(name) + table_instance_key_column = group_df[instance_key] + if element_type in ["points", "shapes"]: + element_indices = element.index + else: + element_indices = _get_unique_label_values_as_index(element) + + element_dict[element_type][name] = None + submask = ~table_instance_key_column.isin(element_indices) + mask.append(submask) + else: + warnings.warn( + f"The element `{name}` is not annotated by the table. Skipping", UserWarning, stacklevel=2 + ) + element_dict[element_type][name] = None + continue + + if len(mask) != 0: + mask = pd.concat(mask) + exclusive_table = table[mask, :].copy() if mask.sum() != 0 else None # type: ignore[attr-defined] + else: + exclusive_table = None + + return element_dict, exclusive_table + + +def _right_join_spatialelement_table( + element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] +) -> tuple[dict[str, Any], AnnData]: + if match_rows == "left": + warnings.warn("Matching rows 'left' is not supported for 'right' join.", UserWarning, stacklevel=2) + regions, region_column_name, instance_key = get_table_keys(table) + groups_df = table.obs.groupby(by=region_column_name) + for element_type, name_element in element_dict.items(): + for name, element in name_element.items(): + if name in regions: + group_df = groups_df.get_group(name) + table_instance_key_column = group_df[instance_key] + if element_type in ["points", "shapes"]: + element_indices = element.index + else: + warnings.warn( + f"Element type `labels` not supported for 'right' join. Skipping `{name}`", + UserWarning, + stacklevel=2, + ) + continue + + masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows) + element_dict[element_type][name] = masked_element + else: + warnings.warn( + f"The element `{name}` is not annotated by the table. Skipping", UserWarning, stacklevel=2 + ) + continue + return element_dict, table + + +def _inner_join_spatialelement_table( + element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] +) -> tuple[dict[str, Any], AnnData]: + regions, region_column_name, instance_key = get_table_keys(table) + groups_df = table.obs.groupby(by=region_column_name) + joined_indices = None + for element_type, name_element in element_dict.items(): + for name, element in name_element.items(): + if name in regions: + group_df = groups_df.get_group(name) + table_instance_key_column = group_df[instance_key] # This is always a series + if element_type in ["points", "shapes"]: + element_indices = element.index + else: + warnings.warn( + f"Element type `labels` not supported for 'inner' join. Skipping `{name}`", + UserWarning, + stacklevel=2, + ) + continue + + masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows) + element_dict[element_type][name] = masked_element + + joined_indices = _get_joined_table_indices( + joined_indices, element_indices, table_instance_key_column, match_rows + ) + else: + warnings.warn( + f"The element `{name}` is not annotated by the table. Skipping", UserWarning, stacklevel=2 + ) + element_dict[element_type][name] = None + continue + + joined_table = table[joined_indices, :].copy() if joined_indices is not None else None + return element_dict, joined_table + + +def _left_exclusive_join_spatialelement_table( + element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] +) -> tuple[dict[str, Any], AnnData | None]: + regions, region_column_name, instance_key = get_table_keys(table) + groups_df = table.obs.groupby(by=region_column_name) + for element_type, name_element in element_dict.items(): + for name, element in name_element.items(): + if name in regions: + group_df = groups_df.get_group(name) + table_instance_key_column = group_df[instance_key] + if element_type in ["points", "shapes"]: + mask = np.full(len(element), True, dtype=bool) + mask[table_instance_key_column.values] = False + masked_element = element.loc[mask, :] if mask.sum() != 0 else None + element_dict[element_type][name] = masked_element + else: + warnings.warn( + f"Element type `labels` not supported for left exclusive join. Skipping `{name}`", + UserWarning, + stacklevel=2, + ) + continue + else: + warnings.warn( + f"The element `{name}` is not annotated by the table. Skipping", UserWarning, stacklevel=2 + ) + continue + + return element_dict, None + + +def _left_join_spatialelement_table( + element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] +) -> tuple[dict[str, Any], AnnData]: + if match_rows == "right": + warnings.warn("Matching rows 'right' is not supported for 'left' join.", UserWarning, stacklevel=2) + regions, region_column_name, instance_key = get_table_keys(table) + groups_df = table.obs.groupby(by=region_column_name) + joined_indices = None + for element_type, name_element in element_dict.items(): + for name, element in name_element.items(): + if name in regions: + group_df = groups_df.get_group(name) + table_instance_key_column = group_df[instance_key] # This is always a series + if element_type in ["points", "shapes"]: + element_indices = element.index + else: + element_indices = _get_unique_label_values_as_index(element) + + joined_indices = _get_joined_table_indices( + joined_indices, element_indices, table_instance_key_column, match_rows + ) + else: + warnings.warn( + f"The element `{name}` is not annotated by the table. Skipping", UserWarning, stacklevel=2 + ) + continue + + joined_indices = joined_indices.dropna() if joined_indices is not None else None + joined_table = table[joined_indices, :].copy() if joined_indices is not None else None + + return element_dict, joined_table + + +def _match_rows( + table_instance_key_column: pd.Series, + mask: pd.Series, + element_indices: pd.RangeIndex, + match_rows: str, +) -> pd.Index: + instance_id_df = pd.DataFrame( + {"instance_id": table_instance_key_column[mask].values, "index_right": table_instance_key_column[mask].index} + ) + element_index_df = pd.DataFrame({"index_left": element_indices}) + index_col = "index_left" if match_rows == "right" else "index_right" + + merged_df = pd.merge( + element_index_df, instance_id_df, left_on="index_left", right_on="instance_id", how=match_rows + )[index_col] + + # With labels it can be that index 0 is NaN + if isinstance(merged_df.iloc[0], float) and math.isnan(merged_df.iloc[0]): + merged_df = merged_df.iloc[1:] + + return pd.Index(merged_df) + + +class JoinTypes(Enum): + """Available join types for matching elements to tables and vice versa.""" + + left = partial(_left_join_spatialelement_table) + left_exclusive = partial(_left_exclusive_join_spatialelement_table) + inner = partial(_inner_join_spatialelement_table) + right = partial(_right_join_spatialelement_table) + right_exclusive = partial(_right_exclusive_join_spatialelement_table) + + def __call__(self, *args: Any) -> tuple[dict[str, Any], AnnData]: + return self.value(*args) + + +class MatchTypes(Enum): + """Available match types for matching rows of elements and tables.""" + + left = "left" + right = "right" + no = "no" + + +def join_sdata_spatialelement_table( + sdata: SpatialData, + spatial_element_name: str | list[str], + table_name: str, + how: str = "left", + match_rows: Literal["no", "left", "right"] = "no", +) -> tuple[dict[str, Any], AnnData]: + """Join SpatialElement(s) and table together in SQL like manner. + + The function allows the user to perform SQL like joins of SpatialElements and a table. The elements are not + returned together in one dataframe like structure, but instead filtered elements are returned. To determine matches, + for the SpatialElement the index is used and for the table the region key column and instance key column. The + elements are not overwritten in the `SpatialData` object. + + The following joins are supported: ``'left'``, ``'left_exclusive'``, ``'inner'``, ``'right'`` and + ``'right_exclusive'``. In case of a ``'left'`` join the SpatialElements are returned in a dictionary as is + while the table is filtered to only include matching rows. In case of ``'left_exclusive'`` join None is returned + for table while the SpatialElements returned are filtered to only include indices not present in the table. The + cases for ``'right'`` joins are symmetric to the ``'left'`` joins. In case of an ``'inner'`` join of + SpatialElement(s) and a table, for each an element is returned only containing the rows that are present in + both the SpatialElement and table. + + For Points and Shapes elements every valid join for argument how is supported. For Labels elements only + the ``'left'`` and ``'right_exclusive'`` joins are supported. + + Parameters + ---------- + sdata + The SpatialData object containing the tables and spatial elements. + spatial_element_name + The name(s) of the spatial elements to be joined with the table. + table_name + The name of the table to join with the spatial elements. + how + The type of SQL like join to perform, default is ``'left'``. Options are ``'left'``, ``'left_exclusive'``, + ``'inner'``, ``'right'`` and ``'right_exclusive'``. + match_rows + Whether to match the indices of the element and table and if so how. If ``'left'``, element_indices take + priority and if ``'right'`` table instance ids take priority. + + Returns + ------- + A tuple containing the joined elements as a dictionary and the joined table as an AnnData object. + + Raises + ------ + AssertionError + If no table with the given table_name exists in the SpatialData object. + ValueError + If the provided join type is not supported. + """ + assert sdata.tables.get(table_name), f"No table with `{table_name}` exists in the SpatialData object." + table = sdata.tables[table_name] + if isinstance(spatial_element_name, str): + spatial_element_name = [spatial_element_name] + + elements_dict: dict[str, dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) + for name in spatial_element_name: + if name in sdata.tables: + warnings.warn( + f"Tables: `{', '.join(elements_dict['tables'].keys())}` given in spatial_element_names cannot be " + f"joined with a table using this function.", + UserWarning, + stacklevel=2, + ) + elif name in sdata.images: + warnings.warn( + f"Images: `{', '.join(elements_dict['images'].keys())}` cannot be joined with a table", + UserWarning, + stacklevel=2, + ) + else: + element_type, _, element = sdata._find_element(name) + elements_dict[element_type][name] = element + + assert any(key in elements_dict for key in ["labels", "shapes", "points"]), ( + "No valid element to join in spatial_element_name. Must provide at least one of either `labels`, `points` or " + "`shapes`." + ) + + if match_rows not in MatchTypes.__dict__["_member_names_"]: + raise TypeError( + f"`{match_rows}` is an invalid argument for `match_rows`. Can be either `no`, ``'left'`` or ``'right'``" + ) + if how in JoinTypes.__dict__["_member_names_"]: + elements_dict, table = JoinTypes[how](elements_dict, table, match_rows) + else: + raise TypeError(f"`{how}` is not a valid type of join.") + + elements_dict = { + name: element for outer_key, dict_val in elements_dict.items() for name, element in dict_val.items() + } + return elements_dict, table + + +def match_table_to_element(sdata: SpatialData, element_name: str, table_name: str = "table") -> AnnData: """ Filter the table and reorders the rows to match the instances (rows/labels) of the specified SpatialElement. @@ -138,17 +566,53 @@ def match_table_to_element(sdata: SpatialData, element_name: str) -> AnnData: sdata SpatialData object element_name - Name of the element to match the table to + The name of the spatial elements to be joined with the table. + table_name + The name of the table to match to the element. Returns ------- Table with the rows matching the instances of the element """ - assert sdata.table is not None, "No table found in the SpatialData" + # TODO: refactor this to make use of the new join_sdata_spatialelement_table function. + # if table_name is None: + # warnings.warn( + # "Assumption of table with name `table` being present is being deprecated in SpatialData v0.1. " + # "Please provide the name of the table as argument to table_name.", + # DeprecationWarning, + # stacklevel=2, + # ) + # table_name = "table" + # _, table = join_sdata_spatialelement_table(sdata, element_name, table_name, "left", match_rows="left") + # return table + assert sdata[table_name] is not None, "No table found in the SpatialData" element_type, _, element = sdata._find_element(element_name) assert element_type in ["labels", "shapes"], f"Element {element_name} ({element_type}) is not supported" elements_dict = {element_type: {element_name: element}} - return _filter_table_by_elements(sdata.table, elements_dict, match_rows=True) + return _filter_table_by_elements(sdata[table_name], elements_dict, match_rows=True) + + +def match_element_to_table( + sdata: SpatialData, element_name: str | list[str], table_name: str +) -> tuple[dict[str, Any], AnnData]: + """ + Filter the elements and make the indices match those in the table. + + Parameters + ---------- + sdata + SpatialData object + element_name + The name(s) of the spatial elements to be joined with the table. Not supported for Label elements. + table_name + The name of the table to join with the spatial elements. + + Returns + ------- + A tuple containing the joined elements as a dictionary and the joined table as an AnnData object. + """ + element_dict, table = join_sdata_spatialelement_table(sdata, element_name, table_name, "right", match_rows="right") + return element_dict, table @dataclass @@ -173,6 +637,7 @@ def _locate_value( element: SpatialElement | None = None, sdata: SpatialData | None = None, element_name: str | None = None, + table_name: str | None = None, ) -> list[_ValueOrigin]: el = _get_element(element=element, sdata=sdata, element_name=element_name) origins = [] @@ -187,7 +652,7 @@ def _locate_value( # adding from the obs columns or var if model in [ShapesModel, Labels2DModel, Labels3DModel] and sdata is not None: - table = sdata.table + table = sdata.tables.get(table_name) if table_name is not None else None if table is not None: # check if the table is annotating the element region = table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] @@ -208,6 +673,7 @@ def get_values( element: SpatialElement | None = None, sdata: SpatialData | None = None, element_name: str | None = None, + table_name: str | None = None, ) -> pd.DataFrame: """ Get the values from the element, from any location: df columns, obs or var columns (table). @@ -222,6 +688,8 @@ def get_values( SpatialData object; either element or (sdata, element_name) must be provided element_name Name of the element; either element or (sdata, element_name) must be provided + table_name + Name of the table to get the values from. Returns ------- @@ -236,7 +704,9 @@ def get_values( value_keys = [value_key] if isinstance(value_key, str) else value_key locations = [] for vk in value_keys: - origins = _locate_value(value_key=vk, element=element, sdata=sdata, element_name=element_name) + origins = _locate_value( + value_key=vk, element=element, sdata=sdata, element_name=element_name, table_name=table_name + ) if len(origins) > 1: raise ValueError( f"{vk} has been found in multiple locations of (element, sdata, element_name) = " @@ -266,9 +736,9 @@ def get_values( if isinstance(el, DaskDataFrame): df = df.compute() return df - if sdata is not None: + if sdata is not None and table_name is not None: assert element_name is not None - matched_table = match_table_to_element(sdata=sdata, element_name=element_name) + matched_table = match_table_to_element(sdata=sdata, element_name=element_name, table_name=table_name) region_key = matched_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] instance_key = matched_table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] obs = matched_table.obs diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 017b4842..b6874193 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -17,7 +17,11 @@ from spatial_image import SpatialImage from xarray import DataArray -from spatialdata._core.query._utils import circles_to_polygons, get_bounding_box_corners +from spatialdata._core.query._utils import ( + _get_filtered_or_unfiltered_tables, + circles_to_polygons, + get_bounding_box_corners, +) from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata._utils import Number, _parse_list_into_array @@ -442,8 +446,6 @@ def _( target_coordinate_system: str, filter_table: bool = True, ) -> SpatialData: - from spatialdata._core.query.relational_query import _filter_table_by_elements - min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) new_elements = {} @@ -459,14 +461,9 @@ def _( ) new_elements[element_type] = queried_elements - if sdata.table is not None: - table = _filter_table_by_elements(sdata.table, new_elements) if filter_table else sdata.table - if len(table) == 0: - table = None - else: - table = None + tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata) - return SpatialData(**new_elements, table=table) + return SpatialData(**new_elements, tables=tables) @bounding_box_query.register(SpatialImage) @@ -566,6 +563,8 @@ def _( if 0 in query_result.shape: return None assert isinstance(query_result, SpatialImage) + # rechunk the data to avoid irregular chunks + image = image.chunk("auto") else: assert isinstance(image, MultiscaleSpatialImage) assert isinstance(query_result, DataTree) @@ -582,6 +581,9 @@ def _( else: d[k] = xdata query_result = MultiscaleSpatialImage.from_dict(d) + # rechunk the data to avoid irregular chunks + for scale in query_result: + query_result[scale]["image"] = query_result[scale]["image"].chunk("auto") query_result = compute_coordinates(query_result) # the bounding box, mapped back to the intrinsic coordinate system is a set of points. The bounding box of these @@ -763,8 +765,8 @@ def polygon_query( target_coordinate_system The coordinate system of the polygon/multipolygon. filter_table - If `True`, the table is filtered to only contain rows that are annotating regions - contained within the query polygon/multipolygon. + Specifies whether to filter the tables to only include tables that annotate elements in the retrieved + SpatialData object of the query. shapes [Deprecated] This argument is now ignored and will be removed. Please filter the SpatialData object before calling this function. @@ -803,7 +805,6 @@ def _( images: bool = True, labels: bool = True, ) -> SpatialData: - from spatialdata._core.query.relational_query import _filter_table_by_elements _check_deprecated_kwargs({"shapes": shapes, "points": points, "images": images, "labels": labels}) new_elements = {} @@ -817,14 +818,9 @@ def _( ) new_elements[element_type] = queried_elements - if sdata.table is not None: - table = _filter_table_by_elements(sdata.table, new_elements) if filter_table else sdata.table - if len(table) == 0: - table = None - else: - table = None + tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata) - return SpatialData(**new_elements, table=table) + return SpatialData(**new_elements, tables=tables) @polygon_query.register(SpatialImage) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 4a9acb83..f636afd8 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -2,11 +2,13 @@ import hashlib import os +import warnings from collections.abc import Generator from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal +import pandas as pd import zarr from anndata import AnnData from dask.dataframe import read_parquet @@ -19,10 +21,10 @@ from shapely import MultiPolygon, Polygon from spatial_image import SpatialImage -from spatialdata._core._elements import Images, Labels, Points, Shapes +from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables from spatialdata._logging import logger from spatialdata._types import ArrayLike, Raster_T -from spatialdata._utils import _error_message_add_element +from spatialdata._utils import _error_message_add_element, deprecation_alias from spatialdata.models import ( Image2DModel, Image3DModel, @@ -31,7 +33,9 @@ PointsModel, ShapesModel, TableModel, + check_target_region_column_symmetry, get_model, + get_table_keys, ) from spatialdata.models._utils import SpatialElement, get_axes_names @@ -106,13 +110,14 @@ class SpatialData: """ + @deprecation_alias(table="tables") def __init__( self, images: dict[str, Raster_T] | None = None, labels: dict[str, Raster_T] | None = None, points: dict[str, DaskDataFrame] | None = None, shapes: dict[str, GeoDataFrame] | None = None, - table: AnnData | None = None, + tables: dict[str, AnnData] | Tables | None = None, ) -> None: self._path: Path | None = None @@ -121,7 +126,11 @@ def __init__( self._labels: Labels = Labels(shared_keys=self._shared_keys) self._points: Points = Points(shared_keys=self._shared_keys) self._shapes: Shapes = Shapes(shared_keys=self._shared_keys) - self._table: AnnData | None = None + self._tables: Tables = Tables(shared_keys=self._shared_keys) + + # Workaround to allow for backward compatibility + if isinstance(tables, AnnData): + tables = {"table": tables} self._validate_unique_element_names( list(chain.from_iterable([e.keys() for e in [images, labels, points, shapes] if e is not None])) @@ -143,12 +152,72 @@ def __init__( for k, v in points.items(): self.points[k] = v - if table is not None: - Table_s.validate(table) - self._table = table + if tables is not None: + for k, v in tables.items(): + self.validate_table_in_spatialdata(v) + self.tables[k] = v self._query = QueryManager(self) + def validate_table_in_spatialdata(self, table: AnnData) -> None: + """ + Validate the presence of the annotation target of a SpatialData table in the SpatialData object. + + This method validates a table in the SpatialData object to ensure that if annotation metadata is present, the + annotation target (SpatialElement) is present in the SpatialData object, the dtypes of the instance key column + in the table and the annotation target do not match. Otherwise, a warning is raised. + + Parameters + ---------- + table + The table potentially annotating a SpatialElement + + Raises + ------ + UserWarning + If the table is annotating elements not present in the SpatialData object. + UserWarning + The dtypes of the instance key column in the table and the annotation target do not match. + """ + TableModel().validate(table) + if TableModel.ATTRS_KEY in table.uns: + region, _, instance_key = get_table_keys(table) + region = region if isinstance(region, list) else [region] + for r in region: + element = self.get(r) + if element is None: + warnings.warn( + f"The table is annotating {r!r}, which is not present in the SpatialData object.", + UserWarning, + stacklevel=2, + ) + else: + if isinstance(element, SpatialImage): + dtype = element.dtype + elif isinstance(element, MultiscaleSpatialImage): + dtype = element.scale0.ds.dtypes["image"] + else: + dtype = element.index.dtype + if dtype != table.obs[instance_key].dtype: + if dtype == str or table.obs[instance_key].dtype == str: + raise TypeError( + f"Table instance_key column ({instance_key}) has a dtype " + f"({table.obs[instance_key].dtype}) that does not match the dtype of the indices of " + f"the annotated element ({dtype})." + ) + + warnings.warn( + ( + f"Table instance_key column ({instance_key}) has a dtype " + f"({table.obs[instance_key].dtype}) that does not match the dtype of the indices of " + f"the annotated element ({dtype}). Please note in the case of int16 vs int32 or " + "similar cases may be tolerated in downstream methods, but it is recommended to make " + "the dtypes match." + ), + UserWarning, + stacklevel=2, + ) + @staticmethod def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> SpatialData: """ @@ -169,7 +238,7 @@ def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> Sp "labels": {}, "points": {}, "shapes": {}, - "table": None, + "tables": {}, } for k, e in elements_dict.items(): schema = get_model(e) @@ -186,13 +255,200 @@ def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> Sp assert isinstance(d["shapes"], dict) d["shapes"][k] = e elif schema == TableModel: - if d["table"] is not None: - raise ValueError("Only one table can be present in the dataset.") - d["table"] = e + assert isinstance(d["tables"], dict) + d["tables"][k] = e else: raise ValueError(f"Unknown schema {schema}") return SpatialData(**d) # type: ignore[arg-type] + @staticmethod + def get_annotated_regions(table: AnnData) -> str | list[str]: + """ + Get the regions annotated by a table. + + Parameters + ---------- + table + The AnnData table for which to retrieve annotated regions. + + Returns + ------- + The annotated regions. + """ + regions, _, _ = get_table_keys(table) + return regions + + @staticmethod + def get_region_key_column(table: AnnData) -> pd.Series: + """Get the column of table.obs containing per row the region annotated by that row. + + Parameters + ---------- + table + The AnnData table. + + Returns + ------- + The region key column. + + Raises + ------ + KeyError + If the region key column is not found in table.obs. + """ + _, region_key, _ = get_table_keys(table) + if table.obs.get(region_key): + return table.obs[region_key] + raise KeyError(f"{region_key} is set as region key column. However the column is not found in table.obs.") + + @staticmethod + def get_instance_key_column(table: AnnData) -> pd.Series: + """ + Return the instance key column in table.obs containing for each row the instance id of that row. + + Parameters + ---------- + table + The AnnData table. + + Returns + ------- + The instance key column. + + Raises + ------ + KeyError + If the instance key column is not found in table.obs. + + """ + _, _, instance_key = get_table_keys(table) + if table.obs.get(instance_key): + return table.obs[instance_key] + raise KeyError(f"{instance_key} is set as instance key column. However the column is not found in table.obs.") + + @staticmethod + def _set_table_annotation_target( + table: AnnData, + region: str | pd.Series, + region_key: str, + instance_key: str, + ) -> None: + """ + Set the SpatialElement annotation target of an AnnData table. + + This method sets the target annotation element of a table based on the specified parameters. It creates the + `attrs` dictionary for `table.uns` and only after validation that the regions are present in the region_key + column of table.obs updates the annotation metadata of the table. + + Parameters + ---------- + table + The AnnData object containing the data table. + region + The name of the target element for the table annotation. + region_key + The key for the region annotation column in `table.obs`. + instance_key + The key for the instance annotation column in `table.obs`. + + Raises + ------ + ValueError + If `region_key` is not present in the `table.obs` columns. + ValueError + If `instance_key` is not present in the `table.obs` columns. + """ + TableModel()._validate_set_region_key(table, region_key) + TableModel()._validate_set_instance_key(table, instance_key) + attrs = { + TableModel.REGION_KEY: region, + TableModel.REGION_KEY_KEY: region_key, + TableModel.INSTANCE_KEY: instance_key, + } + check_target_region_column_symmetry(table, region_key, region) + table.uns[TableModel.ATTRS_KEY] = attrs + + @staticmethod + def _change_table_annotation_target( + table: AnnData, + region: str | pd.Series, + region_key: None | str = None, + instance_key: None | str = None, + ) -> None: + """Change the annotation target of a table currently having annotation metadata already. + + Parameters + ---------- + table + The table already annotating a SpatialElement. + region + The name of the target SpatialElement for which the table annotation will be changed. + region_key + The name of the region key column in the table. If not provided, it will be extracted from the table's uns + attribute. If present here but also given as argument, the value in the table's uns attribute will be + overwritten. + instance_key + The name of the instance key column in the table. If not provided, it will be extracted from the table's uns + attribute. If present here but also given as argument, the value in the table's uns attribute will be + overwritten. + + Raises + ------ + ValueError + If no region_key is provided, and it is not present in both table.uns['spatialdata_attrs'] and table.obs. + ValueError + If provided region_key is not present in table.obs. + """ + attrs = table.uns[TableModel.ATTRS_KEY] + table_region_key = region_key if region_key else attrs.get(TableModel.REGION_KEY_KEY) + + TableModel()._validate_set_region_key(table, region_key) + TableModel()._validate_set_instance_key(table, instance_key) + check_target_region_column_symmetry(table, table_region_key, region) + attrs[TableModel.REGION_KEY] = region + + def set_table_annotates_spatialelement( + self, + table_name: str, + region: str | pd.Series, + region_key: None | str = None, + instance_key: None | str = None, + ) -> None: + """ + Set the SpatialElement annotation target of a given AnnData table. + + Parameters + ---------- + table_name + The name of the table to set the annotation target for. + region + The name of the target element for the annotation. This can either be a string or a pandas Series object. + region_key + The region key for the annotation. If not specified, defaults to None which means the currently set region + key is reused. + instance_key + The instance key for the annotation. If not specified, defaults to None which means the currently set + instance key is reused. + + Raises + ------ + ValueError + If the annotation SpatialElement target is not present in the SpatialData object. + TypeError + If no current annotation metadata is found and both region_key and instance_key are not specified. + """ + table = self.tables[table_name] + element_names = {element[1] for element in self._gen_elements()} + if region not in element_names: + raise ValueError(f"Annotation target '{region}' not present as SpatialElement in SpatialData object.") + + if table.uns.get(TableModel.ATTRS_KEY): + self._change_table_annotation_target(table, region, region_key, instance_key) + elif isinstance(region_key, str) and isinstance(instance_key, str): + self._set_table_annotation_target(table, region, region_key, instance_key) + else: + raise TypeError("No current annotation metadata found. Please specify both region_key and instance_key.") + @property def query(self) -> QueryManager: return self._query @@ -210,6 +466,7 @@ def aggregate( region_key: str = "region", instance_key: str = "instance_id", deepcopy: bool = True, + table_name: str = "table", **kwargs: Any, ) -> SpatialData: """ @@ -242,6 +499,7 @@ def aggregate( region_key=region_key, instance_key=instance_key, deepcopy=deepcopy, + table_name=table_name, **kwargs, ) @@ -380,7 +638,10 @@ def _write_transformations_to_disk(self, element: SpatialElement) -> None: else: raise ValueError("Unknown element type") - def filter_by_coordinate_system(self, coordinate_system: str | list[str], filter_table: bool = True) -> SpatialData: + @deprecation_alias(filter_table="filter_tables") + def filter_by_coordinate_system( + self, coordinate_system: str | list[str], filter_tables: bool = True, include_orphan_tables: bool = False + ) -> SpatialData: """ Filter the SpatialData by one (or a list of) coordinate system. @@ -391,37 +652,104 @@ def filter_by_coordinate_system(self, coordinate_system: str | list[str], filter ---------- coordinate_system The coordinate system(s) to filter by. - filter_table - If True (default), the table will be filtered to only contain regions + filter_tables + If True (default), the tables will be filtered to only contain regions of an element belonging to the specified coordinate system(s). + include_orphan_tables + If True (not default), include tables that do not annotate SpatialElement(s). Only has an effect if + filter_tables is also set to True. Returns ------- The filtered SpatialData. """ - from spatialdata._core.query.relational_query import _filter_table_by_coordinate_system + # TODO: decide whether to add parameter to filter only specific table. + from spatialdata.transformations.operations import get_transformation elements: dict[str, dict[str, SpatialElement]] = {} - element_paths_in_coordinate_system = [] + element_names_in_coordinate_system = [] if isinstance(coordinate_system, str): coordinate_system = [coordinate_system] for element_type, element_name, element in self._gen_elements(): - transformations = get_transformation(element, get_all=True) - assert isinstance(transformations, dict) - for cs in coordinate_system: - if cs in transformations: - if element_type not in elements: - elements[element_type] = {} - elements[element_type][element_name] = element - element_paths_in_coordinate_system.append(element_name) - - if filter_table: - table = _filter_table_by_coordinate_system(self.table, element_paths_in_coordinate_system) + if element_type != "tables": + transformations = get_transformation(element, get_all=True) + assert isinstance(transformations, dict) + for cs in coordinate_system: + if cs in transformations: + if element_type not in elements: + elements[element_type] = {} + elements[element_type][element_name] = element + element_names_in_coordinate_system.append(element_name) + tables = self._filter_tables( + set(), filter_tables, "cs", include_orphan_tables, element_names=element_names_in_coordinate_system + ) + + return SpatialData(**elements, tables=tables) + + # TODO: move to relational query with refactor + def _filter_tables( + self, + names_tables_to_keep: set[str], + filter_tables: bool = True, + by: Literal["cs", "elements"] | None = None, + include_orphan_tables: bool = False, + element_names: str | list[str] | None = None, + elements_dict: dict[str, dict[str, Any]] | None = None, + ) -> Tables | dict[str, AnnData]: + """ + Filter tables by coordinate system or elements or return tables. + + Parameters + ---------- + names_tables_to_keep + The names of the tables to keep even when filter_tables is True. + filter_tables + If True (default), the tables will be filtered to only contain regions + of an element belonging to the specified coordinate system(s) or including only rows annotating specified + elements. + by + Filter mode. Valid values are "cs" or "elements". Default is None. + include_orphan_tables + Flag indicating whether to include orphan tables. Default is False. + element_names + Element names of elements present in specific coordinate system. + elements_dict + Dictionary of elements for filtering the tables. Default is None. + + Returns + ------- + The filtered tables if filter_tables was True, otherwise tables of the SpatialData object. + + """ + if filter_tables: + tables: dict[str, AnnData] | Tables = {} + for table_name, table in self._tables.items(): + if include_orphan_tables and not table.uns.get(TableModel.ATTRS_KEY): + tables[table_name] = table + continue + if table_name in names_tables_to_keep: + tables[table_name] = table + continue + # each mode here requires paths or elements, using assert here to avoid mypy errors. + if by == "cs": + from spatialdata._core.query.relational_query import _filter_table_by_element_names + + assert element_names is not None + table = _filter_table_by_element_names(table, element_names) + if len(table) != 0: + tables[table_name] = table + elif by == "elements": + from spatialdata._core.query.relational_query import _filter_table_by_elements + + assert elements_dict is not None + table = _filter_table_by_elements(table, elements_dict=elements_dict) + if len(table) != 0: + tables[table_name] = table else: - table = self.table + tables = self.tables - return SpatialData(**elements, table=table) + return tables def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: """ @@ -453,7 +781,7 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: new_names.append(new_cs) # rename the coordinate systems - for element in self._gen_elements_values(): + for element in self._gen_spatial_element_values(): # get the transformations transformations = get_transformation(element, get_all=True) assert isinstance(transformations, dict) @@ -574,16 +902,17 @@ def transform_to_coordinate_system( ------- The transformed SpatialData. """ - sdata = self.filter_by_coordinate_system(target_coordinate_system, filter_table=False) + sdata = self.filter_by_coordinate_system(target_coordinate_system, filter_tables=False) elements: dict[str, dict[str, SpatialElement]] = {} for element_type, element_name, element in sdata._gen_elements(): - transformed = sdata.transform_element_to_coordinate_system( - element, target_coordinate_system, maintain_positioning=maintain_positioning - ) - if element_type not in elements: - elements[element_type] = {} - elements[element_type][element_name] = transformed - return SpatialData(**elements, table=sdata.table) + if element_type != "tables": + transformed = sdata.transform_element_to_coordinate_system( + element, target_coordinate_system, maintain_positioning=maintain_positioning + ) + if element_type not in elements: + elements[element_type] = {} + elements[element_type][element_name] = transformed + return SpatialData(**elements, tables=sdata.tables) def write( self, @@ -726,9 +1055,10 @@ def write( name=name, ) - if self.table is not None: - elem_group = root.create_group(name="table") - write_table(table=self.table, group=elem_group, name="table") + if len(self.tables): + elem_group = root.create_group(name="tables") + for key in self.tables: + write_table(table=self.tables[key], group=elem_group, name=key) except Exception as e: # noqa: B902 self._path = None @@ -775,52 +1105,70 @@ def write( assert isinstance(self.path, Path) @property - def table(self) -> AnnData: + def tables(self) -> Tables: """ - Return the table. + Return tables dictionary. Returns ------- - The table. + dict[str, AnnData] + Either the empty dictionary or a dictionary with as values the strings representing the table names and + as values the AnnData tables themselves. """ - return self._table + return self._tables - @table.setter - def table(self, table: AnnData) -> None: - """ - Set the table of a SpatialData object in a object that doesn't contain a table. + @tables.setter + def tables(self, shapes: dict[str, GeoDataFrame]) -> None: + """Set tables.""" + self._shared_keys = self._shared_keys - set(self._tables.keys()) + self._tables = Tables(shared_keys=self._shared_keys) + for k, v in shapes.items(): + self._tables[k] = v - Parameters - ---------- - table - The table to set. + @property + def table(self) -> None | AnnData: + """ + Return table with name table from tables if it exists. - Notes - ----- - If a table is already present, it needs to be removed first. - The table needs to pass validation (see :class:`~spatialdata.TableModel`). - If the SpatialData object is backed by a Zarr storage, the table will be written to the Zarr storage. + Returns + ------- + The table. """ - from spatialdata._io.io_table import write_table + warnings.warn( + "Table accessor will be deprecated with SpatialData version 0.1, use sdata.tables instead.", + DeprecationWarning, + stacklevel=2, + ) + # Isinstance will still return table if anndata has 0 rows. + if isinstance(self.tables.get("table"), AnnData): + return self.tables["table"] + return None + @table.setter + def table(self, table: AnnData) -> None: + warnings.warn( + "Table setter will be deprecated with SpatialData version 0.1, use tables instead.", + DeprecationWarning, + stacklevel=2, + ) TableModel().validate(table) - if self.table is not None: - raise ValueError("The table already exists. Use del sdata.table to remove it first.") - self._table = table - if self.is_backed(): - store = parse_url(self.path, mode="r+").store - root = zarr.group(store=store) - elem_group = root.require_group(name="table") - write_table(table=self.table, group=elem_group, name="table") + if self.tables.get("table") is not None: + raise ValueError("The table already exists. Use del sdata.tables['table'] to remove it first.") + self.tables["table"] = table @table.deleter def table(self) -> None: """Delete the table.""" - self._table = None - if self.is_backed(): - store = parse_url(self.path, mode="r+").store - root = zarr.group(store=store) - del root["table/table"] + warnings.warn( + "del sdata.table will be deprecated with SpatialData version 0.1, use del sdata.tables['table'] instead.", + DeprecationWarning, + stacklevel=2, + ) + if self.tables.get("table"): + del self.tables["table"] + else: + # More informative than the error in the zarr library. + raise KeyError("table with name 'table' not present in the SpatialData object.") @staticmethod def read(file_path: Path | str, selection: tuple[str] | None = None) -> SpatialData: @@ -933,7 +1281,7 @@ def coordinate_systems(self) -> list[str]: from spatialdata.transformations.operations import get_transformation all_cs = set() - gen = self._gen_elements_values() + gen = self._gen_spatial_element_values() for obj in gen: transformations = get_transformation(obj, get_all=True) assert isinstance(transformations, dict) @@ -949,7 +1297,7 @@ def _non_empty_elements(self) -> list[str]: non_empty_elements The names of the elements that are not empty. """ - all_elements = ["images", "labels", "points", "shapes", "table"] + all_elements = ["images", "labels", "points", "shapes", "tables"] return [ element for element in all_elements @@ -987,71 +1335,64 @@ def h(s: str) -> str: attribute = getattr(self, attr) descr += f"\n{h('level0')}{attr.capitalize()}" - if isinstance(attribute, AnnData): + + unsorted_elements = attribute.items() + sorted_elements = sorted(unsorted_elements, key=lambda x: _natural_keys(x[0])) + for k, v in sorted_elements: descr += f"{h('empty_line')}" - descr_class = attribute.__class__.__name__ - descr += f"{h('level1.0')}{attribute!r}: {descr_class} {attribute.shape}" - descr = rreplace(descr, h("level1.0"), " └── ", 1) - else: - unsorted_elements = attribute.items() - sorted_elements = sorted(unsorted_elements, key=lambda x: _natural_keys(x[0])) - for k, v in sorted_elements: - descr += f"{h('empty_line')}" - descr_class = v.__class__.__name__ - if attr == "shapes": - descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} " f"shape: {v.shape} (2D shapes)" - elif attr == "points": - length: int | None = None - if len(v.dask.layers) == 1: - name, layer = v.dask.layers.items().__iter__().__next__() - if "read-parquet" in name: - t = layer.creation_info["args"] - assert isinstance(t, tuple) - assert len(t) == 1 - parquet_file = t[0] - table = read_parquet(parquet_file) - length = len(table) - else: - # length = len(v) - length = None + descr_class = v.__class__.__name__ + if attr == "shapes": + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} " f"shape: {v.shape} (2D shapes)" + elif attr == "points": + length: int | None = None + if len(v.dask.layers) == 1: + name, layer = v.dask.layers.items().__iter__().__next__() + if "read-parquet" in name: + t = layer.creation_info["args"] + assert isinstance(t, tuple) + assert len(t) == 1 + parquet_file = t[0] + table = read_parquet(parquet_file) + length = len(table) else: + # length = len(v) length = None + else: + length = None - n = len(get_axes_names(v)) - dim_string = f"({n}D points)" + n = len(get_axes_names(v)) + dim_string = f"({n}D points)" - assert len(v.shape) == 2 - if length is not None: - shape_str = f"({length}, {v.shape[1]})" - else: - shape_str = ( - "(" - + ", ".join( - [str(dim) if not isinstance(dim, Delayed) else "" for dim in v.shape] - ) - + ")" - ) - descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} " f"with shape: {shape_str} {dim_string}" + assert len(v.shape) == 2 + if length is not None: + shape_str = f"({length}, {v.shape[1]})" else: - if isinstance(v, SpatialImage): - descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class}[{''.join(v.dims)}] {v.shape}" - elif isinstance(v, MultiscaleSpatialImage): - shapes = [] - dims: str | None = None - for pyramid_level in v: - dataset_names = list(v[pyramid_level].keys()) - assert len(dataset_names) == 1 - dataset_name = dataset_names[0] - vv = v[pyramid_level][dataset_name] - shape = vv.shape - if dims is None: - dims = "".join(vv.dims) - shapes.append(shape) - descr += ( - f"{h(attr + 'level1.1')}{k!r}: {descr_class}[{dims}] " f"{', '.join(map(str, shapes))}" - ) - else: - raise TypeError(f"Unknown type {type(v)}") + shape_str = ( + "(" + + ", ".join([str(dim) if not isinstance(dim, Delayed) else "" for dim in v.shape]) + + ")" + ) + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} " f"with shape: {shape_str} {dim_string}" + elif attr == "tables": + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} {v.shape}" + else: + if isinstance(v, SpatialImage): + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class}[{''.join(v.dims)}] {v.shape}" + elif isinstance(v, MultiscaleSpatialImage): + shapes = [] + dims: str | None = None + for pyramid_level in v: + dataset_names = list(v[pyramid_level].keys()) + assert len(dataset_names) == 1 + dataset_name = dataset_names[0] + vv = v[pyramid_level][dataset_name] + shape = vv.shape + if dims is None: + dims = "".join(vv.dims) + shapes.append(shape) + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class}[{dims}] " f"{', '.join(map(str, shapes))}" + else: + raise TypeError(f"Unknown type {type(v)}") if last_attr is True: descr = descr.replace(h("empty_line"), "\n ") else: @@ -1060,7 +1401,7 @@ def h(s: str) -> str: descr = rreplace(descr, h("level0"), "└── ", 1) descr = descr.replace(h("level0"), "├── ") - for attr in ["images", "labels", "points", "table", "shapes"]: + for attr in ["images", "labels", "points", "tables", "shapes"]: descr = rreplace(descr, h(attr + "level1.1"), " └── ", 1) descr = descr.replace(h(attr + "level1.1"), " ├── ") @@ -1074,13 +1415,14 @@ def h(s: str) -> str: gen = self._gen_elements() elements_in_cs: dict[str, list[str]] = {} for k, name, obj in gen: - transformations = get_transformation(obj, get_all=True) - assert isinstance(transformations, dict) - target_css = transformations.keys() - if cs in target_css: - if k not in elements_in_cs: - elements_in_cs[k] = [] - elements_in_cs[k].append(name) + if not isinstance(obj, AnnData): + transformations = get_transformation(obj, get_all=True) + assert isinstance(transformations, dict) + target_css = transformations.keys() + if cs in target_css: + if k not in elements_in_cs: + elements_in_cs[k] = [] + elements_in_cs[k].append(name) for element_names in elements_in_cs.values(): element_names.sort(key=_natural_keys) if len(elements_in_cs) > 0: @@ -1096,26 +1438,98 @@ def h(s: str) -> str: descr += "\n" return descr - def _gen_elements_values(self) -> Generator[SpatialElement, None, None]: + def _gen_spatial_element_values(self) -> Generator[SpatialElement, None, None]: + """ + Generate spatial element objects contained in the SpatialData instance. + + Returns + ------- + Generator[SpatialElement, None, None] + A generator that yields spatial element objects contained in the SpatialData instance. + + """ for element_type in ["images", "labels", "points", "shapes"]: d = getattr(SpatialData, element_type).fget(self) yield from d.values() - def _gen_elements(self) -> Generator[tuple[str, str, SpatialElement], None, None]: - for element_type in ["images", "labels", "points", "shapes"]: + def _gen_elements( + self, include_table: bool = False + ) -> Generator[tuple[str, str, SpatialElement | AnnData], None, None]: + """ + Generate elements contained in the SpatialData instance. + + Parameters + ---------- + include_table + Whether to also generate table elements. + + Returns + ------- + A generator object that returns a tuple containing the type of the element, its name, and the element + itself. + """ + element_types = ["images", "labels", "points", "shapes"] + if include_table: + element_types.append("tables") + for element_type in element_types: d = getattr(SpatialData, element_type).fget(self) for k, v in d.items(): yield element_type, k, v - def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement]: - for element_type, element_name_, element in self._gen_elements(): + def gen_spatial_elements(self) -> Generator[tuple[str, str, SpatialElement], None, None]: + """ + Generate spatial elements within the SpatialData object. + + This method generates spatial elements (images, labels, points and shapes). + + Returns + ------- + A generator that yields tuples containing the element_type (string), name, and SpatialElement objects + themselves. + """ + return self._gen_elements() + + def gen_elements(self) -> Generator[tuple[str, str, SpatialElement | AnnData], None, None]: + """ + Generate elements within the SpatialData object. + + This method generates elements in the SpatialData object (images, labels, points, shapes and tables) + + Returns + ------- + A generator that yields tuples containing the name, description, and element objects themselves. + """ + return self._gen_elements(include_table=True) + + def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | AnnData]: + """ + Retrieve element from the SpatialData instance matching element_name. + + Parameters + ---------- + element_name + The name of the element to find. + + Returns + ------- + A tuple containing the element type, element name, and the retrieved element itself. + + Raises + ------ + KeyError + If the element with the given name cannot be found. + """ + for element_type, element_name_, element in self.gen_elements(): if element_name_ == element_name: return element_type, element_name_, element else: raise KeyError(f"Could not find element with name {element_name!r}") @classmethod - def init_from_elements(cls, elements: dict[str, SpatialElement], table: AnnData | None = None) -> SpatialData: + @deprecation_alias(table="tables") + def init_from_elements( + cls, elements: dict[str, SpatialElement], tables: AnnData | dict[str, AnnData] | None = None + ) -> SpatialData: """ Create a SpatialData object from a dict of named elements and an optional table. @@ -1123,8 +1537,8 @@ def init_from_elements(cls, elements: dict[str, SpatialElement], table: AnnData ---------- elements A dict of named elements. - table - An optional table. + tables + An optional table or dictionary of tables Returns ------- @@ -1143,34 +1557,46 @@ def init_from_elements(cls, elements: dict[str, SpatialElement], table: AnnData assert model == ShapesModel element_type = "shapes" elements_dict.setdefault(element_type, {})[name] = element - return cls(**elements_dict, table=table) + return cls(**elements_dict, tables=tables) - def subset(self, element_names: list[str], filter_table: bool = True) -> SpatialData: + def subset( + self, element_names: list[str], filter_tables: bool = True, include_orphan_tables: bool = False + ) -> SpatialData: """ Subset the SpatialData object. Parameters ---------- element_names - The names of the element_names to subset. + The names of the element_names to subset. If the element_name is the name of a table, this table would be + completely included in the subset even if filter_table is True. filter_table If True (default), the table is filtered to only contain rows that are annotating regions contained within the element_names. + include_orphan_tables + If True (not default), include tables that do not annotate SpatialElement(s). Only has an effect if + filter_tables is also set to True. Returns ------- The subsetted SpatialData object. """ - from spatialdata._core.query.relational_query import _filter_table_by_elements - elements_dict: dict[str, SpatialElement] = {} - for element_type, element_name, element in self._gen_elements(): + names_tables_to_keep: set[str] = set() + for element_type, element_name, element in self._gen_elements(include_table=True): if element_name in element_names: - elements_dict.setdefault(element_type, {})[element_name] = element - table = _filter_table_by_elements(self.table, elements_dict=elements_dict) if filter_table else self.table - if len(table) == 0: - table = None - return SpatialData(**elements_dict, table=table) + if element_type != "tables": + elements_dict.setdefault(element_type, {})[element_name] = element + else: + names_tables_to_keep.add(element_name) + tables = self._filter_tables( + names_tables_to_keep, + filter_tables, + "elements", + include_orphan_tables, + elements_dict=elements_dict, + ) + return SpatialData(**elements_dict, tables=tables) def __getitem__(self, item: str) -> SpatialElement: """ @@ -1188,6 +1614,33 @@ def __getitem__(self, item: str) -> SpatialElement: _, _, element = self._find_element(item) return element + def __contains__(self, key: str) -> bool: + element_dict = { + element_name: element_value for _, element_name, element_value in self._gen_elements(include_table=True) + } + return key in element_dict + + def get(self, key: str, default_value: SpatialElement | AnnData | None = None) -> SpatialElement | AnnData: + """ + Get element from SpatialData object based on corresponding name. + + Parameters + ---------- + key + The key to lookup in the spatial elements. + default_value + The default value (a SpatialElement or a table) to return if the key is not found. Default is None. + + Returns + ------- + The SpatialData element associated with the given key, if found. Otherwise, the default value is returned. + """ + for _, element_name_, element in self.gen_elements(): + if element_name_ == key: + return element + else: + return default_value + def __setitem__(self, key: str, value: SpatialElement | AnnData) -> None: """ Add the element to the SpatialData object. @@ -1209,7 +1662,7 @@ def __setitem__(self, key: str, value: SpatialElement | AnnData) -> None: elif schema == ShapesModel: self.shapes[key] = value elif schema == TableModel: - raise TypeError("Use the table property to set the table (e.g. sdata.table = value).") + self.tables[key] = value else: raise TypeError(f"Unknown element type with schema: {schema!r}.") diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 37be41fa..f5caa59d 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -10,7 +10,11 @@ from functools import singledispatch from typing import Any +import numpy as np import zarr +from anndata import AnnData +from anndata import read_zarr as read_anndata_zarr +from anndata.experimental import read_elem from dask.array.core import Array as DaskArray from dask.dataframe.core import DataFrame as DaskDataFrame from multiscale_spatial_image import MultiscaleSpatialImage @@ -19,7 +23,9 @@ from spatial_image import SpatialImage from spatialdata._core.spatialdata import SpatialData +from spatialdata._logging import logger from spatialdata._utils import iterate_pyramid_levels +from spatialdata.models import TableModel from spatialdata.models._utils import ( MappingToCoordinateSystem_t, ValidAxis_t, @@ -173,8 +179,8 @@ def _are_directories_identical( if _root_dir2 is None: _root_dir2 = dir2 if exclude_regexp is not None and ( - re.match(rf"{_root_dir1}/" + exclude_regexp, str(dir1)) - or re.match(rf"{_root_dir2}/" + exclude_regexp, str(dir2)) + re.match(rf"{re.escape(str(_root_dir1))}/" + exclude_regexp, str(dir1)) + or re.match(rf"{re.escape(str(_root_dir2))}/" + exclude_regexp, str(dir2)) ): return True @@ -227,7 +233,7 @@ def get_dask_backing_files(element: SpatialData | SpatialImage | MultiscaleSpati @get_dask_backing_files.register(SpatialData) def _(element: SpatialData) -> list[str]: files: set[str] = set() - for e in element._gen_elements_values(): + for e in element._gen_spatial_element_values(): if isinstance(e, (SpatialImage, MultiscaleSpatialImage, DaskDataFrame)): files = files.union(get_dask_backing_files(e)) return list(files) @@ -304,6 +310,57 @@ def save_transformations(sdata: SpatialData) -> None: """ from spatialdata.transformations import get_transformation, set_transformation - for element in sdata._gen_elements_values(): + for element in sdata._gen_spatial_element_values(): transformations = get_transformation(element, get_all=True) set_transformation(element, transformations, set_all=True, write_to_sdata=sdata) + + +def read_table_and_validate( + zarr_store_path: str, group: zarr.Group, subgroup: zarr.Group, tables: dict[str, AnnData] +) -> dict[str, AnnData]: + """ + Read in tables in the tables Zarr.group of a SpatialData Zarr store. + + Parameters + ---------- + zarr_store_path + The path to the Zarr store. + group + The parent group containing the subgroup. + subgroup + The subgroup containing the tables. + tables + A dictionary of tables. + + Returns + ------- + The modified dictionary with the tables. + """ + count = 0 + for table_name in subgroup: + f_elem = subgroup[table_name] + f_elem_store = os.path.join(zarr_store_path, f_elem.path) + if isinstance(group.store, zarr.storage.ConsolidatedMetadataStore): + tables[table_name] = read_elem(f_elem) + # we can replace read_elem with read_anndata_zarr after this PR gets into a release (>= 0.6.5) + # https://github.com/scverse/anndata/pull/1057#pullrequestreview-1530623183 + # table = read_anndata_zarr(f_elem) + else: + tables[table_name] = read_anndata_zarr(f_elem_store) + if TableModel.ATTRS_KEY in tables[table_name].uns: + # fill out eventual missing attributes that has been omitted because their value was None + attrs = tables[table_name].uns[TableModel.ATTRS_KEY] + if "region" not in attrs: + attrs["region"] = None + if "region_key" not in attrs: + attrs["region_key"] = None + if "instance_key" not in attrs: + attrs["instance_key"] = None + # fix type for region + if "region" in attrs and isinstance(attrs["region"], np.ndarray): + attrs["region"] = attrs["region"].tolist() + + count += 1 + + logger.debug(f"Found {count} elements in {subgroup}") + return tables diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index f82b74f3..75ebab8d 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -3,7 +3,7 @@ from anndata import AnnData from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from ome_zarr.format import CurrentFormat -from pandas.api.types import is_categorical_dtype +from pandas.api.types import CategoricalDtype from shapely import GeometryType from spatial_image import SpatialImage @@ -166,7 +166,7 @@ def validate_table( ) -> None: if not isinstance(table, AnnData): raise TypeError(f"`table` must be `anndata.AnnData`, was {type(table)}.") - if region_key is not None and not is_categorical_dtype(table.obs[region_key]): + if region_key is not None and not isinstance(table.obs[region_key].dtype, CategoricalDtype): raise ValueError( f"`table.obs[region_key]` must be of type `categorical`, not `{type(table.obs[region_key])}`." ) diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 72ae5f4c..ead604af 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -4,6 +4,7 @@ from ome_zarr.format import Format from spatialdata._io.format import CurrentTablesFormat +from spatialdata.models import TableModel def write_table( @@ -13,10 +14,13 @@ def write_table( group_type: str = "ngff:regions_table", fmt: Format = CurrentTablesFormat(), ) -> None: - region = table.uns["spatialdata_attrs"]["region"] - region_key = table.uns["spatialdata_attrs"].get("region_key", None) - instance_key = table.uns["spatialdata_attrs"].get("instance_key", None) - fmt.validate_table(table, region_key, instance_key) + if TableModel.ATTRS_KEY in table.uns: + region = table.uns["spatialdata_attrs"]["region"] + region_key = table.uns["spatialdata_attrs"].get("region_key", None) + instance_key = table.uns["spatialdata_attrs"].get("instance_key", None) + fmt.validate_table(table, region_key, instance_key) + else: + region, region_key, instance_key = (None, None, None) write_adata(group, name, table) # creates group[name] tables_group = group[name] tables_group.attrs["spatialdata-encoding-type"] = group_type diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 94a86c04..f5e378a2 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,21 +1,18 @@ import logging import os +import warnings from pathlib import Path from typing import Optional, Union -import numpy as np import zarr from anndata import AnnData -from anndata import read_zarr as read_anndata_zarr -from anndata.experimental import read_elem from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import ome_zarr_logger +from spatialdata._io._utils import ome_zarr_logger, read_table_and_validate from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale from spatialdata._io.io_shapes import _read_shapes from spatialdata._logging import logger -from spatialdata.models import TableModel def _open_zarr_store(store: Union[str, Path, zarr.Group]) -> tuple[zarr.Group, str]: @@ -61,10 +58,11 @@ def read_zarr(store: Union[str, Path, zarr.Group], selection: Optional[tuple[str images = {} labels = {} points = {} - table: Optional[AnnData] = None + tables: dict[str, AnnData] = {} shapes = {} - selector = {"images", "labels", "points", "shapes", "table"} if not selection else set(selection or []) + # TODO: remove table once deprecated. + selector = {"images", "labels", "points", "shapes", "tables", "table"} if not selection else set(selection or []) logger.debug(f"Reading selection {selector}") # read multiscale images @@ -123,36 +121,21 @@ def read_zarr(store: Union[str, Path, zarr.Group], selection: Optional[tuple[str shapes[subgroup_name] = _read_shapes(f_elem_store) count += 1 logger.debug(f"Found {count} elements in {group}") + if "tables" in selector and "tables" in f: + group = f["tables"] + tables = read_table_and_validate(f_store_path, f, group, tables) if "table" in selector and "table" in f: - group = f["table"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) - if isinstance(f.store, zarr.storage.ConsolidatedMetadataStore): - table = read_elem(f_elem) - # we can replace read_elem with read_anndata_zarr after this PR gets into a release (>= 0.6.5) - # https://github.com/scverse/anndata/pull/1057#pullrequestreview-1530623183 - # table = read_anndata_zarr(f_elem) - else: - table = read_anndata_zarr(f_elem_store) - if TableModel.ATTRS_KEY in table.uns: - # fill out eventual missing attributes that has been omitted because their value was None - attrs = table.uns[TableModel.ATTRS_KEY] - if "region" not in attrs: - attrs["region"] = None - if "region_key" not in attrs: - attrs["region_key"] = None - if "instance_key" not in attrs: - attrs["instance_key"] = None - # fix type for region - if "region" in attrs and isinstance(attrs["region"], np.ndarray): - attrs["region"] = attrs["region"].tolist() - count += 1 + warnings.warn( + f"Table group found in zarr store at location {f_store_path}. Please update the zarr store" + f"to use tables instead.", + DeprecationWarning, + stacklevel=2, + ) + subgroup_name = "table" + group = f[subgroup_name] + tables = read_table_and_validate(f_store_path, f, group, tables) + logger.debug(f"Found {count} elements in {group}") sdata = SpatialData( @@ -160,7 +143,7 @@ def read_zarr(store: Union[str, Path, zarr.Group], selection: Optional[tuple[str labels=labels, points=points, shapes=shapes, - table=table, + tables=tables, ) sdata._path = Path(store) return sdata diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index dc360509..0e59d63d 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -1,16 +1,16 @@ from __future__ import annotations +import functools import re +import warnings from collections.abc import Generator -from copy import deepcopy -from typing import Union +from typing import Any, Callable, TypeVar, Union import numpy as np import pandas as pd from anndata import AnnData from dask import array as da from datatree import DataTree -from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage from xarray import DataArray @@ -25,6 +25,7 @@ # I was using "from numbers import Number" but this led to mypy errors, so I switched to the following: Number = Union[int, float] +RT = TypeVar("RT") def _parse_list_into_array(array: list[Number] | ArrayLike) -> ArrayLike: @@ -159,7 +160,10 @@ def multiscale_spatial_image_from_data_tree(data_tree: DataTree) -> MultiscaleSp assert len(v) == 1 xdata = v.__iter__().__next__() d[k] = xdata + # this stopped working, we should add support for multiscale_spatial_image 1.0.0 so that the problem is solved return MultiscaleSpatialImage.from_dict(d) + # data_tree.__class__ = MultiscaleSpatialImage + # return cast(MultiscaleSpatialImage, data_tree) # TODO: this functions is similar to _iter_multiscale(), the latter is more powerful but not exposed to the user. @@ -210,24 +214,67 @@ def _inplace_fix_subset_categorical_obs(subset_adata: AnnData, original_adata: A subset_adata.obs = obs -def _deepcopy_geodataframe(gdf: GeoDataFrame) -> GeoDataFrame: +# TODO: change to paramspec as soon as we drop support for python 3.9, see https://stackoverflow.com/a/68290080 +def deprecation_alias(**aliases: str) -> Callable[[Callable[..., RT]], Callable[..., RT]]: """ - temporary fix for https://github.com/scverse/spatialdata/issues/286. + Decorate a function to warn user of use of arguments set for deprecation. Parameters ---------- - gdf - The GeoDataFrame to deepcopy + aliases + Deprecation argument aliases to be mapped to the new arguments. Returns ------- - A deepcopy of the GeoDataFrame + A decorator that can be used to mark an argument for deprecation and substituting it with the new argument. + + Raises + ------ + TypeError + If the provided aliases are not of string type. + + Example + ------- + Assuming we have an argument 'table' set for deprecation and we want to warn the user and substitute with 'tables': + + ```python + @deprecation_alias(table="tables") + def my_function(tables: AnnData | dict[str, AnnData]): + pass + ``` """ - # - new_gdf = deepcopy(gdf) - new_attrs = deepcopy(gdf.attrs) - new_gdf.attrs = new_attrs - return new_gdf + + def deprecation_decorator(f: Callable[..., RT]) -> Callable[..., RT]: + @functools.wraps(f) + def wrapper(*args: Any, **kwargs: Any) -> RT: + class_name = f.__qualname__ + rename_kwargs(f.__name__, kwargs, aliases, class_name) + return f(*args, **kwargs) + + return wrapper + + return deprecation_decorator + + +def rename_kwargs(func_name: str, kwargs: dict[str, Any], aliases: dict[str, str], class_name: None | str) -> None: + """Rename function arguments set for deprecation and gives warning in case of usage of these arguments.""" + for alias, new in aliases.items(): + if alias in kwargs: + class_name = class_name + "." if class_name else "" + if new in kwargs: + raise TypeError( + f"{class_name}{func_name} received both {alias} and {new} as arguments!" + f" {alias} is being deprecated in SpatialData version 0.1, only use {new} instead." + ) + warnings.warn( + message=( + f"`{alias}` is being deprecated as an argument to `{class_name}{func_name}` in SpatialData " + f"version 0.1, switch to `{new}` instead." + ), + category=DeprecationWarning, + stacklevel=3, + ) + kwargs[new] = kwargs.pop(alias) def _error_message_add_element() -> None: diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 388db612..66fc5b4c 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -130,9 +130,9 @@ def _validate( regions_to_coordinate_systems: dict[str, str], ) -> None: """Validate input parameters.""" - self._region_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] - self._instance_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] - available_regions = sdata.table.obs[self._region_key].cat.categories + self._region_key = sdata.tables["table"].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] + self._instance_key = sdata.tables["table"].uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] + available_regions = sdata.tables["table"].obs[self._region_key].cat.categories cs_region_image = [] # list of tuples (coordinate_system, region, image) # check unique matching between regions and images and coordinate systems @@ -173,8 +173,8 @@ def _validate( self.regions = list(regions_to_coordinate_systems.keys()) # all regions for the dataloader self.sdata = sdata - self.dataset_table = self.sdata.table[ - self.sdata.table.obs[self._region_key].isin(self.regions) + self.dataset_table = self.sdata.tables["table"][ + self.sdata.tables["table"].obs[self._region_key].isin(self.regions) ] # filtered table for the data loader self._cs_region_image = tuple(cs_region_image) # tuple of tuples (coordinate_system, region_key, image_key) @@ -188,7 +188,7 @@ def _preprocess( tile_coords_df = [] dims_l = [] shapes_l = [] - + table = self.sdata.tables["table"] for cs, region, image in self._cs_region_image: # get dims and transformations for the region element dims = get_axes_names(self.sdata[region]) @@ -197,7 +197,7 @@ def _preprocess( assert isinstance(t, BaseTransformation) # get instances from region - inst = self.sdata.table.obs[self.sdata.table.obs[self._region_key] == region][self._instance_key].values + inst = table.obs[table.obs[self._region_key] == region][self._instance_key].values # subset the regions by instances subset_region = self.sdata[region].iloc[inst] @@ -219,7 +219,7 @@ def _preprocess( self.dataset_index = pd.concat(index_df).reset_index(drop=True) self.tiles_coords = pd.concat(tile_coords_df).reset_index(drop=True) # get table filtered by regions - self.filtered_table = self.sdata.table.obs[self.sdata.table.obs[self._region_key].isin(self.regions)] + self.filtered_table = table.obs[table.obs[self._region_key].isin(self.regions)] assert len(self.tiles_coords) == len(self.dataset_index) dims_ = set(chain(*dims_l)) diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index 43f6cef7..546a0927 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -86,7 +86,7 @@ def raccoon( self, ) -> SpatialData: """Raccoon dataset.""" - im_data = scipy.misc.face() + im_data = scipy.datasets.face() im = Image2DModel.parse(im_data, dims=["y", "x", "c"]) labels_data = slic(im_data, n_segments=100, compactness=10, sigma=1) labels = Labels2DModel.parse(labels_data, dims=["y", "x"]) @@ -154,7 +154,7 @@ def blobs( circles = self._circles_blobs(self.transformations, self.length, self.n_shapes) polygons = self._polygons_blobs(self.transformations, self.length, self.n_shapes) multipolygons = self._polygons_blobs(self.transformations, self.length, self.n_shapes, multipolygons=True) - adata = aggregate(values=image, by=labels).table + adata = aggregate(values=image, by=labels).tables["table"] adata.obs["region"] = pd.Categorical(["blobs_labels"] * len(adata)) adata.obs["instance_id"] = adata.obs_names.astype(int) del adata.uns[TableModel.ATTRS_KEY] @@ -165,7 +165,7 @@ def blobs( labels={"blobs_labels": labels, "blobs_multiscale_labels": multiscale_labels}, points={"blobs_points": points}, shapes={"blobs_circles": circles, "blobs_polygons": polygons, "blobs_multipolygons": multipolygons}, - table=table, + tables=table, ) def _image_blobs( @@ -241,7 +241,7 @@ def _points_blobs( arr = rng.integers(padding, length - padding, size=(n_points, 2)).astype(np.int64) # randomly assign some values from v to the points points_assignment0 = rng.integers(0, 10, size=arr.shape[0]).astype(np.int64) - genes = rng.choice(["a", "b"], size=arr.shape[0]) + genes = rng.choice(["gene_a", "gene_b"], size=arr.shape[0]) annotation = pd.DataFrame( { "genes": genes, diff --git a/src/spatialdata/models/__init__.py b/src/spatialdata/models/__init__.py index 9a6cf64b..df370e4a 100644 --- a/src/spatialdata/models/__init__.py +++ b/src/spatialdata/models/__init__.py @@ -21,7 +21,9 @@ PointsModel, ShapesModel, TableModel, + check_target_region_column_symmetry, get_model, + get_table_keys, ) __all__ = [ @@ -44,4 +46,6 @@ "get_axes_names", "points_geopandas_to_dask_dataframe", "points_dask_dataframe_to_geopandas", + "check_target_region_column_symmetry", + "get_table_keys", ] diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 0daeb40a..c7234175 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -471,7 +471,8 @@ def validate(cls, data: DaskDataFrame) -> None: """ for ax in [X, Y, Z]: if ax in data.columns: - assert data[ax].dtype in [np.float32, np.float64, np.int64] + # TODO: check why this can return int32 on windows. + assert data[ax].dtype in [np.int32, np.float32, np.float64, np.int64] if cls.TRANSFORM_KEY not in data.attrs: raise ValueError(f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`.") if cls.ATTRS_KEY in data.attrs and "feature_key" in data.attrs[cls.ATTRS_KEY]: @@ -517,6 +518,11 @@ def parse(cls, data: Any, **kwargs: Any) -> DaskDataFrame: Returns ------- :class:`dask.dataframe.core.DataFrame` + + Notes + ----- + The order of the columns of the dataframe returned by the parser is not guaranteed to be the same as the order + of the columns in the dataframe passed as an argument. """ raise NotImplementedError() @@ -607,6 +613,15 @@ def _( logger.info(f"Column `{Z}` in `data` will be ignored since the data is 2D.") for c in set(data.columns) - {feature_key, instance_key, *coordinates.values(), X, Y, Z}: table[c] = data[c] + + # when `coordinates` is None, and no columns have been added or removed, preserves the original order + # here I tried to fix https://github.com/scverse/spatialdata/issues/486, didn't work + # old_columns = list(data.columns) + # new_columns = list(table.columns) + # if new_columns == set(old_columns) and new_columns != old_columns: + # col_order = [col for col in old_columns if col in new_columns] + # table = table[col_order] + return cls._add_metadata_and_validate( table, feature_key=feature_key, instance_key=instance_key, transformations=transformations ) @@ -651,24 +666,121 @@ class TableModel: REGION_KEY_KEY = "region_key" INSTANCE_KEY = "instance_key" - def validate( - self, - data: AnnData, - ) -> AnnData: + def _validate_set_region_key(self, data: AnnData, region_key: str | None = None) -> None: """ - Validate the data. + Validate the region key in table.uns or set a new region key as the region key column. Parameters ---------- data - The data to validate. + The AnnData table. + region_key + The region key to be validated and set in table.uns. + + + Raises + ------ + ValueError + If no region_key is found in table.uns and no region_key is provided as an argument. + ValueError + If the specified region_key in table.uns is not present as a column in table.obs. + ValueError + If the specified region key column is not present in table.obs. + """ + attrs = data.uns.get(self.ATTRS_KEY) + if attrs is None: + data.uns[self.ATTRS_KEY] = attrs = {} + table_region_key = attrs.get(self.REGION_KEY_KEY) + if not region_key: + if not table_region_key: + raise ValueError( + "No region_key in table.uns and no region_key provided as argument. Please specify 'region_key'." + ) + if data.obs.get(attrs[TableModel.REGION_KEY_KEY]) is None: + raise ValueError( + f"Specified region_key in table.uns '{table_region_key}' is not " + f"present as column in table.obs. Please specify region_key." + ) + else: + if region_key not in data.obs: + raise ValueError(f"'{region_key}' column not present in table.obs") + attrs[self.REGION_KEY_KEY] = region_key + + def _validate_set_instance_key(self, data: AnnData, instance_key: str | None = None) -> None: + """ + Validate the instance_key in table.uns or set a new instance_key as the instance_key column. + + If no instance_key is provided as argument, the presence of instance_key in table.uns is checked and validated. + If instance_key is provided, presence in table.obs will be validated and if present it will be set as the new + instance_key in table.uns. + + Parameters + ---------- + data + The AnnData table. + + instance_key + The instance_key to be validated and set in table.uns. + + Raises + ------ + ValueError + If no instance_key is provided as argument and no instance_key is found in the `uns` attribute of table. + ValueError + If no instance_key is provided and the instance_key in table.uns does not match any column in table.obs. + ValueError + If provided instance_key is not present as table.obs column. + """ + attrs = data.uns.get(self.ATTRS_KEY) + if attrs is None: + data.uns[self.ATTRS_KEY] = {} + + if not instance_key: + if not attrs.get(TableModel.INSTANCE_KEY): + raise ValueError( + "No instance_key in table.uns and no instance_key provided as argument. Please " + "specify instance_key." + ) + if data.obs.get(attrs[self.INSTANCE_KEY]) is None: + raise ValueError( + f"Specified instance_key in table.uns '{attrs.get(self.INSTANCE_KEY)}' is not present" + f" as column in table.obs. Please specify instance_key." + ) + if instance_key: + if instance_key in data.obs: + attrs[self.INSTANCE_KEY] = instance_key + else: + raise ValueError(f"Instance key column '{instance_key}' not found in table.obs.") + + def _validate_table_annotation_metadata(self, data: AnnData) -> None: + """ + Validate annotation metadata. + + Parameters + ---------- + data + The AnnData object containing the table annotation data. + + Raises + ------ + ValueError + If any of the required metadata keys are not found in the `adata.uns` dictionary or the `adata.obs` + dataframe. + + - If "region" is not found in `adata.uns['ATTRS_KEY']`. + - If "region_key" is not found in `adata.uns['ATTRS_KEY']`. + - If "instance_key" is not found in `adata.uns['ATTRS_KEY']`. + - If `attr[self.REGION_KEY_KEY]` is not found in `adata.obs`, with attr = adata.uns['ATTRS_KEY'] + - If `attr[self.INSTANCE_KEY]` is not found in `adata.obs`. + - If the regions in `adata.uns['ATTRS_KEY']['self.REGION_KEY']` and the unique values of + `attr[self.REGION_KEY_KEY]` do not match. + + Notes + ----- + This does not check whether the annotation target of the table is present in a given SpatialData object. Rather + it is an internal validation of the annotation metadata of the table. - Returns - ------- - The validated data. """ - if self.ATTRS_KEY not in data.uns: - raise ValueError(f"`{self.ATTRS_KEY}` not found in `adata.uns`.") attr = data.uns[self.ATTRS_KEY] if "region" not in attr: @@ -682,11 +794,37 @@ def validate( raise ValueError(f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`.") if attr[self.INSTANCE_KEY] not in data.obs: raise ValueError(f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`.") + if (dtype := data.obs[attr[self.INSTANCE_KEY]].dtype) not in [np.int16, np.int32, np.int64, str]: + raise TypeError( + f"Only np.int16, np.int32, np.int64 or string allowed as dtype for " + f"instance_key column in obs. Dtype found to be {dtype}" + ) expected_regions = attr[self.REGION_KEY] if isinstance(attr[self.REGION_KEY], list) else [attr[self.REGION_KEY]] found_regions = data.obs[attr[self.REGION_KEY_KEY]].unique().tolist() if len(set(expected_regions).symmetric_difference(set(found_regions))) > 0: raise ValueError(f"Regions in the AnnData object and `{attr[self.REGION_KEY_KEY]}` do not match.") + def validate( + self, + data: AnnData, + ) -> AnnData: + """ + Validate the data. + + Parameters + ---------- + data + The data to validate. + + Returns + ------- + The validated data. + """ + if self.ATTRS_KEY not in data.uns: + return data + + self._validate_table_annotation_metadata(data) + return data @classmethod @@ -713,15 +851,17 @@ def parse( Returns ------- - :class:`anndata.AnnData`. + The parsed data. """ # either all live in adata.uns or all be passed in as argument n_args = sum([region is not None, region_key is not None, instance_key is not None]) + if n_args == 0: + return adata if n_args > 0: if cls.ATTRS_KEY in adata.uns: raise ValueError( - f"Either pass `{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and `{cls.INSTANCE_KEY}`" - f"as arguments or have them in `adata.uns[{cls.ATTRS_KEY!r}]`." + f"`{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and / or `{cls.INSTANCE_KEY}` is/has been passed as" + f"as argument(s). However, `adata.uns[{cls.ATTRS_KEY!r}]` has already been set." ) elif cls.ATTRS_KEY in adata.uns: attr = adata.uns[cls.ATTRS_KEY] @@ -746,8 +886,18 @@ def parse( if instance_key is None: raise ValueError("`instance_key` must be provided.") + grouped = adata.obs.groupby(region_key, observed=True) + grouped_size = grouped.size() + grouped_nunique = grouped.nunique() + not_unique = grouped_size[grouped_size != grouped_nunique[instance_key]].index.tolist() + if not_unique: + raise ValueError( + f"Instance key column for region(s) `{', '.join(not_unique)}` does not contain only unique integers" + ) + attr = {"region": region, "region_key": region_key, "instance_key": instance_key} adata.uns[cls.ATTRS_KEY] = attr + cls().validate(adata) return adata @@ -801,3 +951,73 @@ def _validate_and_return( if isinstance(e, AnnData): return _validate_and_return(TableModel, e) raise TypeError(f"Unsupported type {type(e)}") + + +def get_table_keys(table: AnnData) -> tuple[str | list[str], str, str]: + """ + Get the table keys giving information about what spatial element is annotated. + + The first element returned gives information regarding which spatial elements are annotated by the table, the second + element gives information which column in table.obs contains the information which spatial element is annotated + by each row in the table and the instance key indicates the column in obs giving information of the id of each row. + + Parameters + ---------- + table: + AnnData table for which to retrieve the spatialdata_attrs keys. + + Returns + ------- + The keys in table.uns['spatialdata_attrs'] + """ + if table.uns.get(TableModel.ATTRS_KEY): + attrs = table.uns[TableModel.ATTRS_KEY] + return attrs[TableModel.REGION_KEY], attrs[TableModel.REGION_KEY_KEY], attrs[TableModel.INSTANCE_KEY] + + raise ValueError( + "No spatialdata_attrs key found in table.uns, therefore, no table keys found. Please parse the table." + ) + + +def check_target_region_column_symmetry(table: AnnData, region_key: str, target: str | pd.Series) -> None: + """ + Check region and region_key column symmetry. + + This checks whether the specified targets are also present in the region key column in obs and raises an error + if this is not the case. + + Parameters + ---------- + table + Table annotating specific SpatialElements + region_key + The column in obs containing for each row which SpatialElement is annotated by that row. + target + Name of target(s) SpatialElement(s) + + Raises + ------ + ValueError + If there is a mismatch between specified target regions and regions in the region key column of table.obs. + + Example + ------- + Assuming we have a table with region column in obs given by `region_key` called 'region' for which we want to check + whether it contains the specified annotation targets in the `target` variable as `pd.Series['region1', 'region2']`: + + ```python + check_target_region_column_symmetry(table, region_key=region_key, target=target) + ``` + + This returns None if both specified targets are present in the region_key obs column. In this case the annotation + targets can be safely set. If not then a ValueError is raised stating the elements that are not shared between + the region_key column in obs and the specified targets. + """ + found_regions = set(table.obs[region_key].unique().tolist()) + target_element_set = [target] if isinstance(target, str) else target + symmetric_difference = found_regions.symmetric_difference(target_element_set) + if symmetric_difference: + raise ValueError( + f"Mismatch(es) found between regions in region column in obs and target element: " + f"{', '.join(diff for diff in symmetric_difference)}" + ) diff --git a/src/spatialdata/testing.py b/src/spatialdata/testing.py new file mode 100644 index 00000000..16f155bd --- /dev/null +++ b/src/spatialdata/testing.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from anndata import AnnData +from anndata.tests.helpers import assert_equal as assert_anndata_equal +from dask.dataframe import DataFrame as DaskDataFrame +from dask.dataframe.tests.test_dataframe import assert_eq as assert_dask_dataframe_equal +from datatree.testing import assert_equal as assert_datatree_equal +from geopandas import GeoDataFrame +from geopandas.testing import assert_geodataframe_equal +from multiscale_spatial_image import MultiscaleSpatialImage +from spatial_image import SpatialImage +from xarray.testing import assert_equal as assert_xarray_equal + +from spatialdata import SpatialData +from spatialdata._core._elements import Elements +from spatialdata.models._utils import SpatialElement +from spatialdata.transformations.operations import get_transformation + + +def assert_elements_dict_are_identical( + elements0: Elements, elements1: Elements, check_transformations: bool = True +) -> None: + """ + Compare two dictionaries of elements and assert that they are identical (except for the order of the keys). + + The dictionaries of elements can be obtained from a SpatialData object using the `.shapes`, `.labels`, `.points`, + `.images` and `.tables` properties. + + Parameters + ---------- + elements0 + The first dictionary of elements. + elements1 + The second dictionary of elements. + + Returns + ------- + None + + Raises + ------ + AssertionError + If the two dictionaries of elements are not identical. + + Notes + ----- + With the current implementation, the transformations Translate([1.0, 2.0], + axes=('x', 'y')) and Translate([2.0, 1.0], axes=('y', 'x')) are considered different. + A quick way to avoid an error in this case is to use the check_transformations=False parameter. + """ + assert set(elements0.keys()) == set(elements1.keys()) + for k in elements0: + element0 = elements0[k] + element1 = elements1[k] + assert_elements_are_identical(element0, element1, check_transformations=check_transformations) + + +def assert_elements_are_identical( + element0: SpatialElement | AnnData, element1: SpatialElement | AnnData, check_transformations: bool = True +) -> None: + """ + Compare two elements (two SpatialElements or two tables) and assert that they are identical. + + Parameters + ---------- + element0 + The first element. + element1 + The second element. + check_transformations + Whether to check if the transformations are identical, for each element. + + Returns + ------- + None + + Raises + ------ + AssertionError + If the two elements are not identical. + + Notes + ----- + With the current implementation, the transformations Translate([1.0, 2.0], + axes=('x', 'y')) and Translate([2.0, 1.0], axes=('y', 'x')) are considered different. + A quick way to avoid an error in this case is to use the check_transformations=False parameter. + """ + assert type(element0) == type(element1) + + # compare transformations (only for SpatialElements) + if not isinstance(element0, AnnData): + transformations0 = get_transformation(element0, get_all=True) + transformations1 = get_transformation(element1, get_all=True) + assert isinstance(transformations0, dict) + assert isinstance(transformations1, dict) + if check_transformations: + assert transformations0.keys() == transformations1.keys() + for key in transformations0: + assert ( + transformations0[key] == transformations1[key] + ), f"transformations0[{key}] != transformations1[{key}]" + + # compare the elements + if isinstance(element0, AnnData): + assert_anndata_equal(element0, element1) + elif isinstance(element0, SpatialImage): + assert_xarray_equal(element0, element1) + elif isinstance(element0, MultiscaleSpatialImage): + assert_datatree_equal(element0, element1) + elif isinstance(element0, GeoDataFrame): + assert_geodataframe_equal(element0, element1, check_less_precise=True) + else: + assert isinstance(element0, DaskDataFrame) + assert_dask_dataframe_equal(element0, element1) + + +def assert_spatial_data_objects_are_identical( + sdata0: SpatialData, sdata1: SpatialData, check_transformations: bool = True +) -> None: + """ + Compare two SpatialData objects and assert that they are identical. + + Parameters + ---------- + sdata0 + The first SpatialData object. + sdata1 + The second SpatialData object. + check_transformations + Whether to check if the transformations are identical, for each element. + + Returns + ------- + None + + Raises + ------ + AssertionError + If the two SpatialData objects are not identical. + + Notes + ----- + With the current implementation, the transformations Translate([1.0, 2.0], + axes=('x', 'y')) and Translate([2.0, 1.0], axes=('y', 'x')) are considered different. + A quick way to avoid an error in this case is to use the check_transformations=False parameter. + """ + # this is not a full comparison, but it's fine anyway + element_names0 = [element_name for _, element_name, _ in sdata0.gen_elements()] + element_names1 = [element_name for _, element_name, _ in sdata1.gen_elements()] + assert len(set(element_names0)) == len(element_names0) + assert len(set(element_names1)) == len(element_names1) + assert set(sdata0.coordinate_systems) == set(sdata1.coordinate_systems) + for element_name in element_names0: + element0 = sdata0[element_name] + element1 = sdata1[element_name] + assert_elements_are_identical(element0, element1, check_transformations=check_transformations) diff --git a/src/spatialdata/transformations/operations.py b/src/spatialdata/transformations/operations.py index d639e025..9206d3e3 100644 --- a/src/spatialdata/transformations/operations.py +++ b/src/spatialdata/transformations/operations.py @@ -16,8 +16,8 @@ if TYPE_CHECKING: from spatialdata._core.spatialdata import SpatialData - from spatialdata.models import SpatialElement - from spatialdata.transformations import Affine, BaseTransformation + from spatialdata.models._utils import SpatialElement + from spatialdata.transformations.transformations import Affine, BaseTransformation def set_transformation( @@ -180,7 +180,7 @@ def remove_transformation( def _build_transformations_graph(sdata: SpatialData) -> nx.Graph: g = nx.DiGraph() - gen = sdata._gen_elements_values() + gen = sdata._gen_spatial_element_values() for cs in sdata.coordinate_systems: g.add_node(cs) for e in gen: @@ -329,7 +329,7 @@ def get_transformation_between_landmarks( example on how to call this function on two sets of numpy arrays describing x, y coordinates. >>> import numpy as np >>> from spatialdata.models import PointsModel - >>> from spatialdata.transform import get_transformation_between_landmarks + >>> from spatialdata.transformations import get_transformation_between_landmarks >>> points_moving = np.array([[0, 0], [1, 1], [2, 2]]) >>> points_reference = np.array([[0, 0], [10, 10], [20, 20]]) >>> moving_coords = PointsModel(points_moving) @@ -473,5 +473,5 @@ def remove_transformations_to_coordinate_system(sdata: SpatialData, coordinate_s coordinate_system The coordinate system to remove the transformations from. """ - for element in sdata._gen_elements_values(): + for element in sdata._gen_spatial_element_values(): remove_transformation(element, coordinate_system) diff --git a/tests/conftest.py b/tests/conftest.py index 8e95a0cf..66f65e0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,8 @@ # isort: off import os +from typing import Any +from collections.abc import Sequence os.environ["USE_PYGEOS"] = "0" # isort:on @@ -33,7 +35,7 @@ from spatialdata.datasets import BlobsDataset import geopandas as gpd import dask.dataframe as dd -from spatialdata._utils import _deepcopy_geodataframe +from spatialdata._core._deepcopy import deepcopy as _deepcopy RNG = default_rng(seed=0) @@ -64,12 +66,12 @@ def points() -> SpatialData: @pytest.fixture() def table_single_annotation() -> SpatialData: - return SpatialData(table=_get_table(region="sample1")) + return SpatialData(tables=_get_table(region="labels2d")) @pytest.fixture() def table_multiple_annotations() -> SpatialData: - return SpatialData(table=_get_table(region=["sample1", "sample2"])) + return SpatialData(table=_get_table(region=["labels2d", "poly"])) @pytest.fixture() @@ -91,7 +93,7 @@ def full_sdata() -> SpatialData: labels=_get_labels(), shapes=_get_shapes(), points=_get_points(), - table=_get_table(region="sample1"), + tables=_get_table(region="labels2d"), ) @@ -126,7 +128,7 @@ def sdata(request) -> SpatialData: labels=_get_labels(), shapes=_get_shapes(), points=_get_points(), - table=_get_table("sample1"), + tables=_get_table("labels2d"), ) if request.param == "empty": return SpatialData() @@ -244,7 +246,7 @@ def _get_shapes() -> dict[str, GeoDataFrame]: points["radius"] = np.abs(rng.normal(size=(len(points), 1))) out["poly"] = ShapesModel.parse(poly) - out["poly"].index = ["a", "b", "c", "d", "e"] + out["poly"].index = [0, 1, 2, 3, 4] out["multipoly"] = ShapesModel.parse(multipoly) out["circles"] = ShapesModel.parse(points) @@ -277,11 +279,13 @@ def _get_points() -> dict[str, DaskDataFrame]: def _get_table( - region: str | list[str] = "sample1", - region_key: str = "region", - instance_key: str = "instance_id", + region: None | str | list[str] = "sample1", + region_key: None | str = "region", + instance_key: None | str = "instance_id", ) -> AnnData: adata = AnnData(RNG.normal(size=(100, 10)), obs=pd.DataFrame(RNG.normal(size=(100, 3)), columns=["a", "b", "c"])) + if not all(var for var in (region, region_key, instance_key)): + return TableModel.parse(adata=adata) adata.obs[instance_key] = np.arange(adata.n_obs) if isinstance(region, str): adata.obs[region_key] = region @@ -290,6 +294,11 @@ def _get_table( return TableModel.parse(adata=adata, region=region, region_key=region_key, instance_key=instance_key) +def _get_new_table(spatial_element: None | str | Sequence[str], instance_id: None | Sequence[Any]) -> AnnData: + adata = AnnData(np.random.default_rng(seed=0).random(10, 20000)) + return TableModel.parse(adata=adata, spatial_element=spatial_element, instance_id=instance_id) + + @pytest.fixture() def labels_blobs() -> ArrayLike: """Create a 2D labels.""" @@ -304,7 +313,7 @@ def sdata_blobs() -> SpatialData: sdata = deepcopy(blobs(256, 300, 3)) for k, v in sdata.shapes.items(): - sdata.shapes[k] = _deepcopy_geodataframe(v) + sdata.shapes[k] = _deepcopy(v) from spatialdata._utils import multiscale_spatial_image_from_data_tree sdata.images["blobs_multiscale_image"] = multiscale_spatial_image_from_data_tree( diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 36609464..fcbf89b9 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import Optional import geopandas @@ -10,9 +9,9 @@ from geopandas import GeoDataFrame from numpy.random import default_rng from spatialdata import aggregate +from spatialdata._core._deepcopy import deepcopy as _deepcopy from spatialdata._core.query._utils import circles_to_polygons from spatialdata._core.spatialdata import SpatialData -from spatialdata._utils import _deepcopy_geodataframe from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, TableModel from spatialdata.transformations import Affine, Identity, set_transformation @@ -44,10 +43,10 @@ def test_aggregate_points_by_shapes(sdata_query_aggregation, by_shapes: str, val shapes = sdata[by_shapes] # testing that we can call aggregate with the two equivalent syntaxes for the values argument - result_adata = aggregate(values=points, by=shapes, value_key=value_key, agg_func="sum").table + result_adata = aggregate(values=points, by=shapes, value_key=value_key, agg_func="sum").tables["table"] result_adata_bis = aggregate( - values_sdata=sdata, values="points", by=shapes, value_key=value_key, agg_func="sum" - ).table + values_sdata=sdata, values="points", by=shapes, value_key=value_key, agg_func="sum", table_name="table" + ).tables["table"] np.testing.assert_equal(result_adata.X.A, result_adata_bis.X.A) # check that the obs of aggregated values are correct @@ -75,12 +74,12 @@ def test_aggregate_points_by_shapes(sdata_query_aggregation, by_shapes: str, val # id_key can be implicit for points points.attrs[PointsModel.ATTRS_KEY][PointsModel.FEATURE_KEY] = value_key - result_adata_implicit = aggregate(values=points, by=shapes, agg_func="sum").table + result_adata_implicit = aggregate(values=points, by=shapes, agg_func="sum").tables["table"] assert_equal(result_adata, result_adata_implicit) # in the categorical case, check that sum and count behave the same if value_key == "categorical_in_ddf": - result_adata_count = aggregate(values=points, by=shapes, value_key=value_key, agg_func="count").table + result_adata_count = aggregate(values=points, by=shapes, value_key=value_key, agg_func="count").tables["table"] assert_equal(result_adata, result_adata_count) # querying multiple values at the same time @@ -91,7 +90,9 @@ def test_aggregate_points_by_shapes(sdata_query_aggregation, by_shapes: str, val aggregate(values=points, by=shapes, value_key=new_value_key, agg_func="sum") else: points["another_" + value_key] = points[value_key] + 10 - result_adata_multiple = aggregate(values=points, by=shapes, value_key=new_value_key, agg_func="sum").table + result_adata_multiple = aggregate(values=points, by=shapes, value_key=new_value_key, agg_func="sum").tables[ + "table" + ] assert result_adata_multiple.var_names.to_list() == new_value_key if by_shapes == "by_circles": row = ( @@ -124,6 +125,7 @@ def test_aggregate_points_by_shapes(sdata_query_aggregation, by_shapes: str, val ) +# TODO: refactor in smaller functions for easier understanding @pytest.mark.parametrize("by_shapes", ["by_circles", "by_polygons"]) @pytest.mark.parametrize("values_shapes", ["values_circles", "values_polygons"]) @pytest.mark.parametrize( @@ -143,12 +145,14 @@ def test_aggregate_shapes_by_shapes( by = _parse_shapes(sdata, by_shapes=by_shapes) values = _parse_shapes(sdata, values_shapes=values_shapes) - result_adata = aggregate(values_sdata=sdata, values=values_shapes, by=by, value_key=value_key, agg_func="sum").table + result_adata = aggregate( + values_sdata=sdata, values=values_shapes, by=by, value_key=value_key, agg_func="sum", table_name="table" + ).tables["table"] # testing that we can call aggregate with the two equivalent syntaxes for the values argument (only relevant when # the values to aggregate are not in the table, for which only one of the two syntaxes is possible) if value_key.endswith("_in_gdf"): - result_adata_bis = aggregate(values=values, by=by, value_key=value_key, agg_func="sum").table + result_adata_bis = aggregate(values=values, by=by, value_key=value_key, agg_func="sum").tables["table"] np.testing.assert_equal(result_adata.X.A, result_adata_bis.X.A) # check that the obs of the aggregated values are correct @@ -161,41 +165,41 @@ def test_aggregate_shapes_by_shapes( if value_key == "numerical_in_var": if values_shapes == "values_circles": if by_shapes == "by_circles": - s = sdata.table[np.array([0, 1, 2, 3]), "numerical_in_var"].X.sum() + s = sdata.tables["table"][np.array([0, 1, 2, 3]), "numerical_in_var"].X.sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s], [0]]))) else: - s0 = sdata.table[np.array([5, 6, 7, 8]), "numerical_in_var"].X.sum() + s0 = sdata.tables["table"][np.array([5, 6, 7, 8]), "numerical_in_var"].X.sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s0], [0], [0], [0], [0]]))) else: if by_shapes == "by_circles": - s = sdata.table[np.array([9, 10, 11, 12]), "numerical_in_var"].X.sum() + s = sdata.tables["table"][np.array([9, 10, 11, 12]), "numerical_in_var"].X.sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s], [0]]))) else: - s0 = sdata.table[np.array([14, 15, 16, 17]), "numerical_in_var"].X.sum() - s1 = sdata.table[np.array([20]), "numerical_in_var"].X.sum() - s2 = sdata.table[np.array([20]), "numerical_in_var"].X.sum() + s0 = sdata.tables["table"][np.array([14, 15, 16, 17]), "numerical_in_var"].X.sum() + s1 = sdata.tables["table"][np.array([20]), "numerical_in_var"].X.sum() + s2 = sdata.tables["table"][np.array([20]), "numerical_in_var"].X.sum() s3 = 0 - s4 = sdata.table[np.array([18, 19]), "numerical_in_var"].X.sum() + s4 = sdata.tables["table"][np.array([18, 19]), "numerical_in_var"].X.sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s0], [s1], [s2], [s3], [s4]]))) elif value_key == "numerical_in_obs": # these cases are basically identically to the one above if values_shapes == "values_circles": if by_shapes == "by_circles": - s = sdata.table[np.array([0, 1, 2, 3]), :].obs["numerical_in_obs"].sum() + s = sdata.tables["table"][np.array([0, 1, 2, 3]), :].obs["numerical_in_obs"].sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s], [0]]))) else: - s0 = sdata.table[np.array([5, 6, 7, 8]), :].obs["numerical_in_obs"].sum() + s0 = sdata.tables["table"][np.array([5, 6, 7, 8]), :].obs["numerical_in_obs"].sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s0], [0], [0], [0], [0]]))) else: if by_shapes == "by_circles": - s = sdata.table[np.array([9, 10, 11, 12]), :].obs["numerical_in_obs"].sum() + s = sdata.tables["table"][np.array([9, 10, 11, 12]), :].obs["numerical_in_obs"].sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s], [0]]))) else: - s0 = sdata.table[np.array([14, 15, 16, 17]), :].obs["numerical_in_obs"].sum() - s1 = sdata.table[np.array([20]), :].obs["numerical_in_obs"].sum() - s2 = sdata.table[np.array([20]), :].obs["numerical_in_obs"].sum() + s0 = sdata.tables["table"][np.array([14, 15, 16, 17]), :].obs["numerical_in_obs"].sum() + s1 = sdata.tables["table"][np.array([20]), :].obs["numerical_in_obs"].sum() + s2 = sdata.tables["table"][np.array([20]), :].obs["numerical_in_obs"].sum() s3 = 0 - s4 = sdata.table[np.array([18, 19]), :].obs["numerical_in_obs"].sum() + s4 = sdata.tables["table"][np.array([18, 19]), :].obs["numerical_in_obs"].sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s0], [s1], [s2], [s3], [s4]]))) elif value_key == "numerical_in_gdf": if values_shapes == "values_circles": @@ -250,8 +254,8 @@ def test_aggregate_shapes_by_shapes( # in the categorical case, check that sum and count behave the same if value_key in ["categorical_in_obs", "categorical_in_gdf"]: result_adata_count = aggregate( - values_sdata=sdata, values=values_shapes, by=by, value_key=value_key, agg_func="count" - ).table + values_sdata=sdata, values=values_shapes, by=by, value_key=value_key, agg_func="count", table_name="table" + ).tables["table"] assert_equal(result_adata, result_adata_count) # querying multiple values at the same time @@ -259,23 +263,30 @@ def test_aggregate_shapes_by_shapes( if value_key in ["categorical_in_obs", "categorical_in_gdf"]: # can't aggregate multiple categorical values with pytest.raises(ValueError): - aggregate(values_sdata=sdata, values=values_shapes, by=by, value_key=new_value_key, agg_func="sum") + aggregate( + values_sdata=sdata, + values=values_shapes, + by=by, + value_key=new_value_key, + agg_func="sum", + table_name="table", + ) else: if value_key == "numerical_in_obs": - sdata.table.obs["another_numerical_in_obs"] = 1.0 + sdata.tables["table"].obs["another_numerical_in_obs"] = 1.0 elif value_key == "numerical_in_gdf": values["another_numerical_in_gdf"] = 1.0 else: assert value_key == "numerical_in_var" - new_var = pd.concat((sdata.table.var, pd.DataFrame(index=["another_numerical_in_var"]))) - new_x = np.concatenate((sdata.table.X, np.ones_like(sdata.table.X[:, :1])), axis=1) - new_table = AnnData(X=new_x, obs=sdata.table.obs, var=new_var, uns=sdata.table.uns) - del sdata.table - sdata.table = new_table + new_var = pd.concat((sdata.tables["table"].var, pd.DataFrame(index=["another_numerical_in_var"]))) + new_x = np.concatenate((sdata.tables["table"].X, np.ones_like(sdata.tables["table"].X[:, :1])), axis=1) + new_table = AnnData(X=new_x, obs=sdata.tables["table"].obs, var=new_var, uns=sdata.tables["table"].uns) + del sdata.tables["table"] + sdata.tables["table"] = new_table result_adata = aggregate( - values_sdata=sdata, values=values_shapes, by=by, value_key=new_value_key, agg_func="sum" - ).table + values_sdata=sdata, values=values_shapes, by=by, value_key=new_value_key, agg_func="sum", table_name="table" + ).tables["table"] assert result_adata.var_names.to_list() == new_value_key # since we added only columns of 1., we just have 4 cases to check all the aggregations, and not 12 like before @@ -306,6 +317,7 @@ def test_aggregate_shapes_by_shapes( by=by, value_key=value_key, agg_func="sum", + table_name="table", ) # test we can't aggregate from mixed categorical and numerical sources (let's just test one case) with pytest.raises(ValueError): @@ -315,6 +327,7 @@ def test_aggregate_shapes_by_shapes( by=by, value_key=["numerical_values_in_obs", "categorical_values_in_obs"], agg_func="sum", + table_name="table", ) @@ -326,15 +339,16 @@ def test_aggregate_image_by_labels(labels_blobs, image_schema, labels_schema) -> image = image_schema.parse(image) labels = labels_schema.parse(labels_blobs) - out = aggregate(values=image, by=labels, agg_func="mean").table + out_sdata = aggregate(values=image, by=labels, agg_func="mean", table_name="aggregation") + out = out_sdata.tables["aggregation"] assert len(out) + 1 == len(np.unique(labels_blobs)) assert isinstance(out, AnnData) np.testing.assert_array_equal(out.var_names, [f"channel_{i}_mean" for i in image.coords["c"].values]) - out = aggregate(values=image, by=labels, agg_func=["mean", "sum", "count"]).table + out = aggregate(values=image, by=labels, agg_func=["mean", "sum", "count"]).tables["table"] assert len(out) + 1 == len(np.unique(labels_blobs)) - out = aggregate(values=image, by=labels, zone_ids=[1, 2, 3]).table + out = aggregate(values=image, by=labels, zone_ids=[1, 2, 3]).tables["table"] assert len(out) == 3 @@ -347,12 +361,11 @@ def test_aggregate_requiring_alignment(sdata_blobs: SpatialData, values, by) -> by = sdata_blobs[by] if id(values) == id(by): # warning: this will give problems when aggregation labels by labels (not supported yet), because of this: https://github.com/scverse/spatialdata/issues/269 - by = deepcopy(by) - by = _deepcopy_geodataframe(by) + by = _deepcopy(by) assert by.attrs["transform"] is not values.attrs["transform"] sdata = SpatialData.init_from_elements({"values": values, "by": by}) - out0 = aggregate(values=values, by=by, agg_func="sum").table + out0 = aggregate(values=values, by=by, agg_func="sum").tables["table"] theta = np.pi / 7 affine = Affine( @@ -374,12 +387,12 @@ def test_aggregate_requiring_alignment(sdata_blobs: SpatialData, values, by) -> # both values and by map to the "other" coordinate system, but they are not aligned set_transformation(by, Identity(), "other") - out1 = aggregate(values=values, by=by, target_coordinate_system="other", agg_func="sum").table + out1 = aggregate(values=values, by=by, target_coordinate_system="other", agg_func="sum").tables["table"] assert not np.allclose(out0.X.A, out1.X.A) # both values and by map to the "other" coordinate system, and they are aligned set_transformation(by, affine, "other") - out2 = aggregate(values=values, by=by, target_coordinate_system="other", agg_func="sum").table + out2 = aggregate(values=values, by=by, target_coordinate_system="other", agg_func="sum").tables["table"] assert np.allclose(out0.X.A, out2.X.A) # actually transforming the data still lead to a correct the result @@ -387,7 +400,9 @@ def test_aggregate_requiring_alignment(sdata_blobs: SpatialData, values, by) -> sdata2 = SpatialData.init_from_elements({"values": sdata["values"], "by": transformed_sdata["by"]}) # let's take values from the original sdata (non-transformed but aligned to 'other'); let's take by from the # transformed sdata - out3 = aggregate(values=sdata["values"], by=sdata2["by"], target_coordinate_system="other", agg_func="sum").table + out3 = aggregate(values=sdata["values"], by=sdata2["by"], target_coordinate_system="other", agg_func="sum").tables[ + "table" + ] assert np.allclose(out0.X.A, out3.X.A) @@ -406,7 +421,7 @@ def test_aggregate_considering_fractions_single_values( sdata = sdata_query_aggregation values = sdata[values_name] by = sdata[by_name] - result_adata = aggregate(values=values, by=by, value_key=value_key, agg_func="sum", fractions=True).table + result_adata = aggregate(values=values, by=by, value_key=value_key, agg_func="sum", fractions=True).tables["table"] # to manually compute the fractions of overlap that we use to test that aggregate() works values = circles_to_polygons(values) values["__index"] = values.index @@ -474,11 +489,11 @@ def test_aggregate_considering_fractions_multiple_values( sdata_query_aggregation: SpatialData, by_name, values_name, value_key ) -> None: sdata = sdata_query_aggregation - new_var = pd.concat((sdata.table.var, pd.DataFrame(index=["another_numerical_in_var"]))) - new_x = np.concatenate((sdata.table.X, np.ones_like(sdata.table.X[:, :1])), axis=1) - new_table = AnnData(X=new_x, obs=sdata.table.obs, var=new_var, uns=sdata.table.uns) - del sdata.table - sdata.table = new_table + new_var = pd.concat((sdata.tables["table"].var, pd.DataFrame(index=["another_numerical_in_var"]))) + new_x = np.concatenate((sdata.tables["table"].X, np.ones_like(sdata.tables["table"].X[:, :1])), axis=1) + new_table = AnnData(X=new_x, obs=sdata.tables["table"].obs, var=new_var, uns=sdata.tables["table"].uns) + del sdata.tables["table"] + sdata.tables["table"] = new_table out = aggregate( values_sdata=sdata, values="values_circles", @@ -486,9 +501,10 @@ def test_aggregate_considering_fractions_multiple_values( value_key=["numerical_in_var", "another_numerical_in_var"], agg_func="sum", fractions=True, - ).table + table_name="table", + ).tables["table"] overlaps = np.array([0.655781239649211, 1.0000000000000002, 1.0000000000000004, 0.1349639285777728]) - row0 = np.sum(sdata.table.X[[0, 1, 2, 3], :] * overlaps.reshape(-1, 1), axis=0) + row0 = np.sum(sdata.tables["table"].X[[0, 1, 2, 3], :] * overlaps.reshape(-1, 1), axis=0) assert np.all(np.isclose(out.X.A, np.array([row0, [0, 0]]))) @@ -529,19 +545,19 @@ def test_aggregate_spatialdata(sdata_blobs: SpatialData) -> None: sdata2 = sdata_blobs.aggregate(values="blobs_points", by=sdata_blobs["blobs_polygons"], agg_func="sum") sdata3 = sdata_blobs.aggregate(values=sdata_blobs["blobs_points"], by=sdata_blobs["blobs_polygons"], agg_func="sum") - assert_equal(sdata0.table, sdata1.table) - assert_equal(sdata2.table, sdata3.table) + assert_equal(sdata0.tables["table"], sdata1.tables["table"]) + assert_equal(sdata2.tables["table"], sdata3.tables["table"]) # in sdata2 the name of the "by" region was not passed, so a default one is used - assert sdata2.table.obs["region"].value_counts()["by"] == 3 + assert sdata2.tables["table"].obs["region"].value_counts()["by"] == 3 # let's change it so we can make the objects comparable - sdata2.table.obs["region"] = "blobs_polygons" - sdata2.table.obs["region"] = sdata2.table.obs["region"].astype("category") - sdata2.table.uns[TableModel.ATTRS_KEY]["region"] = "blobs_polygons" - assert_equal(sdata0.table, sdata2.table) + sdata2.tables["table"].obs["region"] = "blobs_polygons" + sdata2.tables["table"].obs["region"] = sdata2.tables["table"].obs["region"].astype("category") + sdata2.tables["table"].uns[TableModel.ATTRS_KEY]["region"] = "blobs_polygons" + assert_equal(sdata0.tables["table"], sdata2.tables["table"]) assert len(sdata0.shapes["blobs_polygons"]) == 3 - assert sdata0.table.shape == (3, 2) + assert sdata0.tables["table"].shape == (3, 2) def test_aggregate_deepcopy(sdata_blobs: SpatialData) -> None: diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 8070d170..9e4e235e 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -5,23 +5,13 @@ import numpy as np import pytest from anndata import AnnData -from dask.dataframe.core import DataFrame as DaskDataFrame -from dask.delayed import Delayed -from geopandas import GeoDataFrame -from multiscale_spatial_image import MultiscaleSpatialImage -from spatial_image import SpatialImage from spatialdata._core.concatenate import _concatenate_tables, concatenate from spatialdata._core.data_extent import are_extents_equal, get_extent from spatialdata._core.operations._utils import transform_to_data_extent from spatialdata._core.spatialdata import SpatialData from spatialdata.datasets import blobs -from spatialdata.models import ( - Image2DModel, - Labels2DModel, - PointsModel, - ShapesModel, - TableModel, -) +from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel, get_table_keys +from spatialdata.testing import assert_elements_dict_are_identical, assert_spatial_data_objects_are_identical from spatialdata.transformations.operations import get_transformation, set_transformation from spatialdata.transformations.transformations import ( Affine, @@ -116,42 +106,9 @@ def test_element_names_unique() -> None: assert "shapes" not in sdata._shared_keys -def _assert_elements_left_to_right_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: - for element_type, element_name, element in sdata0._gen_elements(): - elements = sdata1.__getattribute__(element_type) - assert element_name in elements - element1 = elements[element_name] - if isinstance(element, (AnnData, SpatialImage, GeoDataFrame)): - assert element.shape == element1.shape - elif isinstance(element, DaskDataFrame): - for s0, s1 in zip(element.shape, element1.shape): - if isinstance(s0, Delayed): - s0 = s0.compute() - if isinstance(s1, Delayed): - s1 = s1.compute() - assert s0 == s1 - elif isinstance(element, MultiscaleSpatialImage): - assert len(element) == len(element1) - else: - raise TypeError(f"Unsupported type {type(element)}") - - -def _assert_tables_seem_identical(table0: AnnData | None, table1: AnnData | None) -> None: - assert table0 is None and table1 is None or table0.shape == table1.shape - - -def _assert_spatialdata_objects_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: - # this is not a full comparison, but it's fine anyway - assert len(list(sdata0._gen_elements())) == len(list(sdata1._gen_elements())) - assert set(sdata0.coordinate_systems) == set(sdata1.coordinate_systems) - _assert_elements_left_to_right_seem_identical(sdata0, sdata1) - _assert_elements_left_to_right_seem_identical(sdata1, sdata0) - _assert_tables_seem_identical(sdata0.table, sdata1.table) - - def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: sdata = full_sdata.filter_by_coordinate_system(coordinate_system="global", filter_table=False) - _assert_spatialdata_objects_seem_identical(sdata, full_sdata) + assert_spatial_data_objects_are_identical(sdata, full_sdata) scale = Scale([2.0], axes=("x",)) set_transformation(full_sdata.images["image2d"], scale, "my_space0") @@ -159,13 +116,13 @@ def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: set_transformation(full_sdata.shapes["poly"], Identity(), "my_space1") sdata_my_space = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) - assert len(list(sdata_my_space._gen_elements())) == 2 - _assert_tables_seem_identical(sdata_my_space.table, full_sdata.table) + assert len(list(sdata_my_space.gen_elements())) == 3 + assert_elements_dict_are_identical(sdata_my_space.tables, full_sdata.tables) sdata_my_space1 = full_sdata.filter_by_coordinate_system( coordinate_system=["my_space0", "my_space1", "my_space2"], filter_table=False ) - assert len(list(sdata_my_space1._gen_elements())) == 3 + assert len(list(sdata_my_space1.gen_elements())) == 4 def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None: @@ -305,7 +262,7 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None: set_transformation(full_sdata.shapes["circles"], Identity(), "my_space0") set_transformation(full_sdata.shapes["poly"], Identity(), "my_space1") filtered = full_sdata.filter_by_coordinate_system(coordinate_system=["my_space0", "my_space1"], filter_table=False) - assert len(list(filtered._gen_elements())) == 2 + assert len(list(filtered.gen_elements())) == 3 filtered0 = filtered.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) filtered1 = filtered.filter_by_coordinate_system(coordinate_system="my_space1", filter_table=False) # this is needed cause we can't handle regions with same name. @@ -316,8 +273,8 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None: filtered1.table = table_new filtered1.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_region filtered1.table.obs[filtered1.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]] = new_region - concatenated = concatenate([filtered0, filtered1]) - assert len(list(concatenated._gen_elements())) == 2 + concatenated = concatenate([filtered0, filtered1], concatenate_tables=True) + assert len(list(concatenated.gen_elements())) == 3 def test_locate_spatial_element(full_sdata: SpatialData) -> None: @@ -360,11 +317,12 @@ def test_no_shared_transformations() -> None: set_transformation(sdata.images[element_name], Identity(), to_coordinate_system=test_space) gen = sdata._gen_elements() - for _, name, obj in gen: - if name != element_name: - assert test_space not in get_transformation(obj, get_all=True) - else: - assert test_space in get_transformation(obj, get_all=True) + for element_type, name, obj in gen: + if element_type != "tables": + if name != element_name: + assert test_space not in get_transformation(obj, get_all=True) + else: + assert test_space in get_transformation(obj, get_all=True) def test_init_from_elements(full_sdata: SpatialData) -> None: @@ -375,27 +333,29 @@ def test_init_from_elements(full_sdata: SpatialData) -> None: def test_subset(full_sdata: SpatialData) -> None: - element_names = ["image2d", "labels2d", "points_0", "circles", "poly"] + element_names = ["image2d", "points_0", "circles", "poly"] subset0 = full_sdata.subset(element_names) unique_names = set() - for _, k, _ in subset0._gen_elements(): + for _, k, _ in subset0.gen_spatial_elements(): unique_names.add(k) assert "image3d_xarray" in full_sdata.images assert unique_names == set(element_names) + # no table since the labels are not present in the subset assert subset0.table is None adata = AnnData( shape=(10, 0), - obs={"region": ["circles"] * 5 + ["poly"] * 5, "instance_id": [0, 1, 2, 3, 4, "a", "b", "c", "d", "e"]}, + obs={"region": ["circles"] * 5 + ["poly"] * 5, "instance_id": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]}, ) del full_sdata.table - full_sdata.table = TableModel.parse( - adata, region=["circles", "poly"], region_key="region", instance_key="instance_id" - ) - subset1 = full_sdata.subset(["poly"]) + sdata_table = TableModel.parse(adata, region=["circles", "poly"], region_key="region", instance_key="instance_id") + full_sdata.table = sdata_table + full_sdata.tables["second_table"] = sdata_table + subset1 = full_sdata.subset(["poly", "second_table"]) assert subset1.table is not None assert len(subset1.table) == 5 assert subset1.table.obs["region"].unique().tolist() == ["poly"] + assert len(subset1["second_table"]) == 10 @pytest.mark.parametrize("maintain_positioning", [True, False]) @@ -413,7 +373,7 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning: scale = Scale([2.0], axes=("x",)) translation = Translation([-100.0, 200.0], axes=("x", "y")) sequence = Sequence([rotation, scale, translation]) - for el in full_sdata._gen_elements_values(): + for el in full_sdata._gen_spatial_element_values(): set_transformation(el, sequence, "global") elements = [ "image2d", @@ -429,7 +389,7 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning: sdata = transform_to_data_extent(full_sdata, "global", target_width=1000, maintain_positioning=maintain_positioning) matrices = [] - for el in sdata._gen_elements_values(): + for el in sdata._gen_spatial_element_values(): t = get_transformation(el, to_coordinate_system="global") assert isinstance(t, BaseTransformation) a = t.to_affine_matrix(input_axes=("x", "y", "z"), output_axes=("x", "y", "z")) @@ -449,5 +409,38 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning: data_extent_after = get_extent(after, coordinate_system="global") # huge tolerance because of the bug with pixel perfectness assert are_extents_equal( - data_extent_before, data_extent_after, atol=3 + data_extent_before, data_extent_after, atol=4 ), f"data_extent_before: {data_extent_before}, data_extent_after: {data_extent_after} for element {element}" + + +def test_validate_table_in_spatialdata(full_sdata): + table = full_sdata["table"] + region, region_key, _ = get_table_keys(table) + assert region == "labels2d" + + full_sdata.validate_table_in_spatialdata(table) + + # dtype mismatch + full_sdata.labels["labels2d"] = Labels2DModel.parse(full_sdata.labels["labels2d"].astype("int16")) + with pytest.warns(UserWarning, match="that does not match the dtype of the indices of the annotated element"): + full_sdata.validate_table_in_spatialdata(table) + + # region not found + del full_sdata.labels["labels2d"] + with pytest.warns(UserWarning, match="in the SpatialData object"): + full_sdata.validate_table_in_spatialdata(table) + + table.obs[region_key] = "points_0" + full_sdata.set_table_annotates_spatialelement("table", region="points_0") + + full_sdata.validate_table_in_spatialdata(table) + + # dtype mismatch + full_sdata.points["points_0"].index = full_sdata.points["points_0"].index.astype("int16") + with pytest.warns(UserWarning, match="that does not match the dtype of the indices of the annotated element"): + full_sdata.validate_table_in_spatialdata(table) + + # region not found + del full_sdata.points["points_0"] + with pytest.warns(UserWarning, match="in the SpatialData object"): + full_sdata.validate_table_in_spatialdata(table) diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index 953de15c..5ad61dec 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -94,8 +94,7 @@ def test_physical_units(self, tmp_path: str, shapes: SpatialData) -> None: assert new_sdata.coordinate_systems["test"]._axes[0].unit == "micrometers" -def _get_affine(small_translation: bool = True) -> Affine: - theta = math.pi / 18 +def _get_affine(small_translation: bool = True, theta: float = math.pi / 18) -> Affine: k = 10.0 if small_translation else 1.0 return Affine( [ @@ -123,7 +122,7 @@ def _unpad_rasters(sdata: SpatialData) -> SpatialData: def _postpone_transformation( sdata: SpatialData, from_coordinate_system: str, to_coordinate_system: str, transformation: BaseTransformation ): - for element in sdata._gen_elements_values(): + for element in sdata._gen_spatial_element_values(): d = get_transformation(element, get_all=True) assert isinstance(d, dict) assert len(d) == 1 @@ -490,7 +489,7 @@ def test_transform_elements_and_entire_spatial_data_object(full_sdata: SpatialDa scale = Scale([k], axes=("x",)) translation = Translation([k], axes=("x",)) sequence = Sequence([scale, translation]) - for element in full_sdata._gen_elements_values(): + for element in full_sdata._gen_spatial_element_values(): set_transformation(element, sequence, "my_space") transformed_element = full_sdata.transform_element_to_coordinate_system( element, "my_space", maintain_positioning=maintain_positioning @@ -524,7 +523,7 @@ def test_transform_elements_and_entire_spatial_data_object_multi_hop( ): k = 10.0 scale = Scale([k], axes=("x",)) - for element in full_sdata._gen_elements_values(): + for element in full_sdata._gen_spatial_element_values(): set_transformation(element, scale, "my_space") # testing the scenario "element1 -> cs1 <- element2 -> cs2" and transforming element1 to cs2 @@ -535,13 +534,13 @@ def test_transform_elements_and_entire_spatial_data_object_multi_hop( ) # otherwise we have multiple paths to go from my_space to multi_hop_space - for element in full_sdata._gen_elements_values(): + for element in full_sdata._gen_spatial_element_values(): d = get_transformation(element, get_all=True) assert isinstance(d, dict) if "global" in d: remove_transformation(element, "global") - for element in full_sdata._gen_elements_values(): + for element in full_sdata._gen_spatial_element_values(): transformed_element = full_sdata.transform_element_to_coordinate_system( element, "multi_hop_space", maintain_positioning=maintain_positioning ) diff --git a/tests/core/operations/test_vectorize.py b/tests/core/operations/test_vectorize.py new file mode 100644 index 00000000..a0a7306e --- /dev/null +++ b/tests/core/operations/test_vectorize.py @@ -0,0 +1,68 @@ +import numpy as np +import pytest +from geopandas import GeoDataFrame +from shapely import Point +from spatialdata._core.operations.vectorize import to_circles +from spatialdata.datasets import blobs +from spatialdata.models.models import ShapesModel +from spatialdata.testing import assert_elements_are_identical + +# each of the tests operates on different elements, hence we can initialize the data once without conflicts +sdata = blobs() + + +@pytest.mark.parametrize("is_multiscale", [False, True]) +def test_labels_2d_to_circles(is_multiscale: bool) -> None: + key = "blobs" + ("_multiscale" if is_multiscale else "") + "_labels" + element = sdata[key] + new_circles = to_circles(element) + + assert np.isclose(new_circles.loc[1].geometry.x, 330.59258152354386) + assert np.isclose(new_circles.loc[1].geometry.y, 78.85026897788404) + assert np.isclose(new_circles.loc[1].radius, 69.229993) + assert 7 not in new_circles.index + + +@pytest.mark.skip(reason="Not implemented") +# @pytest.mark.parametrize("background", [0, 1]) +# @pytest.mark.parametrize("is_multiscale", [False, True]) +def test_labels_3d_to_circles() -> None: + pass + + +def test_circles_to_circles() -> None: + element = sdata["blobs_circles"] + new_circles = to_circles(element) + assert_elements_are_identical(element, new_circles) + + +def test_polygons_to_circles() -> None: + element = sdata["blobs_polygons"].iloc[:2] + new_circles = to_circles(element) + + data = { + "geometry": [Point(315.8120722406787, 220.18894606643332), Point(270.1386975678398, 417.8747936281634)], + "radius": [16.608781, 17.541365], + } + expected = ShapesModel.parse(GeoDataFrame(data, geometry="geometry")) + + assert_elements_are_identical(new_circles, expected) + + +def test_multipolygons_to_circles() -> None: + element = sdata["blobs_multipolygons"] + new_circles = to_circles(element) + + data = { + "geometry": [Point(340.37951022629096, 250.76310705786318), Point(337.1680699150594, 316.39984581697314)], + "radius": [23.488363, 19.059285], + } + expected = ShapesModel.parse(GeoDataFrame(data, geometry="geometry")) + assert_elements_are_identical(new_circles, expected) + + +def test_points_images_to_circles() -> None: + with pytest.raises(RuntimeError, match=r"Cannot apply to_circles\(\) to images."): + to_circles(sdata["blobs_image"]) + with pytest.raises(RuntimeError, match="Unsupported type"): + to_circles(sdata["blobs_points"]) diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index 6cb4daec..db4cef3c 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -3,17 +3,16 @@ import pytest from anndata import AnnData from spatialdata import get_values, match_table_to_element -from spatialdata._core.query.relational_query import _locate_value, _ValueOrigin +from spatialdata._core.query.relational_query import ( + _get_element_annotators, + _locate_value, + _ValueOrigin, + join_sdata_spatialelement_table, +) from spatialdata.models.models import TableModel def test_match_table_to_element(sdata_query_aggregation): - # table can't annotate points - with pytest.raises(AssertionError): - match_table_to_element(sdata=sdata_query_aggregation, element_name="points") - # table is not annotating "by_circles" - with pytest.raises(AssertionError, match="No row matches in the table annotates the element"): - match_table_to_element(sdata=sdata_query_aggregation, element_name="by_circles") matched_table = match_table_to_element(sdata=sdata_query_aggregation, element_name="values_circles") arr = np.array(list(reversed(sdata_query_aggregation["values_circles"].index))) sdata_query_aggregation["values_circles"].index = arr @@ -23,6 +22,180 @@ def test_match_table_to_element(sdata_query_aggregation): # TODO: add tests for labels +def test_join_using_string_instance_id_and_index(sdata_query_aggregation): + sdata_query_aggregation["table"].obs["instance_id"] = [ + f"string_{i}" for i in sdata_query_aggregation["table"].obs["instance_id"] + ] + sdata_query_aggregation["values_circles"].index = pd.Index( + [f"string_{i}" for i in sdata_query_aggregation["values_circles"].index] + ) + sdata_query_aggregation["values_polygons"].index = pd.Index( + [f"string_{i}" for i in sdata_query_aggregation["values_polygons"].index] + ) + + sdata_query_aggregation["values_polygons"] = sdata_query_aggregation["values_polygons"][:5] + sdata_query_aggregation["values_circles"] = sdata_query_aggregation["values_circles"][:5] + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "inner" + ) + # Note that we started with 21 n_obs. + assert table.n_obs == 10 + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "right_exclusive" + ) + assert table.n_obs == 11 + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "right" + ) + assert table.n_obs == 21 + + +def test_left_inner_right_exclusive_join(sdata_query_aggregation): + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, "values_polygons", "table", "right_exclusive" + ) + assert table is None + assert all(element_dict[key] is None for key in element_dict) + + sdata_query_aggregation["values_polygons"] = sdata_query_aggregation["values_polygons"].drop([10, 11]) + with pytest.raises(AssertionError, match="No table with"): + join_sdata_spatialelement_table(sdata_query_aggregation, "values_polygons", "not_existing_table", "left") + + # Should we reindex before returning the table? + element_dict, table = join_sdata_spatialelement_table(sdata_query_aggregation, "values_polygons", "table", "left") + assert all(element_dict["values_polygons"].index == table.obs["instance_id"].values) + + # Check no matches in table for element not annotated by table + element_dict, table = join_sdata_spatialelement_table(sdata_query_aggregation, "by_polygons", "table", "left") + assert table is None + assert element_dict["by_polygons"] is sdata_query_aggregation["by_polygons"] + + # Check multiple elements, one of which not annotated by table + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["by_polygons", "values_polygons"], "table", "left" + ) + assert "by_polygons" in element_dict + + # check multiple elements joined to table. + sdata_query_aggregation["values_circles"] = sdata_query_aggregation["values_circles"].drop([7, 8]) + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left" + ) + indices = pd.concat( + [element_dict["values_circles"].index.to_series(), element_dict["values_polygons"].index.to_series()] + ) + assert all(table.obs["instance_id"] == indices.values) + + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons", "by_polygons"], "table", "right_exclusive" + ) + assert all(element_dict[key] is None for key in element_dict) + assert all(table.obs.index == ["7", "8", "19", "20"]) + assert all(table.obs["instance_id"].values == [7, 8, 10, 11]) + assert all(table.obs["region"].values == ["values_circles", "values_circles", "values_polygons", "values_polygons"]) + + # the triggered warning is: UserWarning: The element `{name}` is not annotated by the table. Skipping + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons", "by_polygons"], "table", "inner" + ) + indices = pd.concat( + [element_dict["values_circles"].index.to_series(), element_dict["values_polygons"].index.to_series()] + ) + assert all(table.obs["instance_id"] == indices.values) + assert element_dict["by_polygons"] is None + + +def test_join_spatialelement_table_fail(full_sdata): + with pytest.warns(UserWarning, match="Images:"): + join_sdata_spatialelement_table(full_sdata, ["image2d", "labels2d"], "table", "left_exclusive") + with pytest.warns(UserWarning, match="Tables:"): + join_sdata_spatialelement_table(full_sdata, ["labels2d", "table"], "table", "left_exclusive") + with pytest.raises(TypeError, match="`not_join` is not a"): + join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "not_join") + + +def test_left_exclusive_and_right_join(sdata_query_aggregation): + # Test case in which all table rows match rows in elements + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left_exclusive" + ) + assert all(element_dict[key] is None for key in element_dict) + assert table is None + + # Dropped indices correspond to instance ids 7, 8 for 'values_circles' and 10, 11 for 'values_polygons' + sdata_query_aggregation["table"] = sdata_query_aggregation["table"][ + sdata_query_aggregation["table"].obs.index.drop(["7", "8", "19", "20"]) + ] + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_polygons", "by_polygons"], "table", "left_exclusive" + ) + assert table is None + assert not set(element_dict["values_polygons"].index).issubset(sdata_query_aggregation["table"].obs["instance_id"]) + + # test right join + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons", "by_polygons"], "table", "right" + ) + assert table is sdata_query_aggregation["table"] + assert not {7, 8}.issubset(element_dict["values_circles"].index) + assert not {10, 11}.issubset(element_dict["values_polygons"].index) + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left_exclusive" + ) + assert table is None + assert not np.array_equal( + sdata_query_aggregation["table"].obs.iloc[7:9]["instance_id"].values, + element_dict["values_circles"].index.values, + ) + assert not np.array_equal( + sdata_query_aggregation["table"].obs.iloc[19:21]["instance_id"].values, + element_dict["values_polygons"].index.values, + ) + + +def test_match_rows_join(sdata_query_aggregation): + reversed_instance_id = [3, 4, 5, 6, 7, 8, 1, 2, 0] + list(reversed(range(12))) + original_instance_id = sdata_query_aggregation.table.obs["instance_id"] + sdata_query_aggregation.table.obs["instance_id"] = reversed_instance_id + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left", match_rows="left" + ) + assert all(table.obs["instance_id"].values == original_instance_id.values) + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "right", match_rows="right" + ) + indices = [*element_dict["values_circles"].index, *element_dict[("values_polygons")].index] + assert all(indices == table.obs["instance_id"]) + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "inner", match_rows="left" + ) + assert all(table.obs["instance_id"].values == original_instance_id.values) + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "inner", match_rows="right" + ) + indices = [*element_dict["values_circles"].index, *element_dict[("values_polygons")].index] + assert all(indices == table.obs["instance_id"]) + + # check whether table ordering is preserved if not matching + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left" + ) + assert all(table.obs["instance_id"] == reversed_instance_id) + + def test_locate_value(sdata_query_aggregation): def _check_location(locations: list[_ValueOrigin], origin: str, is_categorical: bool): assert len(locations) == 1 @@ -31,45 +204,74 @@ def _check_location(locations: list[_ValueOrigin], origin: str, is_categorical: # var, numerical _check_location( - _locate_value(value_key="numerical_in_var", sdata=sdata_query_aggregation, element_name="values_circles"), + _locate_value( + value_key="numerical_in_var", + sdata=sdata_query_aggregation, + element_name="values_circles", + table_name="table", + ), origin="var", is_categorical=False, ) # obs, categorical _check_location( - _locate_value(value_key="categorical_in_obs", sdata=sdata_query_aggregation, element_name="values_circles"), + _locate_value( + value_key="categorical_in_obs", + sdata=sdata_query_aggregation, + element_name="values_circles", + table_name="table", + ), origin="obs", is_categorical=True, ) # obs, numerical _check_location( - _locate_value(value_key="numerical_in_obs", sdata=sdata_query_aggregation, element_name="values_circles"), + _locate_value( + value_key="numerical_in_obs", + sdata=sdata_query_aggregation, + element_name="values_circles", + table_name="table", + ), origin="obs", is_categorical=False, ) # gdf, categorical # sdata + element_name _check_location( - _locate_value(value_key="categorical_in_gdf", sdata=sdata_query_aggregation, element_name="values_circles"), + _locate_value( + value_key="categorical_in_gdf", + sdata=sdata_query_aggregation, + element_name="values_circles", + table_name="table", + ), origin="df", is_categorical=True, ) # element _check_location( - _locate_value(value_key="categorical_in_gdf", element=sdata_query_aggregation["values_circles"]), + _locate_value( + value_key="categorical_in_gdf", element=sdata_query_aggregation["values_circles"], table_name="table" + ), origin="df", is_categorical=True, ) # gdf, numerical # sdata + element_name _check_location( - _locate_value(value_key="numerical_in_gdf", sdata=sdata_query_aggregation, element_name="values_circles"), + _locate_value( + value_key="numerical_in_gdf", + sdata=sdata_query_aggregation, + element_name="values_circles", + table_name="table", + ), origin="df", is_categorical=False, ) # element _check_location( - _locate_value(value_key="numerical_in_gdf", element=sdata_query_aggregation["values_circles"]), + _locate_value( + value_key="numerical_in_gdf", element=sdata_query_aggregation["values_circles"], table_name="table" + ), origin="df", is_categorical=False, ) @@ -103,7 +305,9 @@ def _check_location(locations: list[_ValueOrigin], origin: str, is_categorical: def test_get_values_df(sdata_query_aggregation): # test with a single value, in the dataframe; using sdata + element_name - v = get_values(value_key="numerical_in_gdf", sdata=sdata_query_aggregation, element_name="values_circles") + v = get_values( + value_key="numerical_in_gdf", sdata=sdata_query_aggregation, element_name="values_circles", table_name="table" + ) assert v.shape == (9, 1) # test with multiple values, in the dataframe; using element @@ -114,7 +318,9 @@ def test_get_values_df(sdata_query_aggregation): assert v.shape == (9, 2) # test with a single value, in the obs - v = get_values(value_key="numerical_in_obs", sdata=sdata_query_aggregation, element_name="values_circles") + v = get_values( + value_key="numerical_in_obs", sdata=sdata_query_aggregation, element_name="values_circles", table_name="table" + ) assert v.shape == (9, 1) # test with multiple values, in the obs @@ -123,11 +329,14 @@ def test_get_values_df(sdata_query_aggregation): value_key=["numerical_in_obs", "another_numerical_in_obs"], sdata=sdata_query_aggregation, element_name="values_circles", + table_name="table", ) assert v.shape == (9, 2) # test with a single value, in the var - v = get_values(value_key="numerical_in_var", sdata=sdata_query_aggregation, element_name="values_circles") + v = get_values( + value_key="numerical_in_var", sdata=sdata_query_aggregation, element_name="values_circles", table_name="table" + ) assert v.shape == (9, 1) # test with multiple values, in the var @@ -145,6 +354,7 @@ def test_get_values_df(sdata_query_aggregation): value_key=["numerical_in_var", "another_numerical_in_var"], sdata=sdata_query_aggregation, element_name="values_circles", + table_name="table", ) assert v.shape == (9, 2) @@ -152,11 +362,18 @@ def test_get_values_df(sdata_query_aggregation): # value found in multiple locations sdata_query_aggregation.table.obs["another_numerical_in_gdf"] = np.zeros(21) with pytest.raises(ValueError): - get_values(value_key="another_numerical_in_gdf", sdata=sdata_query_aggregation, element_name="values_circles") + get_values( + value_key="another_numerical_in_gdf", + sdata=sdata_query_aggregation, + element_name="values_circles", + table_name="table", + ) # value not found with pytest.raises(ValueError): - get_values(value_key="not_present", sdata=sdata_query_aggregation, element_name="values_circles") + get_values( + value_key="not_present", sdata=sdata_query_aggregation, element_name="values_circles", table_name="table" + ) # mixing categorical and numerical values with pytest.raises(ValueError): @@ -164,6 +381,7 @@ def test_get_values_df(sdata_query_aggregation): value_key=["numerical_in_gdf", "categorical_in_gdf"], sdata=sdata_query_aggregation, element_name="values_circles", + table_name="table", ) # multiple categorical values @@ -173,6 +391,7 @@ def test_get_values_df(sdata_query_aggregation): value_key=["categorical_in_gdf", "another_categorical_in_gdf"], sdata=sdata_query_aggregation, element_name="values_circles", + table_name="table", ) # mixing different origins @@ -181,6 +400,7 @@ def test_get_values_df(sdata_query_aggregation): value_key=["numerical_in_gdf", "numerical_in_obs"], sdata=sdata_query_aggregation, element_name="values_circles", + table_name="table", ) @@ -188,7 +408,7 @@ def test_get_values_labels_bug(sdata_blobs): # https://github.com/scverse/spatialdata-plot/issues/165 from spatialdata import get_values - get_values("channel_0_sum", sdata=sdata_blobs, element_name="blobs_labels") + get_values("channel_0_sum", sdata=sdata_blobs, element_name="blobs_labels", table_name="table") def test_filter_table_categorical_bug(shapes): @@ -200,3 +420,83 @@ def test_filter_table_categorical_bug(shapes): adata_subset = adata[adata.obs["categorical"] == "a"].copy() shapes.table = adata_subset shapes.filter_by_coordinate_system("global") + + +def test_labels_table_joins(full_sdata): + element_dict, table = join_sdata_spatialelement_table( + full_sdata, + "labels2d", + "table", + "left", + ) + assert all(table.obs["instance_id"] == range(100)) + + full_sdata["table"].obs["instance_id"] = list(reversed(range(100))) + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "left", match_rows="left") + assert all(table.obs["instance_id"] == range(100)) + + with pytest.warns(UserWarning, match="Element type"): + join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "left_exclusive") + + with pytest.warns(UserWarning, match="Element type"): + join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "inner") + + with pytest.warns(UserWarning, match="Element type"): + join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "right") + + # all labels are present in table so should return None + element_dict, table = join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "right_exclusive") + assert element_dict["labels2d"] is None + assert table is None + + +def test_points_table_joins(full_sdata): + full_sdata["table"].uns["spatialdata_attrs"]["region"] = "points_0" + full_sdata["table"].obs["region"] = ["points_0"] * 100 + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "left") + + # points should have the same number of rows as before and table as well + assert len(element_dict["points_0"]) == 300 + assert all(table.obs["instance_id"] == range(100)) + + full_sdata["table"].obs["instance_id"] = list(reversed(range(100))) + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "left", match_rows="left") + assert len(element_dict["points_0"]) == 300 + assert all(table.obs["instance_id"] == range(100)) + + # We have 100 table instances so resulting length of points should be 200 as we started with 300 + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "left_exclusive") + assert len(element_dict["points_0"]) == 200 + assert table is None + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "inner") + assert len(element_dict["points_0"]) == 100 + assert all(table.obs["instance_id"] == list(reversed(range(100)))) + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "right") + assert len(element_dict["points_0"]) == 100 + assert all(table.obs["instance_id"] == list(reversed(range(100)))) + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "right", match_rows="right") + assert all(element_dict["points_0"].index.values.compute() == list(reversed(range(100)))) + assert all(table.obs["instance_id"] == list(reversed(range(100)))) + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "right_exclusive") + assert element_dict["points_0"] is None + assert table is None + + +def test_get_element_annotators(full_sdata): + names = _get_element_annotators(full_sdata, "points_0") + assert len(names) == 0 + + names = _get_element_annotators(full_sdata, "labels2d") + assert names == {"table"} + + another_table = full_sdata.tables["table"].copy() + full_sdata.tables["another_table"] = another_table + names = _get_element_annotators(full_sdata, "labels2d") + assert names == {"another_table", "table"} diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 26b4c472..a043cab0 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -27,12 +27,10 @@ ShapesModel, TableModel, ) +from spatialdata.testing import assert_spatial_data_objects_are_identical from spatialdata.transformations import Identity, set_transformation from tests.conftest import _make_points, _make_squares -from tests.core.operations.test_spatialdata_operations import ( - _assert_spatialdata_objects_seem_identical, -) def test_bounding_box_request_immutable(): @@ -186,8 +184,6 @@ def test_query_points_no_points(): assert request is None -# TODO: more tests can be added for spatial queries after the cases 2, 3, 4 are implemented -# (see https://github.com/scverse/spatialdata/pull/151, also for details on more tests) @pytest.mark.parametrize("n_channels", [1, 2, 3]) @pytest.mark.parametrize("is_labels", [True, False]) @pytest.mark.parametrize("is_3d", [True, False]) @@ -360,15 +356,15 @@ def test_query_spatial_data(full_sdata): result1 = full_sdata.query(request, filter_table=True) result2 = full_sdata.query.bounding_box(**request.to_dict(), filter_table=True) - _assert_spatialdata_objects_seem_identical(result0, result1) - _assert_spatialdata_objects_seem_identical(result0, result2) + assert_spatial_data_objects_are_identical(result0, result1) + assert_spatial_data_objects_are_identical(result0, result2) polygon = Polygon([(1, 2), (60, 2), (60, 40), (1, 40)]) result3 = polygon_query(full_sdata, polygon=polygon, target_coordinate_system="global", filter_table=True) result4 = full_sdata.query.polygon(polygon=polygon, target_coordinate_system="global", filter_table=True) - _assert_spatialdata_objects_seem_identical(result0, result3) - _assert_spatialdata_objects_seem_identical(result0, result4) + assert_spatial_data_objects_are_identical(result0, result3, check_transformations=False) + assert_spatial_data_objects_are_identical(result0, result4, check_transformations=False) @pytest.mark.parametrize("with_polygon_query", [True, False]) @@ -421,7 +417,7 @@ def test_polygon_query_with_multipolygon(sdata_query_aggregation): sdata = sdata_query_aggregation values_sdata = SpatialData( shapes={"values_polygons": sdata["values_polygons"], "values_circles": sdata["values_circles"]}, - table=sdata.table, + tables=sdata.table, ) polygon = sdata["by_polygons"].geometry.iloc[0] circle = sdata["by_circles"].geometry.iloc[0] diff --git a/tests/core/test_data_extent.py b/tests/core/test_data_extent.py index d7304ddf..94a1216f 100644 --- a/tests/core/test_data_extent.py +++ b/tests/core/test_data_extent.py @@ -7,7 +7,7 @@ from numpy.random import default_rng from shapely.geometry import MultiPolygon, Point, Polygon from spatialdata import SpatialData, get_extent, transform -from spatialdata._utils import _deepcopy_geodataframe +from spatialdata._core._deepcopy import deepcopy as _deepcopy from spatialdata.datasets import blobs from spatialdata.models import Image2DModel, PointsModel, ShapesModel from spatialdata.transformations import Affine, Translation, remove_transformation, set_transformation @@ -237,7 +237,7 @@ def test_get_extent_affine_circles(): affine = _get_affine(small_translation=True) # let's do a deepcopy of the circles since we don't want to modify the original data - circles = _deepcopy_geodataframe(sdata["blobs_circles"]) + circles = _deepcopy(sdata["blobs_circles"]) set_transformation(element=circles, transformation=affine, to_coordinate_system="transformed") @@ -304,8 +304,8 @@ def test_get_extent_affine_sdata(): # let's make a copy since we don't want to modify the original data sdata2 = SpatialData( shapes={ - "circles": _deepcopy_geodataframe(sdata["blobs_circles"]), - "polygons": _deepcopy_geodataframe(sdata["blobs_polygons"]), + "circles": _deepcopy(sdata["blobs_circles"]), + "polygons": _deepcopy(sdata["blobs_polygons"]), } ) translation0 = Translation([10], axes=("x",)) diff --git a/tests/core/test_deepcopy.py b/tests/core/test_deepcopy.py new file mode 100644 index 00000000..7c3bcae5 --- /dev/null +++ b/tests/core/test_deepcopy.py @@ -0,0 +1,46 @@ +from pandas.testing import assert_frame_equal +from spatialdata._core._deepcopy import deepcopy as _deepcopy +from spatialdata.testing import assert_spatial_data_objects_are_identical + + +def test_deepcopy(full_sdata): + to_delete = [] + for element_type, element_name in to_delete: + del getattr(full_sdata, element_type)[element_name] + + copied = _deepcopy(full_sdata) + # we first compute() the data in-place, then deepcopy and then we make the data lazy again; if the last step is + # missing, calling _deepcopy() again on the original data would fail. Here we check for that. + copied_again = _deepcopy(full_sdata) + + # workaround for https://github.com/scverse/spatialdata/issues/486 + for _, element_name, _ in full_sdata.gen_elements(): + assert full_sdata[element_name] is not copied[element_name] + assert full_sdata[element_name] is not copied_again[element_name] + assert copied[element_name] is not copied_again[element_name] + + p0_0 = full_sdata["points_0"].compute() + columns = list(p0_0.columns) + p0_1 = full_sdata["points_0_1"].compute()[columns] + + p1_0 = copied["points_0"].compute()[columns] + p1_1 = copied["points_0_1"].compute()[columns] + + p2_0 = copied_again["points_0"].compute()[columns] + p2_1 = copied_again["points_0_1"].compute()[columns] + + assert_frame_equal(p0_0, p1_0) + assert_frame_equal(p0_1, p1_1) + assert_frame_equal(p0_0, p2_0) + assert_frame_equal(p0_1, p2_1) + + del full_sdata.points["points_0"] + del full_sdata.points["points_0_1"] + del copied.points["points_0"] + del copied.points["points_0_1"] + del copied_again.points["points_0"] + del copied_again.points["points_0_1"] + # end workaround + + assert_spatial_data_objects_are_identical(full_sdata, copied) + assert_spatial_data_objects_are_identical(full_sdata, copied_again) diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index dac01e80..9c8e7c60 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -106,7 +106,7 @@ def test_return_annot(self, sdata_blobs, regions_element, return_annot): # TODO: consider adding this logic to blobs, to generate blobs with arbitrary table annotation def _annotate_shapes(self, sdata: SpatialData, shape: str) -> SpatialData: new_table = AnnData( - X=np.random.default_rng().random((len(sdata[shape]), 10)), + X=np.random.default_rng(0).random((len(sdata[shape]), 10)), obs=pd.DataFrame({"region": shape, "instance_id": sdata[shape].index.values}), ) new_table = TableModel.parse(new_table, region=shape, region_key="region", instance_key="instance_id") diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py new file mode 100644 index 00000000..50d02677 --- /dev/null +++ b/tests/io/test_multi_table.py @@ -0,0 +1,248 @@ +from pathlib import Path + +import pytest +from anndata import AnnData +from anndata.tests.helpers import assert_equal +from spatialdata import SpatialData, concatenate +from spatialdata.models import TableModel + +from tests.conftest import _get_shapes, _get_table + +# notes on paths: https://github.com/orgs/scverse/projects/17/views/1?pane=issue&itemId=44066734 +test_shapes = _get_shapes() + + +class TestMultiTable: + def test_set_get_tables_from_spatialdata(self, full_sdata: SpatialData, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + adata0 = _get_table(region="polygon") + adata1 = _get_table(region="multipolygon") + full_sdata["adata0"] = adata0 + full_sdata["adata1"] = adata1 + + adata2 = adata0.copy() + del adata2.obs["region"] + # fails because either none either all three 'region', 'region_key', 'instance_key' are required + with pytest.raises(ValueError): + full_sdata["not_added_table"] = adata2 + + assert len(full_sdata.tables) == 3 + assert "adata0" in full_sdata.tables and "adata1" in full_sdata.tables + full_sdata.write(tmpdir) + + full_sdata = SpatialData.read(tmpdir) + assert_equal(adata0, full_sdata["adata0"]) + assert_equal(adata1, full_sdata["adata1"]) + assert "adata0" in full_sdata.tables and "adata1" in full_sdata.tables + + @pytest.mark.parametrize( + "region_key, instance_key, error_msg", + [ + ( + None, + None, + "Specified instance_key in table.uns 'instance_id' is not present as column in table.obs. " + "Please specify instance_key.", + ), + ( + "region", + None, + "Specified instance_key in table.uns 'instance_id' is not present as column in table.obs. " + "Please specify instance_key.", + ), + ("region", "instance_id", "Instance key column 'instance_id' not found in table.obs."), + (None, "instance_id", "Instance key column 'instance_id' not found in table.obs."), + ], + ) + def test_change_annotation_target(self, full_sdata, region_key, instance_key, error_msg): + n_obs = full_sdata["table"].n_obs + ## + with pytest.raises( + ValueError, match=r"Mismatch\(es\) found between regions in region column in obs and target element: " + ): + # ValueError: Mismatch(es) found between regions in region column in obs and target element: labels2d, poly + full_sdata.set_table_annotates_spatialelement("table", "poly") + ## + + del full_sdata["table"].obs["region"] + with pytest.raises( + ValueError, + match="Specified region_key in table.uns 'region' is not present as column in table.obs. " + "Please specify region_key.", + ): + full_sdata.set_table_annotates_spatialelement("table", "poly") + + del full_sdata["table"].obs["instance_id"] + full_sdata["table"].obs["region"] = ["poly"] * n_obs + with pytest.raises(ValueError, match=error_msg): + full_sdata.set_table_annotates_spatialelement( + "table", "poly", region_key=region_key, instance_key=instance_key + ) + + full_sdata["table"].obs["instance_id"] = range(n_obs) + full_sdata.set_table_annotates_spatialelement( + "table", "poly", instance_key="instance_id", region_key=region_key + ) + + with pytest.raises(ValueError, match="'not_existing' column not present in table.obs"): + full_sdata.set_table_annotates_spatialelement("table", "circles", region_key="not_existing") + + def test_set_table_nonexisting_target(self, full_sdata): + with pytest.raises( + ValueError, + match="Annotation target 'non_existing' not present as SpatialElement in SpatialData object.", + ): + full_sdata.set_table_annotates_spatialelement("table", "non_existing") + + def test_set_table_annotates_spatialelement(self, full_sdata): + del full_sdata["table"].uns[TableModel.ATTRS_KEY] + with pytest.raises( + TypeError, match="No current annotation metadata found. " "Please specify both region_key and instance_key." + ): + full_sdata.set_table_annotates_spatialelement("table", "labels2d", region_key="non_existent") + with pytest.raises(ValueError, match="Instance key column 'non_existent' not found in table.obs."): + full_sdata.set_table_annotates_spatialelement( + "table", "labels2d", region_key="region", instance_key="non_existent" + ) + with pytest.raises(ValueError, match="column not present"): + full_sdata.set_table_annotates_spatialelement( + "table", "labels2d", region_key="non_existing", instance_key="instance_id" + ) + full_sdata.set_table_annotates_spatialelement( + "table", "labels2d", region_key="region", instance_key="instance_id" + ) + + def test_old_accessor_deprecation(self, full_sdata, tmp_path): + # To test self._backed + tmpdir = Path(tmp_path) / "tmp.zarr" + full_sdata.write(tmpdir) + adata0 = _get_table(region="polygon") + + with pytest.warns(DeprecationWarning): + _ = full_sdata.table + with pytest.raises(ValueError): + full_sdata.table = adata0 + with pytest.warns(DeprecationWarning): + del full_sdata.table + with pytest.raises(KeyError): + del full_sdata.table + with pytest.warns(DeprecationWarning): + full_sdata.table = adata0 # this gets placed in sdata['table'] + + assert_equal(adata0, full_sdata.table) + + del full_sdata.table + + full_sdata.tables["my_new_table0"] = adata0 + assert full_sdata.table is None + + @pytest.mark.parametrize("region", ["test_shapes", "non_existing"]) + def test_single_table(self, tmp_path: str, region: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + table = _get_table(region=region) + + # Create shapes dictionary + shapes_dict = { + "test_shapes": test_shapes["poly"], + } + + if region == "non_existing": + # annotation target not present in the SpatialData object + with pytest.warns(UserWarning, match=r", which is not present in the SpatialData object"): + SpatialData( + shapes=shapes_dict, + tables={"shape_annotate": table}, + ) + + test_sdata = SpatialData( + shapes=shapes_dict, + tables={"shape_annotate": table}, + ) + + test_sdata.write(tmpdir) + sdata = SpatialData.read(tmpdir) + assert isinstance(sdata["shape_annotate"], AnnData) + assert_equal(test_sdata["shape_annotate"], sdata["shape_annotate"]) + + def test_paired_elements_tables(self, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + table = _get_table(region="poly") + table2 = _get_table(region="multipoly") + table3 = _get_table(region="non_existing") + # annotation target not present in the SpatialData object + with pytest.warns(UserWarning, match=r", which is not present in the SpatialData object"): + SpatialData( + shapes={"poly": test_shapes["poly"], "multipoly": test_shapes["multipoly"]}, + table={"poly_annotate": table, "multipoly_annotate": table3}, + ) + test_sdata = SpatialData( + shapes={"poly": test_shapes["poly"], "multipoly": test_shapes["multipoly"]}, + table={"poly_annotate": table, "multipoly_annotate": table2}, + ) + test_sdata.write(tmpdir) + test_sdata = SpatialData.read(tmpdir) + assert len(test_sdata.tables) == 2 + + def test_single_table_multiple_elements(self, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + table = _get_table(region=["poly", "multipoly"]) + subset = table[table.obs.region == "multipoly"] + with pytest.raises(ValueError, match="Regions in"): + TableModel().validate(subset) + + test_sdata = SpatialData( + shapes={ + "poly": test_shapes["poly"], + "multipoly": test_shapes["multipoly"], + }, + table=table, + ) + test_sdata.write(tmpdir) + SpatialData.read(tmpdir) + + def test_multiple_table_without_element(self, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + table = _get_table(region=None, region_key=None, instance_key=None) + table_two = _get_table(region=None, region_key=None, instance_key=None) + + sdata = SpatialData( + tables={"table": table, "table_two": table_two}, + ) + sdata.write(tmpdir) + SpatialData.read(tmpdir) + + def test_multiple_tables_same_element(self, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + table = _get_table(region="test_shapes") + table2 = _get_table(region="test_shapes") + + test_sdata = SpatialData( + shapes={ + "test_shapes": test_shapes["poly"], + }, + tables={"table": table, "table2": table2}, + ) + test_sdata.write(tmpdir) + SpatialData.read(tmpdir) + + +def test_concatenate_sdata_multitables(): + sdatas = [ + SpatialData( + shapes={f"poly_{i + 1}": test_shapes["poly"], f"multipoly_{i + 1}": test_shapes["multipoly"]}, + tables={"table": _get_table(region=f"poly_{i + 1}"), "table2": _get_table(region=f"multipoly_{i + 1}")}, + ) + for i in range(3) + ] + + with pytest.warns( + UserWarning, + match="Duplicate table names found.", + ): + concatenate(sdatas) + + merged_sdata = concatenate(sdatas, concatenate_tables=True) + assert merged_sdata.tables["table"].n_obs == 300 + assert merged_sdata.tables["table2"].n_obs == 300 + assert all(merged_sdata.tables["table"].obs.region.unique() == ["poly_1", "poly_2", "poly_3"]) + assert all(merged_sdata.tables["table2"].obs.region.unique() == ["multipoly_1", "multipoly_2", "multipoly_3"]) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index e629182d..81d43864 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -14,7 +14,7 @@ from spatial_image import SpatialImage from spatialdata import SpatialData, read_zarr from spatialdata._io._utils import _are_directories_identical -from spatialdata.models import TableModel +from spatialdata.models import Image2DModel, TableModel from spatialdata.transformations.operations import ( get_transformation, set_transformation, @@ -23,7 +23,7 @@ from tests.conftest import _get_images, _get_labels, _get_points, _get_shapes -RNG = default_rng() +RNG = default_rng(0) class TestReadWrite: @@ -319,3 +319,22 @@ def test_io_table(shapes): shapes2.table = adata assert shapes2.table is not None assert shapes2.table.shape == (5, 10) + + +def test_bug_rechunking_after_queried_raster(): + # https://github.com/scverse/spatialdata-io/issues/117 + ## + single_scale = Image2DModel.parse(RNG.random((100, 10, 10)), chunks=(5, 5, 5)) + multi_scale = Image2DModel.parse(RNG.random((100, 10, 10)), scale_factors=[2, 2], chunks=(5, 5, 5)) + images = {"single_scale": single_scale, "multi_scale": multi_scale} + sdata = SpatialData(images=images) + queried = sdata.query.bounding_box( + axes=("x", "y"), min_coordinate=[2, 5], max_coordinate=[12, 12], target_coordinate_system="global" + ) + with tempfile.TemporaryDirectory() as tmpdir: + f = os.path.join(tmpdir, "data.zarr") + queried.write(f) + + ## + + pass diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 4f5033f1..116bdbe4 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -import pathlib +import re import tempfile from copy import deepcopy from functools import partial @@ -119,7 +119,7 @@ def _parse_transformation_from_multiple_places(self, model: Any, element: Any, * str, np.ndarray, dask.array.core.Array, - pathlib.PosixPath, + Path, pd.DataFrame, ) ): @@ -318,7 +318,15 @@ def test_table_model( region: str | np.ndarray, ) -> None: region_key = "reg" - obs = pd.DataFrame(RNG.integers(0, 100, size=(10, 3)), columns=["A", "B", "C"]) + obs = pd.DataFrame( + RNG.choice(np.arange(0, 100, dtype=float), size=(10, 3), replace=False), columns=["A", "B", "C"] + ) + obs[region_key] = region + adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) + with pytest.raises(TypeError, match="Only np.int16"): + model.parse(adata, region=region, region_key=region_key, instance_key="A") + + obs = pd.DataFrame(RNG.choice(np.arange(0, 100), size=(10, 3), replace=False), columns=["A", "B", "C"]) obs[region_key] = region adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) table = model.parse(adata, region=region, region_key=region_key, instance_key="A") @@ -332,6 +340,21 @@ def test_table_model( assert TableModel.REGION_KEY_KEY in table.uns[TableModel.ATTRS_KEY] assert table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] == region + @pytest.mark.parametrize("model", [TableModel]) + @pytest.mark.parametrize("region", [["sample_1"] * 5 + ["sample_2"] * 5]) + def test_table_instance_key_values_not_unique(self, model: TableModel, region: str | np.ndarray): + region_key = "region" + obs = pd.DataFrame(RNG.integers(0, 100, size=(10, 3)), columns=["A", "B", "C"]) + obs[region_key] = region + obs["A"] = [1] * 5 + list(range(5)) + adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) + with pytest.raises(ValueError, match=re.escape("Instance key column for region(s) `sample_1`")): + model.parse(adata, region=region, region_key=region_key, instance_key="A") + + adata.obs["A"] = [1] * 10 + with pytest.raises(ValueError, match=re.escape("Instance key column for region(s) `sample_1, sample_2`")): + model.parse(adata, region=region, region_key=region_key, instance_key="A") + def test_get_schema(): images = _get_images()