From 45ab2df62e9ee11e4f2fd5fb01c9e891d86f1c43 Mon Sep 17 00:00:00 2001 From: giovp Date: Thu, 15 Jun 2023 18:05:20 +0000 Subject: [PATCH 01/38] fix docs --- docs/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index d8865901..c108e49e 100644 --- a/docs/api.md +++ b/docs/api.md @@ -108,7 +108,7 @@ The transformations that can be defined between elements and coordinate systems ## DataLoader ```{eval-rst} -.. currentmodule:: spatialdata.dataloader +.. currentmodule:: spatialdata.dataloader.datasets .. autosummary:: :toctree: generated From 0a7eecb0513e15ff6c01e9508e8f5a0b498bc176 Mon Sep 17 00:00:00 2001 From: giovp Date: Thu, 15 Jun 2023 22:03:30 +0000 Subject: [PATCH 02/38] update --- src/spatialdata/dataloader/datasets.py | 157 +++++++++++++++++++++---- 1 file changed, 135 insertions(+), 22 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 4d719288..849ebf89 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import numpy as np +import pandas as pd from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage from shapely import MultiPolygon, Point, Polygon @@ -10,6 +11,7 @@ from torch.utils.data import Dataset from spatialdata._core.operations.rasterize import rasterize +from spatialdata._types import ArrayLike from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import ( Image2DModel, @@ -17,9 +19,12 @@ Labels2DModel, Labels3DModel, ShapesModel, + TableModel, get_axes_names, get_model, ) +from spatialdata.models._utils import SpatialElement +from spatialdata.transformations import get_transformation from spatialdata.transformations.operations import get_transformation from spatialdata.transformations.transformations import BaseTransformation @@ -28,26 +33,41 @@ class ImageTilesDataset(Dataset): + CS_KEY = "CS" + REGION_KEY = "REGION" + IMAGE_KEY = "IMAGE" + INSTANCE_KEY = "INSTANCE" + def __init__( self, sdata: SpatialData, regions_to_images: dict[str, str], - tile_dim_in_units: float, - tile_dim_in_pixels: int, - target_coordinate_system: str = "global", + tiel_scale: float = 1.0, + tile_dim_in_units: float | None = None, + tile_dim_in_pixels: int | None = None, + coordinate_systems: str | list[str] = None, # unused at the moment, see transform: Optional[Callable[[SpatialData], Any]] = None, ): """ - Torch Dataset that returns image tiles around regions from a SpatialData object. + :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData` object. Parameters ---------- sdata - The SpatialData object containing the regions and images from which to extract the tiles from. + The :class`spatialdata.SpatialData` object. regions_to_images - A dictionary mapping the regions element key we want to extract the tiles around to the images element key - we want to get the image data from. + A mapping betwen region and images. The regions are used to compute the tile centers, while the images are + used to get the pixel values. + tile_scale + The scale of the tiles. This is used only if the `regions` are `shapes`. + It is a scaling factor applied to either the radius (spots) or length (polygons) of the `shapes` + according to the geometry type of the `shapes` element: + + - if `shapes` are circles (spots), the radius is scaled by `tile_scale`. + - if `shapes` are polygons, the length of the polygon is scaled by `tile_scale`. + + If `tile_dim_in_units` is passed, `tile_scale` is ignored. tile_dim_in_units The dimension of the requested tile in the units of the target coordinate system. This specifies the extent of the image each tile is querying. This is not related he size in pixel of each returned tile. @@ -60,28 +80,97 @@ def __init__( # TODO: we can extend this code to support: # - automatic dermination of the tile_dim_in_pixels to match the image resolution (prevent down/upscaling) # - use the bounding box query instead of the raster function if the user wants - self.sdata = sdata - self.regions_to_images = regions_to_images + coordinate_systems = [coordinate_systems] if isinstance(coordinate_systems, str) else coordinate_systems + self._validate(sdata, regions_to_images, coordinate_systems) + self.tile_dim_in_units = tile_dim_in_units self.tile_dim_in_pixels = tile_dim_in_pixels - self.transform = transform - self.target_coordinate_system = target_coordinate_system self.n_spots_dict = self._compute_n_spots_dict() self.n_spots = sum(self.n_spots_dict.values()) - def _validate_regions_to_images(self) -> None: - for region_key, image_key in self.regions_to_images.items(): - regions_element = self.sdata[region_key] - images_element = self.sdata[image_key] - # we could allow also for points - if get_model(regions_element) not in [ShapesModel, Labels2DModel, Labels3DModel]: - raise ValueError("regions_element must be a shapes element or a labels element") - if get_model(images_element) not in [Image2DModel, Image3DModel]: - raise ValueError("images_element must be an image element") - - def _compute_n_spots_dict(self) -> dict[str, int]: + def _validate( + self, + sdata: SpatialData, + regions_to_images: dict[str, str], + coordinate_systems: list[str], + ) -> None: + """Validate input parameters.""" + self._region_key = sdata.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] + self._instance_key = sdata.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] + available_regions = sdata.obs[region_key].unique() + cs_region_image = [] # list of tuples (coordinate_system, region, image) + + for region_key, image_key in regions_to_images.items(): + if region_key not in available_regions: + raise ValueError(f"region {region_key} not found in the spatialdata object.") + + # get elements + region_elem = sdata[region_key] + image_elem = sdata[image_key] + + # check that the elements are supported + if get_model(region_elem) in [Labels2DModel, Labels3DModel]: + raise NotImplementedError("labels elements are not implemented yet.") + if get_model(region_elem) not in [ShapesModel]: + raise ValueError("`regions_element` must be a shapes element.") + if get_model(image_elem) not in [Image2DModel, Image3DModel]: + raise ValueError("`images_element` must be an image element.") + + # check that the coordinate systems are valid for the elements + region_trans = get_transformation(region_elem) + image_trans = get_transformation(image_elem) + + for cs in coordinate_systems: + if cs in region_trans and cs in image_trans: + cs_region_image.append(tuple(cs, region_key, image_key)) + + self.regions = list(available_regions.keys()) # all regions for the dataloader + self.sdata = sdata + self._cs_region_image = tuple(cs_region_image) # tuple(coordinate_system, region_key, image_key) + + def _preprocess(self, tile_scale, tile_dim_in_units) -> None: + """Preprocess the dataset.""" + index_df = [] + + for cs, region, image in self._cs_region_image: + # get instances from region + inst = self.sdata.table.obs[self.sdata.table.obs[self._region_key] == region][self._instance_key] + # get extent for bounding box + extent = self._get_extent(self.sdata[region], tile_scale, tile_dim_in_units) + if len(extent) == 1: + extent = np.repeat(extent, len(inst)) + if len(extent) != len(inst): + raise ValueError( + f"the number of elements in the region {region} ({len(extent)}) does not match the number of " + f"instances ({len(inst)})." + ) + df = pd.DataFrame({self.INSTANCE_KEY: inst}) + df[self.CS_KEY] = cs + df[self.REGION_KEY] = region + df[self.IMAGE_KEY] = image + index_df.append(df) + + self.index = pd.concat(index_df).reset_index(inplace=True, drop=True) + + def _get_extent( + self, + elem: SpatialElement, + tile_scale: float | None = None, + tile_dim_in_units: float | None = None, + ) -> ArrayLike: + """Get the extent of the region.""" + if tile_dim_in_units is None: + if elem.iloc[0][0].geom_type == "Point": + return elem[ShapesModel.RADIUS_KEY].values * tile_scale + if elem.iloc[0][0].geom_type == "Polygon": + return elem[ShapesModel.GEOMETRY_KEY].length * tile_scale + raise ValueError("Only point and polygon shapes are supported.") + return tile_dim_in_units + + def _get_unique_instances(self) -> dict[str, Any]: n_spots_dict = {} + region_key = self.sdata.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] for region_key in self.regions_to_images: element = self.sdata[region_key] # we could allow also points @@ -186,3 +275,27 @@ def __getitem__(self, idx: int) -> Any | SpatialData: if self.transform is not None: return self.transform(tile_sdata) return tile_sdata + + @property + def regions(self) -> list[str]: + return self._regions + + @regions.setter + def regions(self, regions: list[str]) -> None: + self._regions = regions + + @property + def sdata(self) -> SpatialData: + return self._sdata + + @sdata.setter + def sdata(self, sdata: SpatialData) -> None: + self._sdata = sdata + + @property + def coordinate_systems(self) -> list[str]: + return self._coordinate_systems + + @coordinate_systems.setter + def coordinate_systems(self, coordinate_systems: list[str]) -> None: + self._coordinate_systems = coordinate_systems From 6cff536b2831311f3efc69c92da42275adcc0e21 Mon Sep 17 00:00:00 2001 From: giovp Date: Fri, 16 Jun 2023 19:30:28 +0000 Subject: [PATCH 03/38] get tile centroid and eextent outside of function --- src/spatialdata/dataloader/datasets.py | 190 ++++++++++++++----------- 1 file changed, 110 insertions(+), 80 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 849ebf89..900dcf0b 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -11,7 +11,6 @@ from torch.utils.data import Dataset from spatialdata._core.operations.rasterize import rasterize -from spatialdata._types import ArrayLike from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import ( Image2DModel, @@ -23,9 +22,7 @@ get_axes_names, get_model, ) -from spatialdata.models._utils import SpatialElement from spatialdata.transformations import get_transformation -from spatialdata.transformations.operations import get_transformation from spatialdata.transformations.transformations import BaseTransformation if TYPE_CHECKING: @@ -42,12 +39,9 @@ def __init__( self, sdata: SpatialData, regions_to_images: dict[str, str], - tiel_scale: float = 1.0, + regions_to_coordinate_systems: dict[str, str], + tile_scale: float = 1.0, tile_dim_in_units: float | None = None, - tile_dim_in_pixels: int | None = None, - coordinate_systems: str | list[str] = None, - # unused at the moment, see - transform: Optional[Callable[[SpatialData], Any]] = None, ): """ :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData` object. @@ -59,6 +53,9 @@ def __init__( regions_to_images A mapping betwen region and images. The regions are used to compute the tile centers, while the images are used to get the pixel values. + regions_to_coordinate_systems + A mapping between regions and coordinate systems. The coordinate systems are used to transform both + regions coordinates for tiles as well as images. tile_scale The scale of the tiles. This is used only if the `regions` are `shapes`. It is a scaling factor applied to either the radius (spots) or length (polygons) of the `shapes` @@ -74,31 +71,23 @@ def __init__( tile_dim_in_pixels The dimension of the requested tile in pixels. This specifies the size of the output tiles that we will get, independently of which extent of the image the tile is covering. - target_coordinate_system - The coordinate system in which the tile_dim_in_units is specified. """ # TODO: we can extend this code to support: # - automatic dermination of the tile_dim_in_pixels to match the image resolution (prevent down/upscaling) # - use the bounding box query instead of the raster function if the user wants - coordinate_systems = [coordinate_systems] if isinstance(coordinate_systems, str) else coordinate_systems - self._validate(sdata, regions_to_images, coordinate_systems) - - self.tile_dim_in_units = tile_dim_in_units - self.tile_dim_in_pixels = tile_dim_in_pixels - - self.n_spots_dict = self._compute_n_spots_dict() - self.n_spots = sum(self.n_spots_dict.values()) + self._validate(sdata, regions_to_images, regions_to_coordinate_systems) + self._preprocess(tile_scale, tile_dim_in_units) def _validate( self, sdata: SpatialData, regions_to_images: dict[str, str], - coordinate_systems: list[str], + regions_to_coordinate_systems: dict[str, str], ) -> None: """Validate input parameters.""" - self._region_key = sdata.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] - self._instance_key = sdata.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] - available_regions = sdata.obs[region_key].unique() + self._region_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] + self._instance_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] + available_regions = sdata.table.obs[self._region_key].unique() cs_region_image = [] # list of tuples (coordinate_system, region, image) for region_key, image_key in regions_to_images.items(): @@ -121,78 +110,54 @@ def _validate( region_trans = get_transformation(region_elem) image_trans = get_transformation(image_elem) - for cs in coordinate_systems: - if cs in region_trans and cs in image_trans: - cs_region_image.append(tuple(cs, region_key, image_key)) + try: + cs = regions_to_coordinate_systems[region_key] + region_trans = get_transformation(region_elem, cs) + image_trans = get_transformation(image_elem, cs) + if isinstance(region_trans, BaseTransformation) and isinstance(image_trans, BaseTransformation): + cs_region_image.append((cs, region_key, image_key)) + except KeyError as e: + raise KeyError(f"region {region_key} not found in `regions_to_coordinate_systems`") from e self.regions = list(available_regions.keys()) # all regions for the dataloader self.sdata = sdata self._cs_region_image = tuple(cs_region_image) # tuple(coordinate_system, region_key, image_key) - def _preprocess(self, tile_scale, tile_dim_in_units) -> None: + def _preprocess( + self, + tile_scale: float = 1.0, + tile_dim_in_units: float | None = None, + ) -> None: """Preprocess the dataset.""" index_df = [] + tile_coords_df = [] for cs, region, image in self._cs_region_image: + # get dims and transformations for the region element + dims = get_axes_names(region) + t = get_transformation(region, cs) + assert isinstance(t, BaseTransformation) + + # get coordinates of centroids and extent for tiles + tile_coords = _get_tile_coords(self.sdata[region], t, dims, tile_scale, tile_dim_in_units) + tile_coords_df.append(tile_coords) + # get instances from region - inst = self.sdata.table.obs[self.sdata.table.obs[self._region_key] == region][self._instance_key] - # get extent for bounding box - extent = self._get_extent(self.sdata[region], tile_scale, tile_dim_in_units) - if len(extent) == 1: - extent = np.repeat(extent, len(inst)) - if len(extent) != len(inst): - raise ValueError( - f"the number of elements in the region {region} ({len(extent)}) does not match the number of " - f"instances ({len(inst)})." - ) + inst = self.sdata.table.obs[self.sdata.table.obs[self._region_key] == region][self._instance_key].values + # get index dictionary, with `instance_id`, `cs`, `region`, and `image` df = pd.DataFrame({self.INSTANCE_KEY: inst}) df[self.CS_KEY] = cs df[self.REGION_KEY] = region df[self.IMAGE_KEY] = image index_df.append(df) - self.index = pd.concat(index_df).reset_index(inplace=True, drop=True) - - def _get_extent( - self, - elem: SpatialElement, - tile_scale: float | None = None, - tile_dim_in_units: float | None = None, - ) -> ArrayLike: - """Get the extent of the region.""" - if tile_dim_in_units is None: - if elem.iloc[0][0].geom_type == "Point": - return elem[ShapesModel.RADIUS_KEY].values * tile_scale - if elem.iloc[0][0].geom_type == "Polygon": - return elem[ShapesModel.GEOMETRY_KEY].length * tile_scale - raise ValueError("Only point and polygon shapes are supported.") - return tile_dim_in_units - - def _get_unique_instances(self) -> dict[str, Any]: - n_spots_dict = {} - region_key = self.sdata.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] - for region_key in self.regions_to_images: - element = self.sdata[region_key] - # we could allow also points - if isinstance(element, GeoDataFrame): - n_spots_dict[region_key] = len(element) - elif isinstance(element, (SpatialImage, MultiscaleSpatialImage)): - raise NotImplementedError("labels not supported yet") - else: - raise ValueError("element must be a geodataframe or a spatial image") - return n_spots_dict - - def _get_region_info_for_index(self, index: int) -> tuple[str, int]: - # TODO: this implmenetation can be improved - i = 0 - for region_key, n_spots in self.n_spots_dict.items(): - if index < i + n_spots: - return region_key, index - i - i += n_spots - raise ValueError(f"index {index} is out of range") + # concatenate and assign to self + self.dataset_index = pd.concat(index_df).reset_index(inplace=True, drop=True) + self.tiles_coords = pd.concat(tile_coords_df).reset_index(inplace=True, drop=True) + assert len(self.tiles_coords) == len(self.dataset_index) def __len__(self) -> int: - return self.n_spots + return len(self.dataset_index) def __getitem__(self, idx: int) -> Any | SpatialData: from spatialdata import SpatialData @@ -278,24 +243,89 @@ def __getitem__(self, idx: int) -> Any | SpatialData: @property def regions(self) -> list[str]: + """List of regions in the dataset.""" return self._regions @regions.setter - def regions(self, regions: list[str]) -> None: + def regions(self, regions: list[str]) -> None: # D102 self._regions = regions @property def sdata(self) -> SpatialData: + """SpatialData object.""" return self._sdata @sdata.setter - def sdata(self, sdata: SpatialData) -> None: + def sdata(self, sdata: SpatialData) -> None: # D102 self._sdata = sdata @property def coordinate_systems(self) -> list[str]: + """List of coordinate systems in the dataset.""" return self._coordinate_systems @coordinate_systems.setter - def coordinate_systems(self, coordinate_systems: list[str]) -> None: + def coordinate_systems(self, coordinate_systems: list[str]) -> None: # D102 self._coordinate_systems = coordinate_systems + + @property + def tiles_coords(self) -> pd.DataFrame: + """DataFrame with the index of tiles.""" + return self._tiles_coords + + @tiles_coords.setter + def tiles_coords(self, tiles: pd.DataFrame) -> None: + self._tiles_coords = tiles + + @property + def dataset_index(self) -> pd.DataFrame: + """DataFrame with the metadata of the tiles. + + It contains the following columns: + + - INSTANCE: the name of the instance in the region. + - CS: the coordinate system of the region-image pair. + - REGION: the name of the region. + - IMAGE: the name of the image. + """ + return self._dataset_index + + @dataset_index.setter + def dataset_index(self, dataset_index: pd.DataFrame) -> None: + self._dataset_index = dataset_index + + +def _get_tile_coords( + elem: GeoDataFrame, + transformation: BaseTransformation, + dims: tuple[str, ...], + tile_scale: float | None = None, + tile_dim_in_units: float | None = None, +) -> pd.DataFrame: + """Get the (transformed) centroid of the region and the extent.""" + # get centroids and transform them + centroids = elem.centroid.get_coordinates() + aff = transformation.to_affine_matrix(input_axes=dims, output_axes=dims) + centroids = np.squeeze(_affine_matrix_multiplication(aff, centroids), 0) + centroids = pd.DataFrame(centroids, columns=dims) + + # get extent, first by checking shape defaults, then by using the `tile_dim_in_units` + if tile_dim_in_units is None: + if elem.iloc[0][0].geom_type == "Point": + extent = elem[ShapesModel.RADIUS_KEY].values * tile_scale + if elem.iloc[0][0].geom_type == "Polygon": + extent = elem[ShapesModel.GEOMETRY_KEY].length * tile_scale + raise ValueError("Only point and polygon shapes are supported.") + if tile_dim_in_units is not None: + if isinstance(tile_dim_in_units, float): + extent = np.repeat(tile_dim_in_units, len(centroids)) + if len(extent) != len(centroids): + raise ValueError( + f"the number of elements in the region ({len(extent)}) does not match" + f" the number of instances ({len(centroids)})." + ) + + # transform extent + aff = transformation.to_affine_matrix(input_axes=tuple(dims[0]), output_axes=tuple(dims[0])) + centroids["extent"] = np.squeeze(_affine_matrix_multiplication(aff, extent), 0) + return centroids From 638dd334fbd5ebd0723feb049f980639817ed851 Mon Sep 17 00:00:00 2001 From: giovp Date: Sat, 17 Jun 2023 21:16:53 +0000 Subject: [PATCH 04/38] add return type --- src/spatialdata/dataloader/datasets.py | 233 ++++++++++++++----------- 1 file changed, 129 insertions(+), 104 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 900dcf0b..195f333d 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -1,15 +1,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from itertools import chain +from typing import Any, Callable, Optional import numpy as np import pandas as pd +from anndata import AnnData from geopandas import GeoDataFrame -from multiscale_spatial_image import MultiscaleSpatialImage -from shapely import MultiPolygon, Point, Polygon -from spatial_image import SpatialImage +from scipy.sparse import issparse from torch.utils.data import Dataset +from spatialdata import SpatialData, bounding_box_query from spatialdata._core.operations.rasterize import rasterize from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import ( @@ -25,15 +26,12 @@ from spatialdata.transformations import get_transformation from spatialdata.transformations.transformations import BaseTransformation -if TYPE_CHECKING: - from spatialdata import SpatialData - class ImageTilesDataset(Dataset): - CS_KEY = "CS" - REGION_KEY = "REGION" - IMAGE_KEY = "IMAGE" - INSTANCE_KEY = "INSTANCE" + INSTANCE_KEY = "instance" + CS_KEY = "cs" + REGION_KEY = "region" + IMAGE_KEY = "image" def __init__( self, @@ -42,10 +40,19 @@ def __init__( regions_to_coordinate_systems: dict[str, str], tile_scale: float = 1.0, tile_dim_in_units: float | None = None, + raster: bool = False, + return_table: str | list[str] | None = None, + *kwargs: Any, ): """ :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData` object. + By default, the dataset returns spatialdata object, but when `return_image` and `return_table` + are set, the dataset may return a tuple containing: + + - the tile image, centered in the target coordinate system of the region. + - a vector or scala value from the table. + Parameters ---------- sdata @@ -68,15 +75,39 @@ def __init__( tile_dim_in_units The dimension of the requested tile in the units of the target coordinate system. This specifies the extent of the image each tile is querying. This is not related he size in pixel of each returned tile. - tile_dim_in_pixels - The dimension of the requested tile in pixels. This specifies the size of the output tiles that we will get, - independently of which extent of the image the tile is covering. + rasterize + If True, the regions are rasterized using :func:`spatialdata.rasterize`. + If False, uses the :func:`spatialdata.bounding_box_query`. + return_table + If not None, a value from the table is returned together with the image. + Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` + can be returned. It will not be returned a spatialdata object but only a tuple + containing the image and the table value. + + Returns + ------- + :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData`. """ # TODO: we can extend this code to support: # - automatic dermination of the tile_dim_in_pixels to match the image resolution (prevent down/upscaling) # - use the bounding box query instead of the raster function if the user wants self._validate(sdata, regions_to_images, regions_to_coordinate_systems) self._preprocess(tile_scale, tile_dim_in_units) + self._crop_image: Callable[..., Any] = rasterize if raster else bounding_box_query + + if return_table is not None: + return_table = [return_table] if isinstance(return_table, str) else return_table + if return_table in self.dataset_table.obs: + self._return_table: Optional[Callable[[int], Any]] = ( + lambda x: self.dataset_table.obs[return_table].iloc[x].values.reshape(1, -1) + ) + if return_table in sdata.table.var_names: + if issparse(self.dataset_table.X): + self._return_table = lambda x: self.dataset_table.X[:, return_table].X[x].A + else: + self._return_table = lambda x: self.dataset_table.X[:, return_table].X[x] + else: + self._return_table = None def _validate( self, @@ -90,6 +121,14 @@ def _validate( available_regions = sdata.table.obs[self._region_key].unique() cs_region_image = [] # list of tuples (coordinate_system, region, image) + # check unique matching between regions and images and coordinate systems + assert len(set(regions_to_images.values())) == len( + regions_to_images.keys() + ), "One region cannot be paired to multiple regions." + assert len(set(regions_to_coordinate_systems.values())) == len( + regions_to_coordinate_systems.keys() + ), "One region cannot be paired to multiple coordinate systems." + for region_key, image_key in regions_to_images.items(): if region_key not in available_regions: raise ValueError(f"region {region_key} not found in the spatialdata object.") @@ -121,7 +160,8 @@ def _validate( self.regions = list(available_regions.keys()) # all regions for the dataloader self.sdata = sdata - self._cs_region_image = tuple(cs_region_image) # tuple(coordinate_system, region_key, image_key) + self.dataset_table = self.sdata.table.obs[self.sdata.table.obs[self._region_key].isin(self.regions)] + self._cs_region_image = tuple(cs_region_image) # tuple of tuples (coordinate_system, region_key, image_key) def _preprocess( self, @@ -131,10 +171,12 @@ def _preprocess( """Preprocess the dataset.""" index_df = [] tile_coords_df = [] + dims_l = [] for cs, region, image in self._cs_region_image: # get dims and transformations for the region element dims = get_axes_names(region) + dims_l.append(dims) t = get_transformation(region, cs) assert isinstance(t, BaseTransformation) @@ -144,6 +186,7 @@ def _preprocess( # get instances from region inst = self.sdata.table.obs[self.sdata.table.obs[self._region_key] == region][self._instance_key].values + # get index dictionary, with `instance_id`, `cs`, `region`, and `image` df = pd.DataFrame({self.INSTANCE_KEY: inst}) df[self.CS_KEY] = cs @@ -154,92 +197,36 @@ def _preprocess( # concatenate and assign to self self.dataset_index = pd.concat(index_df).reset_index(inplace=True, drop=True) self.tiles_coords = pd.concat(tile_coords_df).reset_index(inplace=True, drop=True) + # get table filtered by regions + self.filtered_table = self.sdata.table.obs[self.sdata.table.obs[self._region_key].isin[self._cs_region_imag[1]]] + assert len(self.tiles_coords) == len(self.dataset_index) + dims_ = set(chain(*dims_l)) + assert np.all([i in self.tiles_coords for i in dims_]) + self.dims = list(dims_) def __len__(self) -> int: return len(self.dataset_index) def __getitem__(self, idx: int) -> Any | SpatialData: - from spatialdata import SpatialData - - if idx >= self.n_spots: - raise IndexError() - regions_name, region_index = self._get_region_info_for_index(idx) - regions = self.sdata[regions_name] - # TODO: here we just need to compute the centroids, - # we probably want to move this functionality to a different file - if isinstance(regions, GeoDataFrame): - dims = get_axes_names(regions) - region = regions.iloc[region_index] - shape = regions.geometry.iloc[0] - if isinstance(shape, Polygon): - xy = region.geometry.centroid.coords.xy - centroid = np.array([[xy[0][0], xy[1][0]]]) - elif isinstance(shape, MultiPolygon): - raise NotImplementedError("MultiPolygon not supported yet") - elif isinstance(shape, Point): - xy = region.geometry.coords.xy - centroid = np.array([[xy[0][0], xy[1][0]]]) - else: - raise RuntimeError(f"Unsupported type: {type(shape)}") - - t = get_transformation(regions, self.target_coordinate_system) - assert isinstance(t, BaseTransformation) - aff = t.to_affine_matrix(input_axes=dims, output_axes=dims) - transformed_centroid = np.squeeze(_affine_matrix_multiplication(aff, centroid), 0) - elif isinstance(regions, (SpatialImage, MultiscaleSpatialImage)): - raise NotImplementedError("labels not supported yet") - else: - raise ValueError("element must be shapes or labels") - min_coordinate = np.array(transformed_centroid) - self.tile_dim_in_units / 2 - max_coordinate = np.array(transformed_centroid) + self.tile_dim_in_units / 2 - - raster = self.sdata[self.regions_to_images[regions_name]] - tile = rasterize( - raster, - axes=dims, - min_coordinate=min_coordinate, - max_coordinate=max_coordinate, - target_coordinate_system=self.target_coordinate_system, - target_width=self.tile_dim_in_pixels, + """Get item from the dataset.""" + # get the row from the index + row = self.dataset_index.iloc[idx] + # get the tile coordinates + t_coords = self.tiles_coords.iloc[idx] + + image = self.sdata[row["image"]] + tile = self._crop_image( + image, + axes=self.dims, + min_coordinate=t_coords[[f"min{i}" for i in self.dims]], + max_coordinate=t_coords[[f"min{i}" for i in self.dims]], + target_coordinate_system=row["cs"], ) - # TODO: as explained in the TODO in the __init__(), we want to let the - # user also use the bounding box query instaed of the rasterization - # the return function of this function would change, so we need to - # decide if instead having an extra Tile dataset class - # from spatialdata._core._spatial_query import BoundingBoxRequest - # request = BoundingBoxRequest( - # target_coordinate_system=self.target_coordinate_system, - # axes=dims, - # min_coordinate=min_coordinate, - # max_coordinate=max_coordinate, - # ) - # sdata_item = self.sdata.query.bounding_box(**request.to_dict()) - table = self.sdata.table - filter_table = False - if table is not None: - region = table.uns["spatialdata_attrs"]["region"] - region_key = table.uns["spatialdata_attrs"]["region_key"] - instance_key = table.uns["spatialdata_attrs"]["instance_key"] - if isinstance(region, str): - if regions_name == region: - filter_table = True - elif isinstance(region, list): - if regions_name in region: - filter_table = True - else: - raise ValueError("region must be a string or a list of strings") - # TODO: maybe slow, we should check if there is a better way to do this - if filter_table: - instance = self.sdata[regions_name].iloc[region_index].name - row = table[(table.obs[region_key] == regions_name) & (table.obs[instance_key] == instance)].copy() - tile_table = row - else: - tile_table = None - tile_sdata = SpatialData(images={self.regions_to_images[regions_name]: tile}, table=tile_table) - if self.transform is not None: - return self.transform(tile_sdata) - return tile_sdata + + if self._return_table is not None: + return tile, self.filtered_table(idx) + return SpatialData(images={t_coords[self.REGION_KEY][idx]: tile}, table=self.dataset_table[idx]) @property def regions(self) -> list[str]: @@ -252,7 +239,7 @@ def regions(self, regions: list[str]) -> None: # D102 @property def sdata(self) -> SpatialData: - """SpatialData object.""" + """The original SpatialData object.""" return self._sdata @sdata.setter @@ -270,7 +257,19 @@ def coordinate_systems(self, coordinate_systems: list[str]) -> None: # D102 @property def tiles_coords(self) -> pd.DataFrame: - """DataFrame with the index of tiles.""" + """DataFrame with the index of tiles. + + It contains axis coordinates of the centroids, and extent of the tiles. + For example, for a 2D image, it contains the following columns: + + - `x`: the x coordinate of the centroid. + - `y`: the y coordinate of the centroid. + - `extent`: the extent of the tile. + - `minx`: the minimum x coordinate of the tile. + - `miny`: the minimum y coordinate of the tile. + - `maxx`: the maximum x coordinate of the tile. + - `maxy`: the maximum y coordinate of the tile. + """ return self._tiles_coords @tiles_coords.setter @@ -283,10 +282,10 @@ def dataset_index(self) -> pd.DataFrame: It contains the following columns: - - INSTANCE: the name of the instance in the region. - - CS: the coordinate system of the region-image pair. - - REGION: the name of the region. - - IMAGE: the name of the image. + - `instance`: the name of the instance in the region. + - `cs`: the coordinate system of the region-image pair. + - `region`: the name of the region. + - `image`: the name of the image. """ return self._dataset_index @@ -294,6 +293,24 @@ def dataset_index(self) -> pd.DataFrame: def dataset_index(self, dataset_index: pd.DataFrame) -> None: self._dataset_index = dataset_index + @property + def dataset_table(self) -> AnnData: + """AnnData table filtered by the `region` and `cs` present in the dataset.""" + return self._dataset_table + + @dataset_table.setter + def dataset_table(self, dataset_table: AnnData) -> None: + self._dataset_table = dataset_table + + @property + def dims(self) -> list[str]: + """Dimensions of the dataset.""" + return self._dims + + @dims.setter + def dims(self, dims: list[str]) -> None: + self._dims = dims + def _get_tile_coords( elem: GeoDataFrame, @@ -304,10 +321,9 @@ def _get_tile_coords( ) -> pd.DataFrame: """Get the (transformed) centroid of the region and the extent.""" # get centroids and transform them - centroids = elem.centroid.get_coordinates() + centroids = elem.centroid.get_coordinates().values aff = transformation.to_affine_matrix(input_axes=dims, output_axes=dims) centroids = np.squeeze(_affine_matrix_multiplication(aff, centroids), 0) - centroids = pd.DataFrame(centroids, columns=dims) # get extent, first by checking shape defaults, then by using the `tile_dim_in_units` if tile_dim_in_units is None: @@ -327,5 +343,14 @@ def _get_tile_coords( # transform extent aff = transformation.to_affine_matrix(input_axes=tuple(dims[0]), output_axes=tuple(dims[0])) - centroids["extent"] = np.squeeze(_affine_matrix_multiplication(aff, extent), 0) - return centroids + extent = np.squeeze(_affine_matrix_multiplication(aff, extent), 0) + + # get min and max coordinates + min_coordinates = np.array(centroids.values) - extent / 2 + max_coordinates = np.array(centroids.values) + extent / 2 + + # return a dataframe with columns e.g. ["x", "y", "extent", "minx", "miny", "maxx", "maxy"] + return pd.DataFrame( + np.hstack([centroids, extent[:, np.newaxis], min_coordinates, max_coordinates]), + columns=list(dims) + ["extent"] + ["min" + dim for dim in dims] + ["max" + dim for dim in dims], + ) From 17363771deb3d73839d5c20d6d284c90cc6ab626 Mon Sep 17 00:00:00 2001 From: giovp Date: Sat, 17 Jun 2023 21:20:05 +0000 Subject: [PATCH 05/38] move return table out of init --- src/spatialdata/dataloader/datasets.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 195f333d..b15660e4 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -94,20 +94,7 @@ def __init__( self._validate(sdata, regions_to_images, regions_to_coordinate_systems) self._preprocess(tile_scale, tile_dim_in_units) self._crop_image: Callable[..., Any] = rasterize if raster else bounding_box_query - - if return_table is not None: - return_table = [return_table] if isinstance(return_table, str) else return_table - if return_table in self.dataset_table.obs: - self._return_table: Optional[Callable[[int], Any]] = ( - lambda x: self.dataset_table.obs[return_table].iloc[x].values.reshape(1, -1) - ) - if return_table in sdata.table.var_names: - if issparse(self.dataset_table.X): - self._return_table = lambda x: self.dataset_table.X[:, return_table].X[x].A - else: - self._return_table = lambda x: self.dataset_table.X[:, return_table].X[x] - else: - self._return_table = None + self._return_table = self._get_return_table(return_table) def _validate( self, @@ -205,6 +192,17 @@ def _preprocess( assert np.all([i in self.tiles_coords for i in dims_]) self.dims = list(dims_) + def _get_return_table(self, return_table: str | list[str] | None) -> Optional[Callable[[int], Any]] | None: + if return_table is not None: + return_table = [return_table] if isinstance(return_table, str) else return_table + if return_table in self.dataset_table.obs: + return lambda x: self.dataset_table.obs[return_table].iloc[x].values.reshape(1, -1) + if return_table in self.dataset_table.var_names: + if issparse(self.dataset_table.X): + return lambda x: self.dataset_table.X[:, return_table].X[x].A + return lambda x: self.dataset_table.X[:, return_table].X[x] + return None + def __len__(self) -> int: return len(self.dataset_index) From 1ee6abb48ab8895294c6e113399a2eac77a5bc17 Mon Sep 17 00:00:00 2001 From: giovp Date: Sat, 17 Jun 2023 21:22:08 +0000 Subject: [PATCH 06/38] add comments --- src/spatialdata/dataloader/datasets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index b15660e4..c5f88c94 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -193,8 +193,10 @@ def _preprocess( self.dims = list(dims_) def _get_return_table(self, return_table: str | list[str] | None) -> Optional[Callable[[int], Any]] | None: + """Get function to return values from the table of the dataset.""" if return_table is not None: return_table = [return_table] if isinstance(return_table, str) else return_table + # return callable that always return array of shape (1, len(return_table)) if return_table in self.dataset_table.obs: return lambda x: self.dataset_table.obs[return_table].iloc[x].values.reshape(1, -1) if return_table in self.dataset_table.var_names: From 90794906b5f43ce292c9d519b4fed486cdfa2dc5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Jun 2023 19:16:30 +0000 Subject: [PATCH 07/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spatialdata/dataloader/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index f5e92dae..c5f88c94 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import Any, Callable, Optional import numpy as np import pandas as pd From 38cfcff4a263aa38e7f710086ffb95afed97e47a Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 20 Jun 2023 19:17:22 +0000 Subject: [PATCH 08/38] update precommit --- src/spatialdata/dataloader/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index c5f88c94..f259538e 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import Any, Callable, Optional +from typing import Any, Callable import numpy as np import pandas as pd @@ -192,7 +192,7 @@ def _preprocess( assert np.all([i in self.tiles_coords for i in dims_]) self.dims = list(dims_) - def _get_return_table(self, return_table: str | list[str] | None) -> Optional[Callable[[int], Any]] | None: + def _get_return_table(self, return_table: str | list[str] | None) -> Callable[[int], Any] | None: """Get function to return values from the table of the dataset.""" if return_table is not None: return_table = [return_table] if isinstance(return_table, str) else return_table From b9a5b9e8551a5d755867dc2850bc4b9db95f50e5 Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 21 Jun 2023 16:08:44 +0000 Subject: [PATCH 09/38] simplify return --- src/spatialdata/dataloader/__init__.py | 7 +--- src/spatialdata/dataloader/datasets.py | 49 +++++++++++++++----------- tests/dataloader/test_datasets.py | 41 ++++++++++----------- tests/dataloader/test_transforms.py | 0 4 files changed, 51 insertions(+), 46 deletions(-) delete mode 100644 tests/dataloader/test_transforms.py diff --git a/src/spatialdata/dataloader/__init__.py b/src/spatialdata/dataloader/__init__.py index f9262f85..297b221f 100644 --- a/src/spatialdata/dataloader/__init__.py +++ b/src/spatialdata/dataloader/__init__.py @@ -1,6 +1 @@ -import contextlib - -with contextlib.suppress(ImportError): - from spatialdata.dataloader.datasets import ImageTilesDataset - -__all__ = ["ImageTilesDataset"] +from spatialdata.dataloader.datasets import ImageTilesDataset diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index f259538e..1cd58b19 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable import numpy as np import pandas as pd @@ -10,7 +10,7 @@ from scipy.sparse import issparse from torch.utils.data import Dataset -from spatialdata import SpatialData, bounding_box_query +from spatialdata import bounding_box_query from spatialdata._core.operations.rasterize import rasterize from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import ( @@ -26,6 +26,11 @@ from spatialdata.transformations import get_transformation from spatialdata.transformations.transformations import BaseTransformation +__all__ = ["ImageTilesDataset"] + +if TYPE_CHECKING: + from spatialdata import SpatialData + class ImageTilesDataset(Dataset): INSTANCE_KEY = "instance" @@ -42,7 +47,6 @@ def __init__( tile_dim_in_units: float | None = None, raster: bool = False, return_table: str | list[str] | None = None, - *kwargs: Any, ): """ :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData` object. @@ -76,25 +80,22 @@ def __init__( The dimension of the requested tile in the units of the target coordinate system. This specifies the extent of the image each tile is querying. This is not related he size in pixel of each returned tile. rasterize - If True, the regions are rasterized using :func:`spatialdata.rasterize`. - If False, uses the :func:`spatialdata.bounding_box_query`. + If True, the images are rasterized using :func:`spatialdata.rasterize`. + If False, they are rasterized using :func:`spatialdata.bounding_box_query`. return_table - If not None, a value from the table is returned together with the image. + If not None, a value from the table is returned together with the image tile. Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` - can be returned. It will not be returned a spatialdata object but only a tuple + can be returned. If None, it will return a spatialdata object with only the tuple containing the image and the table value. Returns ------- :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData`. """ - # TODO: we can extend this code to support: - # - automatic dermination of the tile_dim_in_pixels to match the image resolution (prevent down/upscaling) - # - use the bounding box query instead of the raster function if the user wants self._validate(sdata, regions_to_images, regions_to_coordinate_systems) self._preprocess(tile_scale, tile_dim_in_units) self._crop_image: Callable[..., Any] = rasterize if raster else bounding_box_query - self._return_table = self._get_return_table(return_table) + self._return = self._get_return(return_table) def _validate( self, @@ -192,18 +193,28 @@ def _preprocess( assert np.all([i in self.tiles_coords for i in dims_]) self.dims = list(dims_) - def _get_return_table(self, return_table: str | list[str] | None) -> Callable[[int], Any] | None: + def _get_return( + self, + return_table: str | list[str] | None, + ) -> Callable[[int, Any], tuple[Any, Any] | SpatialData]: """Get function to return values from the table of the dataset.""" + from spatialdata import SpatialData + if return_table is not None: + # table is always returned as array shape (1, len(return_table)) + # where return_table can be a single column or a list of columns return_table = [return_table] if isinstance(return_table, str) else return_table - # return callable that always return array of shape (1, len(return_table)) + # return tuple of (tile, table) if return_table in self.dataset_table.obs: - return lambda x: self.dataset_table.obs[return_table].iloc[x].values.reshape(1, -1) + return lambda x, tile: (tile, self.dataset_table.obs[return_table].iloc[x].values.reshape(1, -1)) if return_table in self.dataset_table.var_names: if issparse(self.dataset_table.X): - return lambda x: self.dataset_table.X[:, return_table].X[x].A - return lambda x: self.dataset_table.X[:, return_table].X[x] - return None + return lambda x, tile: (tile, self.dataset_table.X[:, return_table].X[x].A) + return lambda x, tile: (tile, self.dataset_table.X[:, return_table].X[x]) + # return spatialdata consisting of the image tile and the associated table + return lambda x, tile: SpatialData( + images={self.tiles_coords.iloc[x][[self.REGION_KEY]]: tile}, table=self.dataset_table[x] + ) def __len__(self) -> int: return len(self.dataset_index) @@ -224,9 +235,7 @@ def __getitem__(self, idx: int) -> Any | SpatialData: target_coordinate_system=row["cs"], ) - if self._return_table is not None: - return tile, self.filtered_table(idx) - return SpatialData(images={t_coords[self.REGION_KEY][idx]: tile}, table=self.dataset_table[idx]) + self._return(idx, tile) @property def regions(self) -> list[str]: diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index b5772531..66476c7f 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -8,26 +8,27 @@ from spatialdata.models import TableModel -@pytest.mark.parametrize("image_element", ["blobs_image", "blobs_multiscale_image"]) -@pytest.mark.parametrize( - "regions_element", - ["blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons"], -) -def test_tiles_dataset(sdata_blobs, image_element, regions_element): - if regions_element in ["blobs_labels", "blobs_multipolygons", "blobs_multiscale_labels"]: - cm = pytest.raises(NotImplementedError) - else: - cm = contextlib.nullcontext() - with cm: - ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={regions_element: image_element}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", - ) - tile = ds[0].images.values().__iter__().__next__() - assert tile.shape == (3, 32, 32) +class TestImageTilesDataset: + @pytest.mark.parametrize("image_element", ["blobs_image", "blobs_multiscale_image"]) + @pytest.mark.parametrize( + "regions_element", + ["blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons"], + ) + def test_tiles_dataset(self, sdata_blobs, image_element, regions_element): + if regions_element in ["blobs_labels", "blobs_multipolygons", "blobs_multiscale_labels"]: + cm = pytest.raises(NotImplementedError) + else: + cm = contextlib.nullcontext() + with cm: + ds = ImageTilesDataset( + sdata=sdata_blobs, + regions_to_images={regions_element: image_element}, + tile_dim_in_units=10, + tile_dim_in_pixels=32, + target_coordinate_system="global", + ) + tile = ds[0].images.values().__iter__().__next__() + assert tile.shape == (3, 32, 32) def test_tiles_table(sdata_blobs): diff --git a/tests/dataloader/test_transforms.py b/tests/dataloader/test_transforms.py deleted file mode 100644 index e69de29b..00000000 From 1d2f3cc4fab2c10a3690f89672f3423fdc525e10 Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 21 Jun 2023 21:33:06 +0000 Subject: [PATCH 10/38] update tests and simplify --- src/spatialdata/dataloader/datasets.py | 22 +++++++++----- tests/dataloader/test_datasets.py | 41 +++++++++++++++++++++----- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 1cd58b19..dde22451 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -7,11 +7,10 @@ import pandas as pd from anndata import AnnData from geopandas import GeoDataFrame +from multiscale_spatial_image import MultiscaleSpatialImage from scipy.sparse import issparse from torch.utils.data import Dataset -from spatialdata import bounding_box_query -from spatialdata._core.operations.rasterize import rasterize from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import ( Image2DModel, @@ -92,6 +91,9 @@ def __init__( ------- :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData`. """ + from spatialdata import bounding_box_query + from spatialdata._core.operations.rasterize import rasterize + self._validate(sdata, regions_to_images, regions_to_coordinate_systems) self._preprocess(tile_scale, tile_dim_in_units) self._crop_image: Callable[..., Any] = rasterize if raster else bounding_box_query @@ -118,9 +120,6 @@ def _validate( ), "One region cannot be paired to multiple coordinate systems." for region_key, image_key in regions_to_images.items(): - if region_key not in available_regions: - raise ValueError(f"region {region_key} not found in the spatialdata object.") - # get elements region_elem = sdata[region_key] image_elem = sdata[image_key] @@ -132,6 +131,11 @@ def _validate( raise ValueError("`regions_element` must be a shapes element.") if get_model(image_elem) not in [Image2DModel, Image3DModel]: raise ValueError("`images_element` must be an image element.") + if isinstance(image_elem, MultiscaleSpatialImage): + raise NotImplementedError("Multiscale images are not implemented yet.") + + if region_key not in available_regions: + raise ValueError(f"region {region_key} not found in the spatialdata object.") # check that the coordinate systems are valid for the elements region_trans = get_transformation(region_elem) @@ -148,7 +152,9 @@ def _validate( self.regions = list(available_regions.keys()) # all regions for the dataloader self.sdata = sdata - self.dataset_table = self.sdata.table.obs[self.sdata.table.obs[self._region_key].isin(self.regions)] + self.dataset_table = self.sdata.table.obs[ + self.sdata.table.obs[self._region_key].isin(self.regions) + ] # filtered table for the data loader self._cs_region_image = tuple(cs_region_image) # tuple of tuples (coordinate_system, region_key, image_key) def _preprocess( @@ -231,7 +237,7 @@ def __getitem__(self, idx: int) -> Any | SpatialData: image, axes=self.dims, min_coordinate=t_coords[[f"min{i}" for i in self.dims]], - max_coordinate=t_coords[[f"min{i}" for i in self.dims]], + max_coordinate=t_coords[[f"max{i}" for i in self.dims]], target_coordinate_system=row["cs"], ) @@ -338,7 +344,7 @@ def _get_tile_coords( if tile_dim_in_units is None: if elem.iloc[0][0].geom_type == "Point": extent = elem[ShapesModel.RADIUS_KEY].values * tile_scale - if elem.iloc[0][0].geom_type == "Polygon": + if elem.iloc[0][0].geom_type in ["Polygon", "MultiPolygon"]: extent = elem[ShapesModel.GEOMETRY_KEY].length * tile_scale raise ValueError("Only point and polygon shapes are supported.") if tile_dim_in_units is not None: diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index 66476c7f..93e1ffff 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -4,6 +4,7 @@ import pandas as pd import pytest from anndata import AnnData +from spatialdata import SpatialData from spatialdata.dataloader import ImageTilesDataset from spatialdata.models import TableModel @@ -14,21 +15,45 @@ class TestImageTilesDataset: "regions_element", ["blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons"], ) - def test_tiles_dataset(self, sdata_blobs, image_element, regions_element): - if regions_element in ["blobs_labels", "blobs_multipolygons", "blobs_multiscale_labels"]: + def test_validation(self, sdata_blobs, image_element, regions_element): + if regions_element in ["blobs_labels", "blobs_multiscale_labels"] or image_element == "blobs_multiscale_image": cm = pytest.raises(NotImplementedError) + elif regions_element in ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]: + cm = pytest.raises(ValueError) else: cm = contextlib.nullcontext() with cm: - ds = ImageTilesDataset( + _ = ImageTilesDataset( sdata=sdata_blobs, regions_to_images={regions_element: image_element}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", + regions_to_coordinate_systems={regions_element: "global"}, ) - tile = ds[0].images.values().__iter__().__next__() - assert tile.shape == (3, 32, 32) + + @pytest.mark.parametrize( + "regions_element", + ["blobs_circles", "blobs_polygons", "blobs_multipolygons"], + ) + def test_default(self, sdata_blobs, image_element, regions_element): + sdata = self._annotate_shapes(sdata_blobs, regions_element) + ds = ImageTilesDataset( + sdata=sdata, + regions_to_images={regions_element: "blobs_image"}, + regions_to_coordinate_systems={regions_element: "global"}, + ) + + tile = ds[0].images.values().__iter__().__next__() + assert tile.shape == (3, 32, 32) + + # TODO: consider adding this logic to blobs, to generate blobs with arbitrary table annotation + def _annotate_shapes(self, sdata: SpatialData, shape: str) -> SpatialData: + new_table = AnnData( + X=np.random.default_rng().random((len(sdata[shape]), 10)), + obs=pd.DataFrame({"region": shape, "instance_id": sdata[shape].index.values}), + ) + new_table = TableModel.parse(new_table, region=shape, region_key="region", instance_key="instance_id") + del sdata.table + sdata.table = new_table + return sdata def test_tiles_table(sdata_blobs): From dd03cdf26eb00b16ee497a774e8fd376ed32b1bd Mon Sep 17 00:00:00 2001 From: giovp Date: Thu, 22 Jun 2023 20:35:42 +0000 Subject: [PATCH 11/38] add tests --- src/spatialdata/dataloader/datasets.py | 87 ++++++++++++++--------- tests/dataloader/test_datasets.py | 98 ++++++++++++++++---------- 2 files changed, 112 insertions(+), 73 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index dde22451..e98c75bc 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -1,6 +1,9 @@ from __future__ import annotations +from collections.abc import Mapping +from functools import partial from itertools import chain +from types import MappingProxyType from typing import TYPE_CHECKING, Any, Callable import numpy as np @@ -45,12 +48,13 @@ def __init__( tile_scale: float = 1.0, tile_dim_in_units: float | None = None, raster: bool = False, - return_table: str | list[str] | None = None, + return_annot: str | list[str] | None = None, + raster_kwargs: Mapping[str, Any] = MappingProxyType({}), ): """ :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData` object. - By default, the dataset returns spatialdata object, but when `return_image` and `return_table` + By default, the dataset returns spatialdata object, but when `return_image` and `return_annot` are set, the dataset may return a tuple containing: - the tile image, centered in the target coordinate system of the region. @@ -78,14 +82,16 @@ def __init__( tile_dim_in_units The dimension of the requested tile in the units of the target coordinate system. This specifies the extent of the image each tile is querying. This is not related he size in pixel of each returned tile. - rasterize + raster If True, the images are rasterized using :func:`spatialdata.rasterize`. If False, they are rasterized using :func:`spatialdata.bounding_box_query`. - return_table + return_annot If not None, a value from the table is returned together with the image tile. Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` can be returned. If None, it will return a spatialdata object with only the tuple containing the image and the table value. + raster_kwargs + Keyword arguments passed to :func:`spatialdata.rasterize` if `raster` is True. Returns ------- @@ -96,8 +102,11 @@ def __init__( self._validate(sdata, regions_to_images, regions_to_coordinate_systems) self._preprocess(tile_scale, tile_dim_in_units) - self._crop_image: Callable[..., Any] = rasterize if raster else bounding_box_query - self._return = self._get_return(return_table) + + self._crop_image: Callable[..., Any] = ( + partial(rasterize, **dict(raster_kwargs)) if raster else bounding_box_query # type: ignore[assignment] + ) + self._return = self._get_return(return_annot) def _validate( self, @@ -108,7 +117,7 @@ def _validate( """Validate input parameters.""" self._region_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] self._instance_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] - available_regions = sdata.table.obs[self._region_key].unique() + available_regions = sdata.table.obs[self._region_key].cat.categories cs_region_image = [] # list of tuples (coordinate_system, region, image) # check unique matching between regions and images and coordinate systems @@ -150,9 +159,9 @@ def _validate( except KeyError as e: raise KeyError(f"region {region_key} not found in `regions_to_coordinate_systems`") from e - self.regions = list(available_regions.keys()) # all regions for the dataloader + self.regions = list(regions_to_coordinate_systems.keys()) # all regions for the dataloader self.sdata = sdata - self.dataset_table = self.sdata.table.obs[ + self.dataset_table = self.sdata.table[ self.sdata.table.obs[self._region_key].isin(self.regions) ] # filtered table for the data loader self._cs_region_image = tuple(cs_region_image) # tuple of tuples (coordinate_system, region_key, image_key) @@ -166,18 +175,22 @@ def _preprocess( index_df = [] tile_coords_df = [] dims_l = [] + shapes_l = [] for cs, region, image in self._cs_region_image: # get dims and transformations for the region element - dims = get_axes_names(region) + dims = get_axes_names(self.sdata[region]) dims_l.append(dims) - t = get_transformation(region, cs) + t = get_transformation(self.sdata[region], cs) assert isinstance(t, BaseTransformation) # get coordinates of centroids and extent for tiles tile_coords = _get_tile_coords(self.sdata[region], t, dims, tile_scale, tile_dim_in_units) tile_coords_df.append(tile_coords) + # get shapes + shapes_l.append(self.sdata[region]) + # get instances from region inst = self.sdata.table.obs[self.sdata.table.obs[self._region_key] == region][self._instance_key].values @@ -189,10 +202,10 @@ def _preprocess( index_df.append(df) # concatenate and assign to self - self.dataset_index = pd.concat(index_df).reset_index(inplace=True, drop=True) - self.tiles_coords = pd.concat(tile_coords_df).reset_index(inplace=True, drop=True) + self.dataset_index = pd.concat(index_df).reset_index(drop=True) + self.tiles_coords = pd.concat(tile_coords_df).reset_index(drop=True) # get table filtered by regions - self.filtered_table = self.sdata.table.obs[self.sdata.table.obs[self._region_key].isin[self._cs_region_imag[1]]] + self.filtered_table = self.sdata.table.obs[self.sdata.table.obs[self._region_key].isin(self.regions)] assert len(self.tiles_coords) == len(self.dataset_index) dims_ = set(chain(*dims_l)) @@ -201,25 +214,30 @@ def _preprocess( def _get_return( self, - return_table: str | list[str] | None, + return_annot: str | list[str] | None, ) -> Callable[[int, Any], tuple[Any, Any] | SpatialData]: """Get function to return values from the table of the dataset.""" from spatialdata import SpatialData - if return_table is not None: - # table is always returned as array shape (1, len(return_table)) + if return_annot is not None: + # table is always returned as array shape (1, len(return_annot)) # where return_table can be a single column or a list of columns - return_table = [return_table] if isinstance(return_table, str) else return_table + return_annot = [return_annot] if isinstance(return_annot, str) else return_annot # return tuple of (tile, table) - if return_table in self.dataset_table.obs: - return lambda x, tile: (tile, self.dataset_table.obs[return_table].iloc[x].values.reshape(1, -1)) - if return_table in self.dataset_table.var_names: + if np.all([i in self.dataset_table.obs for i in return_annot]): + return lambda x, tile: (tile, self.dataset_table.obs[return_annot].iloc[x].values.reshape(1, -1)) + if np.all([i in self.dataset_table.var_names for i in return_annot]): if issparse(self.dataset_table.X): - return lambda x, tile: (tile, self.dataset_table.X[:, return_table].X[x].A) - return lambda x, tile: (tile, self.dataset_table.X[:, return_table].X[x]) + return lambda x, tile: (tile, self.dataset_table.X[:, return_annot].X[x].A) + return lambda x, tile: (tile, self.dataset_table.X[:, return_annot].X[x]) + raise ValueError( + f"`return_annot` must be a column name in the table or a variable name in the table. " + f"Got {return_annot}." + ) # return spatialdata consisting of the image tile and the associated table return lambda x, tile: SpatialData( - images={self.tiles_coords.iloc[x][[self.REGION_KEY]]: tile}, table=self.dataset_table[x] + images={self.dataset_index.iloc[x][self.IMAGE_KEY]: tile}, + table=self.dataset_table[x], ) def __len__(self) -> int: @@ -236,12 +254,12 @@ def __getitem__(self, idx: int) -> Any | SpatialData: tile = self._crop_image( image, axes=self.dims, - min_coordinate=t_coords[[f"min{i}" for i in self.dims]], - max_coordinate=t_coords[[f"max{i}" for i in self.dims]], + min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, + max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, target_coordinate_system=row["cs"], ) - self._return(idx, tile) + yield self._return(idx, tile) @property def regions(self) -> list[str]: @@ -338,15 +356,16 @@ def _get_tile_coords( # get centroids and transform them centroids = elem.centroid.get_coordinates().values aff = transformation.to_affine_matrix(input_axes=dims, output_axes=dims) - centroids = np.squeeze(_affine_matrix_multiplication(aff, centroids), 0) + centroids = _affine_matrix_multiplication(aff, centroids) # get extent, first by checking shape defaults, then by using the `tile_dim_in_units` if tile_dim_in_units is None: if elem.iloc[0][0].geom_type == "Point": extent = elem[ShapesModel.RADIUS_KEY].values * tile_scale - if elem.iloc[0][0].geom_type in ["Polygon", "MultiPolygon"]: + elif elem.iloc[0][0].geom_type in ["Polygon", "MultiPolygon"]: extent = elem[ShapesModel.GEOMETRY_KEY].length * tile_scale - raise ValueError("Only point and polygon shapes are supported.") + else: + raise ValueError("Only point and polygon shapes are supported.") if tile_dim_in_units is not None: if isinstance(tile_dim_in_units, float): extent = np.repeat(tile_dim_in_units, len(centroids)) @@ -358,14 +377,14 @@ def _get_tile_coords( # transform extent aff = transformation.to_affine_matrix(input_axes=tuple(dims[0]), output_axes=tuple(dims[0])) - extent = np.squeeze(_affine_matrix_multiplication(aff, extent), 0) + extent = _affine_matrix_multiplication(aff, extent[:, np.newaxis]) # get min and max coordinates - min_coordinates = np.array(centroids.values) - extent / 2 - max_coordinates = np.array(centroids.values) + extent / 2 + min_coordinates = np.array(centroids) - extent / 2 + max_coordinates = np.array(centroids) + extent / 2 # return a dataframe with columns e.g. ["x", "y", "extent", "minx", "miny", "maxx", "maxy"] return pd.DataFrame( - np.hstack([centroids, extent[:, np.newaxis], min_coordinates, max_coordinates]), + np.hstack([centroids, extent, min_coordinates, max_coordinates]), columns=list(dims) + ["extent"] + ["min" + dim for dim in dims] + ["max" + dim for dim in dims], ) diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index 93e1ffff..819180c4 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -29,20 +29,72 @@ def test_validation(self, sdata_blobs, image_element, regions_element): regions_to_coordinate_systems={regions_element: "global"}, ) - @pytest.mark.parametrize( - "regions_element", - ["blobs_circles", "blobs_polygons", "blobs_multipolygons"], - ) - def test_default(self, sdata_blobs, image_element, regions_element): + @pytest.mark.parametrize("regions_element", ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]) + @pytest.mark.parametrize("raster", [True, False]) + def test_default(self, sdata_blobs, regions_element, raster): + raster_kwargs = {"target_unit_to_pixels": 2} if raster else {} + sdata = self._annotate_shapes(sdata_blobs, regions_element) ds = ImageTilesDataset( sdata=sdata, + raster=raster, regions_to_images={regions_element: "blobs_image"}, regions_to_coordinate_systems={regions_element: "global"}, + raster_kwargs=raster_kwargs, ) - tile = ds[0].images.values().__iter__().__next__() - assert tile.shape == (3, 32, 32) + sdata_tile = ds[0].__next__() + tile = sdata_tile.images.values().__iter__().__next__() + + if regions_element == "blobs_circles": + if raster: + assert tile.shape == (3, 50, 50) + else: + assert tile.shape == (3, 25, 25) + elif regions_element == "blobs_polygons": + if raster: + assert tile.shape == (3, 164, 164) + else: + assert tile.shape == (3, 82, 82) + elif regions_element == "blobs_multipolygons": + if raster: + assert tile.shape == (3, 329, 329) + else: + assert tile.shape == (3, 164, 164) + else: + raise ValueError(f"Unexpected regions_element: {regions_element}") + # extent has units in pixel so should be the same as tile shape + if raster: + assert round(ds.tiles_coords.extent.unique()[0] * 2) == tile.shape[1] + else: + assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + assert np.all(sdata_tile.table.obs.columns == ds.sdata.table.obs.columns) + assert list(sdata_tile.images.keys())[0] == "blobs_image" + + @pytest.mark.parametrize("regions_element", ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]) + @pytest.mark.parametrize("return_annot", ["region", ["region", "instance_id"]]) + def test_return_annot(self, sdata_blobs, regions_element, return_annot): + sdata = self._annotate_shapes(sdata_blobs, regions_element) + ds = ImageTilesDataset( + sdata=sdata, + regions_to_images={regions_element: "blobs_image"}, + regions_to_coordinate_systems={regions_element: "global"}, + return_annot=return_annot, + ) + + tile, annot = ds[0].__next__() + if regions_element == "blobs_circles": + assert tile.shape == (3, 25, 25) + elif regions_element == "blobs_polygons": + assert tile.shape == (3, 82, 82) + elif regions_element == "blobs_multipolygons": + assert tile.shape == (3, 164, 164) + else: + raise ValueError(f"Unexpected regions_element: {regions_element}") + # extent has units in pixel so should be the same as tile shape + assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + return_annot = [return_annot] if isinstance(return_annot, str) else return_annot + assert annot.shape[1] == len(return_annot) # TODO: consider adding this logic to blobs, to generate blobs with arbitrary table annotation def _annotate_shapes(self, sdata: SpatialData, shape: str) -> SpatialData: @@ -54,35 +106,3 @@ def _annotate_shapes(self, sdata: SpatialData, shape: str) -> SpatialData: del sdata.table sdata.table = new_table return sdata - - -def test_tiles_table(sdata_blobs): - new_table = AnnData( - X=np.random.default_rng().random((3, 10)), - obs=pd.DataFrame({"region": "blobs_circles", "instance_id": np.array([0, 1, 2])}), - ) - new_table = TableModel.parse(new_table, region="blobs_circles", region_key="region", instance_key="instance_id") - del sdata_blobs.table - sdata_blobs.table = new_table - ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={"blobs_circles": "blobs_image"}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", - ) - assert len(ds) == 3 - assert len(ds[0].table) == 1 - assert np.all(ds[0].table.X == new_table[0].X) - - -def test_tiles_multiple_elements(sdata_blobs): - ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={"blobs_circles": "blobs_image", "blobs_polygons": "blobs_multiscale_image"}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", - ) - assert len(ds) == 6 - _ = ds[0] From f846caf623e0bda0d92ec3f38ac840fac5217f9b Mon Sep 17 00:00:00 2001 From: giovp Date: Thu, 22 Jun 2023 20:49:42 +0000 Subject: [PATCH 12/38] fix tests for pandas --- src/spatialdata/dataloader/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index e98c75bc..e6b5f83a 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -377,7 +377,7 @@ def _get_tile_coords( # transform extent aff = transformation.to_affine_matrix(input_axes=tuple(dims[0]), output_axes=tuple(dims[0])) - extent = _affine_matrix_multiplication(aff, extent[:, np.newaxis]) + extent = _affine_matrix_multiplication(aff, np.array(extent)[:, np.newaxis]) # get min and max coordinates min_coordinates = np.array(centroids) - extent / 2 From 58b15c81de9ff551f37bdf4ea76af378a6a03365 Mon Sep 17 00:00:00 2001 From: giovp Date: Thu, 22 Jun 2023 20:55:14 +0000 Subject: [PATCH 13/38] fix import and docs --- docs/api.md | 2 +- src/spatialdata/__init__.py | 2 +- src/spatialdata/dataloader/datasets.py | 8 ++------ 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/docs/api.md b/docs/api.md index 011a86d0..d28cb738 100644 --- a/docs/api.md +++ b/docs/api.md @@ -110,7 +110,7 @@ The transformations that can be defined between elements and coordinate systems ## DataLoader ```{eval-rst} -.. currentmodule:: spatialdata.dataloader.datasets +.. currentmodule:: spatialdata.dataloader .. autosummary:: :toctree: generated diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index f400ea5d..f58ccfe7 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -27,7 +27,7 @@ "save_transformations", ] -from spatialdata import dataloader, models, transformations +from spatialdata import models, transformations from spatialdata._core.concatenate import concatenate from spatialdata._core.operations.aggregate import aggregate from spatialdata._core.operations.rasterize import rasterize diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index e6b5f83a..33f0bfe3 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -4,7 +4,7 @@ from functools import partial from itertools import chain from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Callable +from typing import Any, Callable import numpy as np import pandas as pd @@ -14,6 +14,7 @@ from scipy.sparse import issparse from torch.utils.data import Dataset +from spatialdata import SpatialData from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import ( Image2DModel, @@ -30,9 +31,6 @@ __all__ = ["ImageTilesDataset"] -if TYPE_CHECKING: - from spatialdata import SpatialData - class ImageTilesDataset(Dataset): INSTANCE_KEY = "instance" @@ -217,8 +215,6 @@ def _get_return( return_annot: str | list[str] | None, ) -> Callable[[int, Any], tuple[Any, Any] | SpatialData]: """Get function to return values from the table of the dataset.""" - from spatialdata import SpatialData - if return_annot is not None: # table is always returned as array shape (1, len(return_annot)) # where return_table can be a single column or a list of columns From 43065a9993e4b1c9304116b034f5437b602941bc Mon Sep 17 00:00:00 2001 From: giovp Date: Thu, 22 Jun 2023 21:15:58 +0000 Subject: [PATCH 14/38] update api --- docs/api.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index d28cb738..89fbf2bd 100644 --- a/docs/api.md +++ b/docs/api.md @@ -111,7 +111,6 @@ The transformations that can be defined between elements and coordinate systems ```{eval-rst} .. currentmodule:: spatialdata.dataloader - .. autosummary:: :toctree: generated From 556b87bd35d1e1e3a31c890c8568ae423ee19700 Mon Sep 17 00:00:00 2001 From: giovp Date: Thu, 22 Jun 2023 21:19:13 +0000 Subject: [PATCH 15/38] update import --- src/spatialdata/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index f58ccfe7..f400ea5d 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -27,7 +27,7 @@ "save_transformations", ] -from spatialdata import models, transformations +from spatialdata import dataloader, models, transformations from spatialdata._core.concatenate import concatenate from spatialdata._core.operations.aggregate import aggregate from spatialdata._core.operations.rasterize import rasterize From ffe73ec54337b81913d497d321ad333f1dc9c4a3 Mon Sep 17 00:00:00 2001 From: giovp Date: Fri, 23 Jun 2023 16:19:23 +0000 Subject: [PATCH 16/38] fix imports --- src/spatialdata/_core/concatenate.py | 8 ++--- src/spatialdata/_core/operations/aggregate.py | 7 ++--- src/spatialdata/_core/operations/rasterize.py | 2 -- .../_core/query/relational_query.py | 6 ++-- src/spatialdata/_core/query/spatial_query.py | 1 - src/spatialdata/_core/spatialdata.py | 30 +++++++++++-------- src/spatialdata/_io/_utils.py | 8 ++--- src/spatialdata/_io/io_zarr.py | 2 +- src/spatialdata/_utils.py | 5 +--- src/spatialdata/dataloader/datasets.py | 2 +- src/spatialdata/datasets.py | 2 +- src/spatialdata/models/_utils.py | 1 - src/spatialdata/transformations/operations.py | 2 +- tests/conftest.py | 2 +- tests/core/operations/test_aggregations.py | 3 +- .../operations/test_spatialdata_operations.py | 2 +- tests/core/operations/test_transform.py | 3 +- tests/core/query/test_spatial_query.py | 2 +- tests/dataloader/test_datasets.py | 2 +- tests/io/test_readwrite.py | 2 +- tests/models/test_models.py | 2 +- 21 files changed, 41 insertions(+), 53 deletions(-) diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index 070b6fa2..8595343b 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -2,14 +2,12 @@ from copy import copy # Should probably go up at the top from itertools import chain -from typing import TYPE_CHECKING, Any +from typing import Any import numpy as np from anndata import AnnData -if TYPE_CHECKING: - from spatialdata._core.spatialdata import SpatialData - +from spatialdata._core.spatialdata import SpatialData from spatialdata.models import TableModel __all__ = [ @@ -94,8 +92,6 @@ def concatenate( ------- The concatenated :class:`spatialdata.SpatialData` object. """ - from spatialdata import SpatialData - merged_images = {**{k: v for sdata in sdatas for k, v in sdata.images.items()}} if len(merged_images) != np.sum([len(sdata.images) for sdata in sdatas]): raise KeyError("Images must have unique names across the SpatialData objects to concatenate") diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index 9881dc7b..e51c848c 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any +from typing import Any import anndata as ad import dask as da @@ -20,6 +20,7 @@ from spatialdata._core.operations.transform import transform from spatialdata._core.query._utils import circles_to_polygons from spatialdata._core.query.relational_query import get_values +from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.models import ( Image2DModel, @@ -32,9 +33,6 @@ ) from spatialdata.transformations import BaseTransformation, Identity, get_transformation -if TYPE_CHECKING: - from spatialdata import SpatialData - __all__ = ["aggregate"] @@ -236,7 +234,6 @@ def _create_sdata_from_table_and_shapes( instance_key: str, deepcopy: bool, ) -> SpatialData: - from spatialdata import SpatialData from spatialdata._utils import _deepcopy_geodataframe table.obs[instance_key] = table.obs_names.copy() diff --git a/src/spatialdata/_core/operations/rasterize.py b/src/spatialdata/_core/operations/rasterize.py index 09879547..36b53904 100644 --- a/src/spatialdata/_core/operations/rasterize.py +++ b/src/spatialdata/_core/operations/rasterize.py @@ -207,8 +207,6 @@ def _( target_height: Optional[float] = None, target_depth: Optional[float] = None, ) -> SpatialData: - from spatialdata import SpatialData - min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index ff45bdd9..7cd4d2fb 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import Any import dask.array as da import numpy as np @@ -11,6 +11,7 @@ from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _inplace_fix_subset_categorical_obs from spatialdata.models import ( Labels2DModel, @@ -22,9 +23,6 @@ get_model, ) -if TYPE_CHECKING: - from spatialdata import SpatialData - def _filter_table_by_coordinate_system(table: AnnData | None, coordinate_system: str | list[str]) -> AnnData | None: """ diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index cd8777fb..85b19d04 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -305,7 +305,6 @@ def _( target_coordinate_system: str, filter_table: bool = True, ) -> SpatialData: - from spatialdata import SpatialData from spatialdata._core.query.relational_query import _filter_table_by_elements min_coordinate = _parse_list_into_array(min_coordinate) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 1c754170..b33febf4 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -18,27 +18,17 @@ from pyarrow.parquet import read_table from spatial_image import SpatialImage -from spatialdata._io import ( - write_image, - write_labels, - write_points, - write_shapes, - write_table, -) -from spatialdata._io._utils import get_backing_files from spatialdata._logging import logger from spatialdata._types import ArrayLike -from spatialdata._utils import _natural_keys -from spatialdata.models import ( +from spatialdata.models._utils import SpatialElement, get_axes_names +from spatialdata.models.models import ( Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, ShapesModel, - SpatialElement, TableModel, - get_axes_names, get_model, ) @@ -653,6 +643,9 @@ def add_image( ----- If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. """ + from spatialdata._io._utils import get_backing_files + from spatialdata._io.io_raster import write_image + if self.is_backed(): files = get_backing_files(image) assert self.path is not None @@ -736,6 +729,9 @@ def add_labels( ----- If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. """ + from spatialdata._io._utils import get_backing_files + from spatialdata._io.io_raster import write_labels + if self.is_backed(): files = get_backing_files(labels) assert self.path is not None @@ -820,6 +816,9 @@ def add_points( ----- If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. """ + from spatialdata._io._utils import get_backing_files + from spatialdata._io.io_points import write_points + if self.is_backed(): files = get_backing_files(points) assert self.path is not None @@ -902,6 +901,8 @@ def add_shapes( ----- If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. """ + from spatialdata._io.io_shapes import write_shapes + self._add_shapes_in_memory(name=name, shapes=shapes, overwrite=overwrite) if self.is_backed(): elem_group = self._init_add_element(name=name, element_type="shapes", overwrite=overwrite) @@ -918,6 +919,8 @@ def write( storage_options: JSONDict | list[JSONDict] | None = None, overwrite: bool = False, ) -> None: + from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table + """Write the SpatialData object to Zarr.""" if isinstance(file_path, str): file_path = Path(file_path) @@ -1113,6 +1116,8 @@ def table(self, table: AnnData) -> None: The table needs to pass validation (see :class:`~spatialdata.TableModel`). If the SpatialData object is backed by a Zarr storage, the table will be written to the Zarr storage. """ + from spatialdata._io.io_table import write_table + TableModel().validate(table) if self.table is not None: raise ValueError("The table already exists. Use del sdata.table to remove it first.") @@ -1199,6 +1204,7 @@ def _gen_repr( ------- The string representation of the SpatialData object. """ + from spatialdata._utils import _natural_keys def rreplace(s: str, old: str, new: str, occurrence: int) -> str: li = s.rsplit(old, occurrence) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index bfa12721..c2d44114 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -8,7 +8,7 @@ from collections.abc import Generator, Mapping from contextlib import contextmanager from functools import singledispatch -from typing import TYPE_CHECKING, Any +from typing import Any import zarr from dask.dataframe.core import DataFrame as DaskDataFrame @@ -18,6 +18,7 @@ from spatial_image import SpatialImage from xarray import DataArray +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import iterate_pyramid_levels from spatialdata.models._utils import ( MappingToCoordinateSystem_t, @@ -30,9 +31,6 @@ _get_current_output_axes, ) -if TYPE_CHECKING: - from spatialdata import SpatialData - # suppress logger debug from ome_zarr with context manager @contextmanager @@ -196,8 +194,6 @@ def _are_directories_identical( def _compare_sdata_on_disk(a: SpatialData, b: SpatialData) -> bool: - from spatialdata import SpatialData - if not isinstance(a, SpatialData) or not isinstance(b, SpatialData): return False # TODO: if the sdata object is backed on disk, don't create a new zarr file diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 7675b3f7..326424f2 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -7,7 +7,7 @@ from anndata import AnnData from anndata import read_zarr as read_anndata_zarr -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata._io._utils import ome_zarr_logger from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 7d9e0aac..bb973c48 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -3,7 +3,7 @@ import re from collections.abc import Generator from copy import deepcopy -from typing import TYPE_CHECKING, Union +from typing import Union import numpy as np import pandas as pd @@ -26,9 +26,6 @@ # I was using "from numbers import Number" but this led to mypy errors, so I switched to the following: Number = Union[int, float] -if TYPE_CHECKING: - pass - def _parse_list_into_array(array: list[Number] | ArrayLike) -> ArrayLike: if isinstance(array, list): diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 33f0bfe3..64af2324 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -14,7 +14,7 @@ from scipy.sparse import issparse from torch.utils.data import Dataset -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import ( Image2DModel, diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index ab93669f..cc3068a2 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -12,8 +12,8 @@ from skimage.segmentation import slic from spatial_image import SpatialImage -from spatialdata import SpatialData from spatialdata._core.operations.aggregate import aggregate +from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.models import ( Image2DModel, diff --git a/src/spatialdata/models/_utils.py b/src/spatialdata/models/_utils.py index 4dacb084..5ebbc472 100644 --- a/src/spatialdata/models/_utils.py +++ b/src/spatialdata/models/_utils.py @@ -17,7 +17,6 @@ SpatialElement = Union[SpatialImage, MultiscaleSpatialImage, GeoDataFrame, DaskDataFrame] TRANSFORM_KEY = "transform" DEFAULT_COORDINATE_SYSTEM = "global" -# ValidAxis_t = Literal["c", "x", "y", "z"] ValidAxis_t = str MappingToCoordinateSystem_t = dict[str, BaseTransformation] C = "c" diff --git a/src/spatialdata/transformations/operations.py b/src/spatialdata/transformations/operations.py index 8e985432..f42452fa 100644 --- a/src/spatialdata/transformations/operations.py +++ b/src/spatialdata/transformations/operations.py @@ -16,7 +16,7 @@ ) if TYPE_CHECKING: - from spatialdata import SpatialData + from spatialdata._core.spatialdata import SpatialData from spatialdata.models import SpatialElement from spatialdata.transformations import Affine, BaseTransformation diff --git a/tests/conftest.py b/tests/conftest.py index b5d21299..5739d21f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,7 @@ from numpy.random import default_rng from shapely.geometry import MultiPolygon, Point, Polygon from spatial_image import SpatialImage -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata.models import ( Image2DModel, Image3DModel, diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 51ffa98e..36609464 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -9,8 +9,9 @@ from anndata.tests.helpers import assert_equal from geopandas import GeoDataFrame from numpy.random import default_rng -from spatialdata import SpatialData, aggregate +from spatialdata import aggregate from spatialdata._core.query._utils import circles_to_polygons +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _deepcopy_geodataframe from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, TableModel from spatialdata.transformations import Affine, Identity, set_transformation diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index f2cd2695..39628c15 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -9,8 +9,8 @@ from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage -from spatialdata import SpatialData from spatialdata._core.concatenate import _concatenate_tables, concatenate +from spatialdata._core.spatialdata import SpatialData from spatialdata.datasets import blobs from spatialdata.models import ( Image2DModel, diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index f28b345f..601159ef 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -8,7 +8,8 @@ from geopandas.testing import geom_almost_equals from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage -from spatialdata import SpatialData, transform +from spatialdata import transform +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import unpad_raster from spatialdata.models import Image2DModel, PointsModel, ShapesModel, get_axes_names from spatialdata.transformations.operations import ( diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index ae9e2047..0166d9e4 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -5,13 +5,13 @@ from anndata import AnnData from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage -from spatialdata import SpatialData from spatialdata._core.query.spatial_query import ( BaseSpatialRequest, BoundingBoxRequest, bounding_box_query, polygon_query, ) +from spatialdata._core.spatialdata import SpatialData from spatialdata.models import ( Image2DModel, Image3DModel, diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index 819180c4..1a0cc959 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -4,7 +4,7 @@ import pandas as pd import pytest from anndata import AnnData -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata.dataloader import ImageTilesDataset from spatialdata.models import TableModel diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 046baf3b..ddeefb37 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -12,7 +12,7 @@ from numpy.random import default_rng from shapely.geometry import Point from spatial_image import SpatialImage -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata._io._utils import _are_directories_identical from spatialdata.models import TableModel from spatialdata.transformations.operations import ( diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 285c4584..a8249b2b 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -20,7 +20,7 @@ from pandas.api.types import is_categorical_dtype from shapely.io import to_ragged_array from spatial_image import SpatialImage, to_spatial_image -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.models import ( Image2DModel, From 05f5db7cc0c63f6ee4e49d61c02b20d315c63ec0 Mon Sep 17 00:00:00 2001 From: giovp Date: Fri, 23 Jun 2023 17:30:31 +0000 Subject: [PATCH 17/38] update api --- docs/api.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index 89fbf2bd..1805b19f 100644 --- a/docs/api.md +++ b/docs/api.md @@ -9,6 +9,7 @@ The `SpatialData` class. ```{eval-rst} +.. currentmodule:: spatialdata .. autosummary:: :toctree: generated @@ -63,6 +64,7 @@ The elements (building-blocks) that consitute `SpatialData`. ### Utilities ```{eval-rst} +.. currentmodule:: spatialdata.models .. autosummary:: :toctree: generated @@ -80,7 +82,6 @@ The transformations that can be defined between elements and coordinate systems ```{eval-rst} .. currentmodule:: spatialdata.transformations - .. autosummary:: :toctree: generated @@ -97,6 +98,7 @@ The transformations that can be defined between elements and coordinate systems ```{eval-rst} .. autosummary:: +.. currentmodule:: spatialdata.transformations :toctree: generated get_transformation From bba86ef0d35baae4b681df3a481a3deb8e0899ef Mon Sep 17 00:00:00 2001 From: giovp Date: Fri, 23 Jun 2023 18:08:49 +0000 Subject: [PATCH 18/38] fix api --- docs/api.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/api.md b/docs/api.md index 1805b19f..4913b70d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -9,7 +9,6 @@ The `SpatialData` class. ```{eval-rst} -.. currentmodule:: spatialdata .. autosummary:: :toctree: generated @@ -49,6 +48,7 @@ The elements (building-blocks) that consitute `SpatialData`. ```{eval-rst} .. currentmodule:: spatialdata.models + .. autosummary:: :toctree: generated @@ -65,6 +65,7 @@ The elements (building-blocks) that consitute `SpatialData`. ```{eval-rst} .. currentmodule:: spatialdata.models + .. autosummary:: :toctree: generated @@ -82,6 +83,7 @@ The transformations that can be defined between elements and coordinate systems ```{eval-rst} .. currentmodule:: spatialdata.transformations + .. autosummary:: :toctree: generated @@ -97,8 +99,9 @@ The transformations that can be defined between elements and coordinate systems ### Utilities ```{eval-rst} -.. autosummary:: .. currentmodule:: spatialdata.transformations + +.. autosummary:: :toctree: generated get_transformation @@ -113,6 +116,7 @@ The transformations that can be defined between elements and coordinate systems ```{eval-rst} .. currentmodule:: spatialdata.dataloader + .. autosummary:: :toctree: generated From b37f40c624ad71586aba2668ca171e73a79d9172 Mon Sep 17 00:00:00 2001 From: giovp Date: Fri, 23 Jun 2023 18:24:47 +0000 Subject: [PATCH 19/38] try fix docs --- .readthedocs.yaml | 2 +- pyproject.toml | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 690bf115..b59dfb7b 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -6,7 +6,7 @@ build: python: "3.10" sphinx: configuration: docs/conf.py - fail_on_warning: false + fail_on_warning: true python: install: - method: pip diff --git a/pyproject.toml b/pyproject.toml index 3eea80ee..fd8b0df5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ dev = [ docs = [ "sphinx>=4.5", "sphinx-book-theme>=1.0.0", - "sphinx_rtd_theme", "myst-nb", "sphinxcontrib-bibtex>=1.0.0", "sphinx-autodoc-typehints", From a96b6fd057c3e4c4974cdee08fd661c3ec41d5e3 Mon Sep 17 00:00:00 2001 From: giovp Date: Fri, 23 Jun 2023 18:30:20 +0000 Subject: [PATCH 20/38] try fix docs --- docs/api.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index 4913b70d..bdd0c80d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -29,7 +29,6 @@ Operations on `SpatialData` objects. match_table_to_element concatenate rasterize - transform aggregate ``` From ceb6f5ba4cf1c517719455701d8350b75904a1a0 Mon Sep 17 00:00:00 2001 From: giovp Date: Fri, 23 Jun 2023 18:43:56 +0000 Subject: [PATCH 21/38] add optional import of dataloader --- src/spatialdata/dataloader/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/dataloader/__init__.py b/src/spatialdata/dataloader/__init__.py index 297b221f..819ab58e 100644 --- a/src/spatialdata/dataloader/__init__.py +++ b/src/spatialdata/dataloader/__init__.py @@ -1 +1,4 @@ -from spatialdata.dataloader.datasets import ImageTilesDataset +try: + from spatialdata.dataloader.datasets import ImageTilesDataset +except ImportError: + ImageTilesDataset = None # type: ignore[assignment, misc] From d290845caa8a2e3eefe98f26154702782bd89715 Mon Sep 17 00:00:00 2001 From: giovp Date: Fri, 23 Jun 2023 21:13:23 +0000 Subject: [PATCH 22/38] minor fixes --- src/spatialdata/dataloader/datasets.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 64af2324..cbb29686 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -145,9 +145,6 @@ def _validate( raise ValueError(f"region {region_key} not found in the spatialdata object.") # check that the coordinate systems are valid for the elements - region_trans = get_transformation(region_elem) - image_trans = get_transformation(image_elem) - try: cs = regions_to_coordinate_systems[region_key] region_trans = get_transformation(region_elem, cs) @@ -182,16 +179,18 @@ def _preprocess( t = get_transformation(self.sdata[region], cs) assert isinstance(t, BaseTransformation) + # get instances from region + inst = self.sdata.table.obs[self.sdata.table.obs[self._region_key] == region][self._instance_key].values + + # subset the regions by instances + subset_region = self.sdata[region].iloc[inst] # get coordinates of centroids and extent for tiles - tile_coords = _get_tile_coords(self.sdata[region], t, dims, tile_scale, tile_dim_in_units) + tile_coords = _get_tile_coords(subset_region, t, dims, tile_scale, tile_dim_in_units) tile_coords_df.append(tile_coords) # get shapes shapes_l.append(self.sdata[region]) - # get instances from region - inst = self.sdata.table.obs[self.sdata.table.obs[self._region_key] == region][self._instance_key].values - # get index dictionary, with `instance_id`, `cs`, `region`, and `image` df = pd.DataFrame({self.INSTANCE_KEY: inst}) df[self.CS_KEY] = cs @@ -224,8 +223,8 @@ def _get_return( return lambda x, tile: (tile, self.dataset_table.obs[return_annot].iloc[x].values.reshape(1, -1)) if np.all([i in self.dataset_table.var_names for i in return_annot]): if issparse(self.dataset_table.X): - return lambda x, tile: (tile, self.dataset_table.X[:, return_annot].X[x].A) - return lambda x, tile: (tile, self.dataset_table.X[:, return_annot].X[x]) + return lambda x, tile: (tile, self.dataset_table[:, return_annot].X[x].A) + return lambda x, tile: (tile, self.dataset_table[:, return_annot].X[x]) raise ValueError( f"`return_annot` must be a column name in the table or a variable name in the table. " f"Got {return_annot}." @@ -255,7 +254,7 @@ def __getitem__(self, idx: int) -> Any | SpatialData: target_coordinate_system=row["cs"], ) - yield self._return(idx, tile) + return self._return(idx, tile) @property def regions(self) -> list[str]: From 787837990d6ff1aaee3d33a25cb093918ef33532 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 27 Jun 2023 19:43:38 +0000 Subject: [PATCH 23/38] fix test --- src/spatialdata/dataloader/datasets.py | 16 +++++++++++++--- tests/dataloader/test_datasets.py | 4 ++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index cbb29686..12ee3cd5 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -47,6 +47,7 @@ def __init__( tile_dim_in_units: float | None = None, raster: bool = False, return_annot: str | list[str] | None = None, + transform: Callable[[Any], Any] | None = None, raster_kwargs: Mapping[str, Any] = MappingProxyType({}), ): """ @@ -105,6 +106,7 @@ def __init__( partial(rasterize, **dict(raster_kwargs)) if raster else bounding_box_query # type: ignore[assignment] ) self._return = self._get_return(return_annot) + self.transform = transform def _validate( self, @@ -223,8 +225,8 @@ def _get_return( return lambda x, tile: (tile, self.dataset_table.obs[return_annot].iloc[x].values.reshape(1, -1)) if np.all([i in self.dataset_table.var_names for i in return_annot]): if issparse(self.dataset_table.X): - return lambda x, tile: (tile, self.dataset_table[:, return_annot].X[x].A) - return lambda x, tile: (tile, self.dataset_table[:, return_annot].X[x]) + return lambda x, tile: (tile, self.dataset_table[x, return_annot].X.A) + return lambda x, tile: (tile, self.dataset_table[x, return_annot].X) raise ValueError( f"`return_annot` must be a column name in the table or a variable name in the table. " f"Got {return_annot}." @@ -254,6 +256,9 @@ def __getitem__(self, idx: int) -> Any | SpatialData: target_coordinate_system=row["cs"], ) + if self.transform is not None: + out = self._return(idx, tile) + return self.transform(out) return self._return(idx, tile) @property @@ -362,8 +367,13 @@ def _get_tile_coords( else: raise ValueError("Only point and polygon shapes are supported.") if tile_dim_in_units is not None: - if isinstance(tile_dim_in_units, float): + if isinstance(tile_dim_in_units, (float, int)): extent = np.repeat(tile_dim_in_units, len(centroids)) + else: + raise TypeError( + f"`tile_dim_in_units` must be a `float`, `int`, `list`, `tuple` or `np.ndarray`, " + f"not {type(tile_dim_in_units)}." + ) if len(extent) != len(centroids): raise ValueError( f"the number of elements in the region ({len(extent)}) does not match" diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index 1a0cc959..5650431a 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -43,7 +43,7 @@ def test_default(self, sdata_blobs, regions_element, raster): raster_kwargs=raster_kwargs, ) - sdata_tile = ds[0].__next__() + sdata_tile = ds[0] tile = sdata_tile.images.values().__iter__().__next__() if regions_element == "blobs_circles": @@ -82,7 +82,7 @@ def test_return_annot(self, sdata_blobs, regions_element, return_annot): return_annot=return_annot, ) - tile, annot = ds[0].__next__() + tile, annot = ds[0] if regions_element == "blobs_circles": assert tile.shape == (3, 25, 25) elif regions_element == "blobs_polygons": From 9b84084f7b9067c27961d2e08ad51f5e367d8d43 Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Thu, 29 Jun 2023 19:55:42 +0200 Subject: [PATCH 24/38] fixed typos --- src/spatialdata/dataloader/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 12ee3cd5..114dfacf 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -57,14 +57,14 @@ def __init__( are set, the dataset may return a tuple containing: - the tile image, centered in the target coordinate system of the region. - - a vector or scala value from the table. + - a vector or scalar value from the table. Parameters ---------- sdata The :class`spatialdata.SpatialData` object. regions_to_images - A mapping betwen region and images. The regions are used to compute the tile centers, while the images are + A mapping between region and images. The regions are used to compute the tile centers, while the images are used to get the pixel values. regions_to_coordinate_systems A mapping between regions and coordinate systems. The coordinate systems are used to transform both From 13da14254b9d35937d9ad547983506ca1451ad09 Mon Sep 17 00:00:00 2001 From: Giovanni Palla <25887487+giovp@users.noreply.github.com> Date: Thu, 13 Jul 2023 17:57:16 -0400 Subject: [PATCH 25/38] Update src/spatialdata/dataloader/datasets.py Co-authored-by: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> --- src/spatialdata/dataloader/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 114dfacf..38de3f8e 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -83,7 +83,7 @@ def __init__( of the image each tile is querying. This is not related he size in pixel of each returned tile. raster If True, the images are rasterized using :func:`spatialdata.rasterize`. - If False, they are rasterized using :func:`spatialdata.bounding_box_query`. + If False, they are queried using :func:`spatialdata.bounding_box_query`. return_annot If not None, a value from the table is returned together with the image tile. Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` From 13dc9bece5fe189cdd2e8a5bde60cdb8c4dc0a77 Mon Sep 17 00:00:00 2001 From: giovp Date: Thu, 13 Jul 2023 22:00:58 +0000 Subject: [PATCH 26/38] update --- src/spatialdata/dataloader/datasets.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 38de3f8e..4163190e 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -33,7 +33,7 @@ class ImageTilesDataset(Dataset): - INSTANCE_KEY = "instance" + INSTANCE_KEY = "instance_id" CS_KEY = "cs" REGION_KEY = "region" IMAGE_KEY = "image" @@ -79,9 +79,9 @@ def __init__( If `tile_dim_in_units` is passed, `tile_scale` is ignored. tile_dim_in_units - The dimension of the requested tile in the units of the target coordinate system. This specifies the extent - of the image each tile is querying. This is not related he size in pixel of each returned tile. - raster + The dimension of the requested tile in the units of the target coordinate system. + This specifies the extent of the tile. This is not related the size in pixel of each returned tile. + rasterize If True, the images are rasterized using :func:`spatialdata.rasterize`. If False, they are queried using :func:`spatialdata.bounding_box_query`. return_annot @@ -97,13 +97,13 @@ def __init__( :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData`. """ from spatialdata import bounding_box_query - from spatialdata._core.operations.rasterize import rasterize + from spatialdata._core.operations.rasterize import rasterize as rasterize_fn self._validate(sdata, regions_to_images, regions_to_coordinate_systems) self._preprocess(tile_scale, tile_dim_in_units) self._crop_image: Callable[..., Any] = ( - partial(rasterize, **dict(raster_kwargs)) if raster else bounding_box_query # type: ignore[assignment] + partial(rasterize_fn, **dict(raster_kwargs)) if raster else bounding_box_query # type: ignore[assignment] ) self._return = self._get_return(return_annot) self.transform = transform From 6ef34a536fa4d363413e0f4c0d670d1b1bcef1b9 Mon Sep 17 00:00:00 2001 From: giovp Date: Fri, 14 Jul 2023 04:03:18 +0200 Subject: [PATCH 27/38] update with more comments --- src/spatialdata/dataloader/datasets.py | 24 ++++++++++++++---------- tests/dataloader/test_datasets.py | 6 +++--- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 4163190e..6d787836 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -13,7 +13,7 @@ from multiscale_spatial_image import MultiscaleSpatialImage from scipy.sparse import issparse from torch.utils.data import Dataset - +from spatialdata._core.query.relational_query import match_table_to_element from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import ( @@ -45,10 +45,10 @@ def __init__( regions_to_coordinate_systems: dict[str, str], tile_scale: float = 1.0, tile_dim_in_units: float | None = None, - raster: bool = False, - return_annot: str | list[str] | None = None, + rasterize: bool = False, + return_annotations: str | list[str] | None = None, transform: Callable[[Any], Any] | None = None, - raster_kwargs: Mapping[str, Any] = MappingProxyType({}), + rasterize_kwargs: Mapping[str, Any] = MappingProxyType({}), ): """ :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData` object. @@ -84,13 +84,17 @@ def __init__( rasterize If True, the images are rasterized using :func:`spatialdata.rasterize`. If False, they are queried using :func:`spatialdata.bounding_box_query`. - return_annot + return_annotations If not None, a value from the table is returned together with the image tile. Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` can be returned. If None, it will return a spatialdata object with only the tuple containing the image and the table value. - raster_kwargs - Keyword arguments passed to :func:`spatialdata.rasterize` if `raster` is True. + transform + A callable that takes as input the tuple (image, table_value) and returns a new tuple. + This can be used to apply transformations to the image and the table value. + rasterize_kwargs + Keyword arguments passed to :func:`spatialdata.rasterize` if `rasterize` is True. + This argument can be used for instance to choose the pixel dimension of the image tile. Returns ------- @@ -103,9 +107,9 @@ def __init__( self._preprocess(tile_scale, tile_dim_in_units) self._crop_image: Callable[..., Any] = ( - partial(rasterize_fn, **dict(raster_kwargs)) if raster else bounding_box_query # type: ignore[assignment] + partial(rasterize_fn, **dict(rasterize_kwargs)) if rasterize else bounding_box_query # type: ignore[assignment] ) - self._return = self._get_return(return_annot) + self._return = self._get_return(return_annotations) self.transform = transform def _validate( @@ -123,7 +127,7 @@ def _validate( # check unique matching between regions and images and coordinate systems assert len(set(regions_to_images.values())) == len( regions_to_images.keys() - ), "One region cannot be paired to multiple regions." + ), "One region cannot be paired to multiple images." assert len(set(regions_to_coordinate_systems.values())) == len( regions_to_coordinate_systems.keys() ), "One region cannot be paired to multiple coordinate systems." diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index 5650431a..3b99b104 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -37,10 +37,10 @@ def test_default(self, sdata_blobs, regions_element, raster): sdata = self._annotate_shapes(sdata_blobs, regions_element) ds = ImageTilesDataset( sdata=sdata, - raster=raster, + rasterize=raster, regions_to_images={regions_element: "blobs_image"}, regions_to_coordinate_systems={regions_element: "global"}, - raster_kwargs=raster_kwargs, + rasterize_kwargs=raster_kwargs, ) sdata_tile = ds[0] @@ -79,7 +79,7 @@ def test_return_annot(self, sdata_blobs, regions_element, return_annot): sdata=sdata, regions_to_images={regions_element: "blobs_image"}, regions_to_coordinate_systems={regions_element: "global"}, - return_annot=return_annot, + return_annotations=return_annot, ) tile, annot = ds[0] From a65a8554b1dacbd8dd0db93b19340efa369cecad Mon Sep 17 00:00:00 2001 From: giovp Date: Fri, 14 Jul 2023 04:05:31 +0200 Subject: [PATCH 28/38] fix precommit --- src/spatialdata/dataloader/datasets.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 6d787836..2fe5924e 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -13,7 +13,7 @@ from multiscale_spatial_image import MultiscaleSpatialImage from scipy.sparse import issparse from torch.utils.data import Dataset -from spatialdata._core.query.relational_query import match_table_to_element + from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import ( @@ -107,7 +107,12 @@ def __init__( self._preprocess(tile_scale, tile_dim_in_units) self._crop_image: Callable[..., Any] = ( - partial(rasterize_fn, **dict(rasterize_kwargs)) if rasterize else bounding_box_query # type: ignore[assignment] + partial( + rasterize_fn, + **dict(rasterize_kwargs), + ) + if rasterize + else bounding_box_query # type: ignore[assignment] ) self._return = self._get_return(return_annotations) self.transform = transform From 8ef16eb3f65119ee1d6263e77a050f1bea082af6 Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Fri, 14 Jul 2023 13:40:04 +0200 Subject: [PATCH 29/38] modified docstring for transform in ImageTilesDataset --- src/spatialdata/dataloader/datasets.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 2fe5924e..d20a8e0a 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -87,11 +87,14 @@ def __init__( return_annotations If not None, a value from the table is returned together with the image tile. Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` - can be returned. If None, it will return a spatialdata object with only the tuple + can be returned. If None, it will return a `SpatialData` object with only the tuple containing the image and the table value. transform - A callable that takes as input the tuple (image, table_value) and returns a new tuple. - This can be used to apply transformations to the image and the table value. + A callable that takes as input the tuple (image, table_value) and returns a new tuple (when + `return_annotations` is not None); a callable that takes as input the `SpatialData` object and + returns a tuple when `return_annotations` is `None`. + This parameter can be used to apply data transformations (for instance a normalization operation) to the + image and the table value. rasterize_kwargs Keyword arguments passed to :func:`spatialdata.rasterize` if `rasterize` is True. This argument can be used for instance to choose the pixel dimension of the image tile. From d24bd7c5972d469e7bb91dfc34e78c343b361a8f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Jul 2023 11:41:13 +0000 Subject: [PATCH 30/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spatialdata/dataloader/datasets.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index d20a8e0a..ff71c8d3 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -90,10 +90,10 @@ def __init__( can be returned. If None, it will return a `SpatialData` object with only the tuple containing the image and the table value. transform - A callable that takes as input the tuple (image, table_value) and returns a new tuple (when - `return_annotations` is not None); a callable that takes as input the `SpatialData` object and - returns a tuple when `return_annotations` is `None`. - This parameter can be used to apply data transformations (for instance a normalization operation) to the + A callable that takes as input the tuple (image, table_value) and returns a new tuple (when + `return_annotations` is not None); a callable that takes as input the `SpatialData` object and + returns a tuple when `return_annotations` is `None`. + This parameter can be used to apply data transformations (for instance a normalization operation) to the image and the table value. rasterize_kwargs Keyword arguments passed to :func:`spatialdata.rasterize` if `rasterize` is True. From f32d472cb9eef15d99fbe7b43c20e0e2d29d1a38 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Jul 2023 19:16:17 +0000 Subject: [PATCH 31/38] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9b26e9ea..a1d8b8c2 100644 --- a/README.md +++ b/README.md @@ -45,9 +45,10 @@ If you found a bug, please use the [issue tracker][issue-tracker]. [L Marconato*, G Palla*, KA Yamauchi*, I Virshup*, E Heidari, T Treis, M Toth, R Shrestha, H Vöhringer, W Huber, M Gerstung, J Moore, FJ Theis, O Stegle, bioRxiv, 2023](https://www.biorxiv.org/content/10.1101/2023.05.05.539647v1). \* = equal contribution ## Sponsor + The spatialdata project is supported by the EMBL International PhD Programme and the Chan Zuckerberg Initiative. -[//]: # (numfocus-fiscal-sponsor-attribution) +[//]: # "numfocus-fiscal-sponsor-attribution" The scverse project uses a [consensus based governance model](https://scverse.org/about/roles/) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/). Consider making a [tax-deductible donation](https://numfocus.org/donate-to-scverse) to help the project pay for developer time, professional services, travel, workshops, and a variety of other needs. From 95b5b493d24324a87482da646d76bc399cce8682 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 27 Nov 2023 11:13:49 +0100 Subject: [PATCH 32/38] tryf ixing docs --- src/spatialdata/_core/data_extent.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_core/data_extent.py b/src/spatialdata/_core/data_extent.py index 251a9e7b..82ff72d6 100644 --- a/src/spatialdata/_core/data_extent.py +++ b/src/spatialdata/_core/data_extent.py @@ -115,9 +115,9 @@ def get_extent( has_labels: bool = True, has_points: bool = True, has_shapes: bool = True, - # python 3.9 tests fail if we don't use Union here, see - # https://github.com/scverse/spatialdata/pull/318#issuecomment-1755714287 - elements: Union[list[str], None] = None, # noqa: UP007 + elements: Union[ # noqa: UP007 # https://github.com/scverse/spatialdata/pull/318#issuecomment-1755714287 + list[str], None + ] = None, ) -> BoundingBoxDescription: """ Get the extent (bounding box) of a SpatialData object or a SpatialElement. @@ -134,7 +134,7 @@ def get_extent( max_coordinate The maximum coordinate of the bounding box. axes - The names of the dimensions of the bounding box + The names of the dimensions of the bounding box. exact If True, the extent is computed exactly. If False, an approximation faster to compute is given. The approximation is guaranteed to contain all the data, see notes for details. From 16977e0852fa8d66ebf0d1e501084eb8d429b3b4 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 27 Nov 2023 11:59:44 +0100 Subject: [PATCH 33/38] fix tests again --- src/spatialdata/_core/data_extent.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/data_extent.py b/src/spatialdata/_core/data_extent.py index 82ff72d6..66c9d1cb 100644 --- a/src/spatialdata/_core/data_extent.py +++ b/src/spatialdata/_core/data_extent.py @@ -259,7 +259,12 @@ def _get_extent_of_shapes(e: GeoDataFrame) -> BoundingBoxDescription: @get_extent.register def _(e: GeoDataFrame, coordinate_system: str = "global", exact: bool = True) -> BoundingBoxDescription: """ - Compute the extent (bounding box) of a set of shapes. + Get the extent (bounding box) of a SpatialData object: the extent of the union of the extents of all its elements. + + Parameters + ---------- + e + The SpatialData object. Returns ------- From 4923ee6417cbedeeb2a13d7ebbf7bd9669979616 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 27 Nov 2023 12:02:45 +0100 Subject: [PATCH 34/38] fix tests --- tests/dataloader/test_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index 3b99b104..ab1319ab 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -60,7 +60,7 @@ def test_default(self, sdata_blobs, regions_element, raster): if raster: assert tile.shape == (3, 329, 329) else: - assert tile.shape == (3, 164, 164) + assert tile.shape == (3, 165, 164) else: raise ValueError(f"Unexpected regions_element: {regions_element}") # extent has units in pixel so should be the same as tile shape From 63066ea76e753c0b24d394a81864c49a01c42cce Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 27 Nov 2023 15:29:50 +0100 Subject: [PATCH 35/38] update --- src/spatialdata/_core/data_extent.py | 45 +++++++++++++++------------- tests/dataloader/test_datasets.py | 2 +- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/spatialdata/_core/data_extent.py b/src/spatialdata/_core/data_extent.py index 66c9d1cb..b6fa0b5f 100644 --- a/src/spatialdata/_core/data_extent.py +++ b/src/spatialdata/_core/data_extent.py @@ -136,36 +136,37 @@ def get_extent( axes The names of the dimensions of the bounding box. exact - If True, the extent is computed exactly. If False, an approximation faster to compute is given. The - approximation is guaranteed to contain all the data, see notes for details. + If `True`, the extent is computed exactly. If `False`, an approximation faster to compute is given. + The approximation is guaranteed to contain all the data, see notes for details. has_images - If True, images are included in the computation of the extent. + If `True`, images are included in the computation of the extent. has_labels - If True, labels are included in the computation of the extent. + If `True`, labels are included in the computation of the extent. has_points - If True, points are included in the computation of the extent. + If `True`, points are included in the computation of the extent. has_shapes - If True, shapes are included in the computation of the extent. + If `True`, shapes are included in the computation of the extent. elements - If not None, only the elements with the given names are included in the computation of the extent. + If not `None`, only the elements with the given names are included in the computation of the extent. Notes ----- - The extent of a SpatialData object is the extent of the union of the extents of all its elements. The extent of a - SpatialElement is the extent of the element in the coordinate system specified by the argument `coordinate_system`. + The extent of a `SpatialData` object is the extent of the union of the extents of all its elements. + The extent of a `SpatialElement` is the extent of the element in the coordinate system + specified by the argument `coordinate_system`. - If `exact` is False, first the extent of the SpatialElement before any transformation is computed. Then, the extent - is transformed to the target coordinate system. This is faster than computing the extent after the transformation, - since the transformation is applied to extent of the untransformed data, as opposed to transforming the data and - then computing the extent. + If `exact` is `False`, first the extent of the `SpatialElement` before any transformation is computed. + Then, the extent is transformed to the target coordinate system. This is faster than computing the extent + after the transformation, since the transformation is applied to extent of the untransformed data, + as opposed to transforming the data and then computing the extent. - The exact and approximate extent are the same if the transformation doesn't contain any rotation or shear, or in the - case in which the transformation is affine but all the corners of the extent of the untransformed data + The exact and approximate extent are the same if the transformation does not contain any rotation or shear, or in + the case in which the transformation is affine but all the corners of the extent of the untransformed data (bounding box corners) are part of the dataset itself. Note that this is always the case for raster data. - An extreme case is a dataset composed of the two points (0, 0) and (1, 1), rotated anticlockwise by 45 degrees. The - exact extent is the bounding box [minx, miny, maxx, maxy] = [0, 0, 0, 1.414], while the approximate extent is the - box [minx, miny, maxx, maxy] = [-0.707, 0, 0.707, 1.414]. + An extreme case is a dataset composed of the two points `(0, 0)` and `(1, 1)`, rotated anticlockwise by 45 degrees. + The exact extent is the bounding box `[minx, miny, maxx, maxy] = [0, 0, 0, 1.414]`, while the approximate extent is + the box `[minx, miny, maxx, maxy] = [-0.707, 0, 0.707, 1.414]`. """ raise ValueError("The object type is not supported.") @@ -184,7 +185,9 @@ def _( elements: Union[list[str], None] = None, # noqa: UP007 ) -> BoundingBoxDescription: """ - Get the extent (bounding box) of a SpatialData object: the extent of the union of the extents of all its elements. + Get the extent (bounding box) of a SpatialData object. + + The resulting extent is the union of the extents of all its elements. Parameters ---------- @@ -259,7 +262,9 @@ def _get_extent_of_shapes(e: GeoDataFrame) -> BoundingBoxDescription: @get_extent.register def _(e: GeoDataFrame, coordinate_system: str = "global", exact: bool = True) -> BoundingBoxDescription: """ - Get the extent (bounding box) of a SpatialData object: the extent of the union of the extents of all its elements. + Get the extent (bounding box) of a SpatialData object. + + The resulting extent is the union of the extents of all its elements. Parameters ---------- diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index ab1319ab..3b99b104 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -60,7 +60,7 @@ def test_default(self, sdata_blobs, regions_element, raster): if raster: assert tile.shape == (3, 329, 329) else: - assert tile.shape == (3, 165, 164) + assert tile.shape == (3, 164, 164) else: raise ValueError(f"Unexpected regions_element: {regions_element}") # extent has units in pixel so should be the same as tile shape From 000f83b3e65abd3c99d4f9d3a65651451acf20b5 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 27 Nov 2023 15:46:52 +0100 Subject: [PATCH 36/38] fix docs --- src/spatialdata/_core/data_extent.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/data_extent.py b/src/spatialdata/_core/data_extent.py index b6fa0b5f..3947fe5f 100644 --- a/src/spatialdata/_core/data_extent.py +++ b/src/spatialdata/_core/data_extent.py @@ -129,6 +129,8 @@ def get_extent( Returns ------- + The bounding box description. + min_coordinate The minimum coordinate of the bounding box. max_coordinate @@ -136,7 +138,11 @@ def get_extent( axes The names of the dimensions of the bounding box. exact - If `True`, the extent is computed exactly. If `False`, an approximation faster to compute is given. + Whether the extent is computed exactly or not. + + - If `True`, the extent is computed exactly. + - If `False`, an approximation faster to compute is given. + The approximation is guaranteed to contain all the data, see notes for details. has_images If `True`, images are included in the computation of the extent. From fa539f5a0fa0d80acba2027dbd6026407b9a7639 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 27 Nov 2023 15:57:00 +0100 Subject: [PATCH 37/38] update --- docs/_templates/autosummary/class.rst | 8 -- docs/api.md | 8 +- src/spatialdata/dataloader/datasets.py | 109 +++++++++++++------------ 3 files changed, 60 insertions(+), 65 deletions(-) diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst index d4668a41..e4665dfc 100644 --- a/docs/_templates/autosummary/class.rst +++ b/docs/_templates/autosummary/class.rst @@ -12,11 +12,8 @@ Attributes table ~~~~~~~~~~~~~~~~~~ .. autosummary:: - {% for item in attributes %} - ~{{ fullname }}.{{ item }} - {%- endfor %} {% endif %} {% endblock %} @@ -27,13 +24,10 @@ Methods table ~~~~~~~~~~~~~ .. autosummary:: - {% for item in methods %} - {%- if item != '__init__' %} ~{{ fullname }}.{{ item }} {%- endif -%} - {%- endfor %} {% endif %} {% endblock %} @@ -46,7 +40,6 @@ Attributes {% for item in attributes %} .. autoattribute:: {{ [objname, item] | join(".") }} - {%- endfor %} {% endif %} @@ -61,7 +54,6 @@ Methods {%- if item != '__init__' %} .. automethod:: {{ [objname, item] | join(".") }} - {%- endif -%} {%- endfor %} diff --git a/docs/api.md b/docs/api.md index 5eea18ce..9034b0d9 100644 --- a/docs/api.md +++ b/docs/api.md @@ -33,7 +33,7 @@ Operations on `SpatialData` objects. aggregate ``` -### Utilities +### Operations Utilities ```{eval-rst} .. autosummary:: @@ -61,7 +61,7 @@ The elements (building-blocks) that consitute `SpatialData`. TableModel ``` -### Utilities +### Models Utilities ```{eval-rst} .. currentmodule:: spatialdata.models @@ -96,7 +96,7 @@ The transformations that can be defined between elements and coordinate systems Sequence ``` -### Utilities +### Transformations Utilities ```{eval-rst} .. currentmodule:: spatialdata.transformations @@ -123,7 +123,7 @@ The transformations that can be defined between elements and coordinate systems ImageTilesDataset ``` -## Input/output +## Input/Output ```{eval-rst} .. currentmodule:: spatialdata diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index ff71c8d3..52bf04e6 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -33,6 +33,62 @@ class ImageTilesDataset(Dataset): + """ + Dataloader for SpatialData. + + :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData` object. + + By default, the dataset returns spatialdata object, but when `return_image` and `return_annot` + are set, the dataset returns a tuple containing: + + - the tile image, centered in the target coordinate system of the region. + - a vector or scalar value from the table. + + Parameters + ---------- + sdata + The SpatialData object. + regions_to_images + A mapping between region and images. The regions are used to compute the tile centers, while the images are + used to get the pixel values. + regions_to_coordinate_systems + A mapping between regions and coordinate systems. The coordinate systems are used to transform both + regions coordinates for tiles as well as images. + tile_scale + The scale of the tiles. This is used only if the `regions` are `shapes`. + It is a scaling factor applied to either the radius (spots) or length (polygons) of the `shapes` + according to the geometry type of the `shapes` element: + + - if `shapes` are circles (spots), the radius is scaled by `tile_scale`. + - if `shapes` are polygons, the length of the polygon is scaled by `tile_scale`. + + If `tile_dim_in_units` is passed, `tile_scale` is ignored. + tile_dim_in_units + The dimension of the requested tile in the units of the target coordinate system. + This specifies the extent of the tile. This is not related the size in pixel of each returned tile. + rasterize + If True, the images are rasterized using :func:`spatialdata.rasterize`. + If False, they are queried using :func:`spatialdata.bounding_box_query`. + return_annotations + If not None, a value from the table is returned together with the image tile. + Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` + can be returned. If None, it will return a `SpatialData` object with only the tuple + containing the image and the table value. + transform + A callable that takes as input the tuple (image, table_value) and returns a new tuple (when + `return_annotations` is not None); a callable that takes as input the `SpatialData` object and + returns a tuple when `return_annotations` is `None`. + This parameter can be used to apply data transformations (for instance a normalization operation) to the + image and the table value. + rasterize_kwargs + Keyword arguments passed to :func:`spatialdata.rasterize` if `rasterize` is True. + This argument can be used for instance to choose the pixel dimension of the image tile. + + Returns + ------- + :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData`. + """ + INSTANCE_KEY = "instance_id" CS_KEY = "cs" REGION_KEY = "region" @@ -50,59 +106,6 @@ def __init__( transform: Callable[[Any], Any] | None = None, rasterize_kwargs: Mapping[str, Any] = MappingProxyType({}), ): - """ - :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData` object. - - By default, the dataset returns spatialdata object, but when `return_image` and `return_annot` - are set, the dataset may return a tuple containing: - - - the tile image, centered in the target coordinate system of the region. - - a vector or scalar value from the table. - - Parameters - ---------- - sdata - The :class`spatialdata.SpatialData` object. - regions_to_images - A mapping between region and images. The regions are used to compute the tile centers, while the images are - used to get the pixel values. - regions_to_coordinate_systems - A mapping between regions and coordinate systems. The coordinate systems are used to transform both - regions coordinates for tiles as well as images. - tile_scale - The scale of the tiles. This is used only if the `regions` are `shapes`. - It is a scaling factor applied to either the radius (spots) or length (polygons) of the `shapes` - according to the geometry type of the `shapes` element: - - - if `shapes` are circles (spots), the radius is scaled by `tile_scale`. - - if `shapes` are polygons, the length of the polygon is scaled by `tile_scale`. - - If `tile_dim_in_units` is passed, `tile_scale` is ignored. - tile_dim_in_units - The dimension of the requested tile in the units of the target coordinate system. - This specifies the extent of the tile. This is not related the size in pixel of each returned tile. - rasterize - If True, the images are rasterized using :func:`spatialdata.rasterize`. - If False, they are queried using :func:`spatialdata.bounding_box_query`. - return_annotations - If not None, a value from the table is returned together with the image tile. - Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` - can be returned. If None, it will return a `SpatialData` object with only the tuple - containing the image and the table value. - transform - A callable that takes as input the tuple (image, table_value) and returns a new tuple (when - `return_annotations` is not None); a callable that takes as input the `SpatialData` object and - returns a tuple when `return_annotations` is `None`. - This parameter can be used to apply data transformations (for instance a normalization operation) to the - image and the table value. - rasterize_kwargs - Keyword arguments passed to :func:`spatialdata.rasterize` if `rasterize` is True. - This argument can be used for instance to choose the pixel dimension of the image tile. - - Returns - ------- - :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData`. - """ from spatialdata import bounding_box_query from spatialdata._core.operations.rasterize import rasterize as rasterize_fn From 63e761df077b998a60e44b565696a9b9b79a46d0 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 27 Nov 2023 16:27:19 +0100 Subject: [PATCH 38/38] fix tests and docs and remove --- src/spatialdata/dataloader/datasets.py | 4 ++-- src/spatialdata/models/models.py | 8 ++++---- tests/dataloader/test_datasets.py | 15 +++++++++++---- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 52bf04e6..388db612 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -375,9 +375,9 @@ def _get_tile_coords( # get extent, first by checking shape defaults, then by using the `tile_dim_in_units` if tile_dim_in_units is None: - if elem.iloc[0][0].geom_type == "Point": + if elem.iloc[0, 0].geom_type == "Point": extent = elem[ShapesModel.RADIUS_KEY].values * tile_scale - elif elem.iloc[0][0].geom_type in ["Polygon", "MultiPolygon"]: + elif elem.iloc[0, 0].geom_type in ["Polygon", "MultiPolygon"]: extent = elem[ShapesModel.GEOMETRY_KEY].length * tile_scale else: raise ValueError("Only point and polygon shapes are supported.") diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index f7155a3b..e27a08c3 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -18,7 +18,7 @@ from multiscale_spatial_image import to_multiscale from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from multiscale_spatial_image.to_multiscale.to_multiscale import Methods -from pandas.api.types import is_categorical_dtype +from pandas import CategoricalDtype from shapely._geometry import GeometryType from shapely.geometry import MultiPolygon, Point, Polygon from shapely.geometry.collection import GeometryCollection @@ -470,7 +470,7 @@ def validate(cls, data: DaskDataFrame) -> None: raise ValueError(f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`.") if cls.ATTRS_KEY in data.attrs and "feature_key" in data.attrs[cls.ATTRS_KEY]: feature_key = data.attrs[cls.ATTRS_KEY][cls.FEATURE_KEY] - if not is_categorical_dtype(data[feature_key]): + if not isinstance(data[feature_key], CategoricalDtype): logger.info(f"Feature key `{feature_key}`could be of type `pd.Categorical`. Consider casting it.") @singledispatchmethod @@ -624,7 +624,7 @@ def _add_metadata_and_validate( # Here we are explicitly importing the categories # but it is a convenient way to ensure that the categories are known. # It also just changes the state of the series, so it is not a big deal. - if is_categorical_dtype(data[c]) and not data[c].cat.known: + if isinstance(data[c], CategoricalDtype) and not data[c].cat.known: try: data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories) except ValueError: @@ -729,7 +729,7 @@ def parse( region_: list[str] = region if isinstance(region, list) else [region] if not adata.obs[region_key].isin(region_).all(): raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.") - if not is_categorical_dtype(adata.obs[region_key]): + if not isinstance(adata.obs[region_key], CategoricalDtype): warnings.warn( f"Converting `{cls.REGION_KEY_KEY}: {region_key}` to categorical dtype.", UserWarning, stacklevel=2 ) diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index 3b99b104..dac01e80 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -60,14 +60,18 @@ def test_default(self, sdata_blobs, regions_element, raster): if raster: assert tile.shape == (3, 329, 329) else: - assert tile.shape == (3, 164, 164) + assert tile.shape == (3, 165, 164) else: raise ValueError(f"Unexpected regions_element: {regions_element}") + # extent has units in pixel so should be the same as tile shape if raster: assert round(ds.tiles_coords.extent.unique()[0] * 2) == tile.shape[1] else: - assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + if regions_element != "blobs_multipolygons": + assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + else: + assert int(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1] assert np.all(sdata_tile.table.obs.columns == ds.sdata.table.obs.columns) assert list(sdata_tile.images.keys())[0] == "blobs_image" @@ -88,11 +92,14 @@ def test_return_annot(self, sdata_blobs, regions_element, return_annot): elif regions_element == "blobs_polygons": assert tile.shape == (3, 82, 82) elif regions_element == "blobs_multipolygons": - assert tile.shape == (3, 164, 164) + assert tile.shape == (3, 165, 164) else: raise ValueError(f"Unexpected regions_element: {regions_element}") # extent has units in pixel so should be the same as tile shape - assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + if regions_element != "blobs_multipolygons": + assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + else: + assert round(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1] return_annot = [return_annot] if isinstance(return_annot, str) else return_annot assert annot.shape[1] == len(return_annot)