diff --git a/src/geome/transforms/subset.py b/src/geome/transforms/subset.py index 14f99be..0e1b6ac 100644 --- a/src/geome/transforms/subset.py +++ b/src/geome/transforms/subset.py @@ -1,58 +1,67 @@ +from __future__ import annotations + from typing import Literal +import numpy as np from anndata import AnnData from .base.transform import Transform class Subset(Transform): - """Create a subset of adata on list of observation or features. + """ + Create a subset of an AnnData object based on specified observation or feature values. + + Parameters + ---------- + key_value : dict + A dictionary where keys are observation or feature names, and values are lists of values to keep. + axis : Literal[0, 1, "obs", "var"], optional + The axis to subset on. It can be 0 or "obs" for observations and 1 or "var" for features. + Default is "obs". + copy : bool, optional + If True, return a copy of the subsetted AnnData object instead of a view. Default is False. + + Example: + import anndata as ad + import numpy as np + from geome.transforms import Subset + + # Create test data + obs_data = {"cell_type": ["B cell", "T cell", "B cell", "T cell"]} + var_data = {"gene": ["gene1", "gene2", "gene3", "gene4"]} + adata = ad.AnnData(X=np.random.rand(4, 4), obs=obs_data, var=var_data) - Input: - key_value : dict{str(obs_name): list[str(Value1), str(Value2),...],...} - The dictionary with observation columns and list of values to keep for each observation. + # Subset by observations + adata_subset_obs = Subset(key_value={"cell_type": ["B cell"]}, axis="obs")(adata) + print(adata_subset_obs) + # View of AnnData object with n_obs x n_vars = 2 x 4 + # obs: 'cell_type' + # var: 'gene' """ - def __init__(self, key_value: dict, axis: Literal[0, 1, "obs", "var"] = "obs"): + def __init__(self, key_value: dict, axis: Literal[0, 1, "obs", "var"] = "obs", copy: bool = False): self.key_value = key_value if axis not in (0, 1, "obs", "var"): raise TypeError("axis needs to be one of obs, var, 0 or 1") if isinstance(axis, int): axis = ("obs", "var")[axis] self.axis = axis + self.copy = copy def __call__(self, adata: AnnData): - """Converts the given list of observation columns in the AnnData object to categorical. - - Args: - ---- - adata: The AnnData object. - obs_list (list[str]): The list of observation columns to convert to categorical. - """ + """Subset the AnnData object based on the provided key_value and axis.""" + subset_mask = self._generate_subset_mask(adata, self.axis) if self.axis == "obs": - subset_mask = self._generate_obs_subset_mask(adata) sub_adata = adata[subset_mask] elif self.axis == "var": - subset_mask = self._generate_var_subset_mask(adata) sub_adata = adata[:, subset_mask] - return sub_adata - - def _generate_obs_subset_mask(self, adata: AnnData): - """Generate a boolean mask for selecting observations.""" - subset_mask = None - for key, values in self.key_value.items(): - if subset_mask is None: - subset_mask = adata.obs[key].isin(values) - else: - subset_mask &= adata.obs[key].isin(values) - return subset_mask + return sub_adata.copy() if self.copy else sub_adata - def _generate_var_subset_mask(self, adata: AnnData): - """Generate a boolean mask for selecting variables.""" - subset_mask = None + def _generate_subset_mask(self, adata: AnnData, axis: Literal["obs", "var"]): + """Generate a boolean mask for selecting observations or variables based on the specified axis.""" + data_attr = adata.obs if axis == "obs" else adata.var + subset_mask = np.ones(data_attr.shape[0], dtype=bool) for key, values in self.key_value.items(): - if subset_mask is None: - subset_mask = adata.var[key].isin(values) - else: - subset_mask &= adata.var[key].isin(values) + subset_mask &= data_attr[key].isin(values) return subset_mask diff --git a/tests/transforms/test_subset.py b/tests/transforms/test_subset.py new file mode 100644 index 0000000..26da15a --- /dev/null +++ b/tests/transforms/test_subset.py @@ -0,0 +1,76 @@ +import anndata as ad +import numpy as np +import pytest + +from geome.transforms import Subset + + +# copy value true or false as a fixture +@pytest.fixture(params=[True, False]) +def copy(request): + return request.param + + +@pytest.fixture +def adata(): + # Create test data for obs and var + obs_data = { + "cell_type": ["B cell", "T cell", "B cell", "T cell"], + "condition": ["healthy", "disease", "disease", "healthy"], + } + var_data = {"gene": ["gene1", "gene2", "gene3", "gene4"], "chromosome": ["chr1", "chr2", "chr1", "chr2"]} + adata = ad.AnnData(X=np.random.rand(4, 4), obs=obs_data, var=var_data) + return adata + + +@pytest.mark.parametrize( + "axis,key_value,expected_shape,expected_values", + [ + ("obs", {"cell_type": ["B cell"]}, (2, 4), {"cell_type": "B cell"}), + ("obs", {"condition": ["healthy"]}, (2, 4), {"condition": "healthy"}), + ("var", {"chromosome": ["chr1"]}, (4, 2), {"chromosome": "chr1"}), + ("var", {"gene": ["gene1", "gene3"]}, (4, 2), {"gene": ["gene1", "gene3"]}), + ], +) +def test_subset(adata, axis, key_value, expected_shape, expected_values, copy): + # Apply the Subset transformation + subset_transform = Subset(key_value, axis=axis, copy=copy) + adata_subset = subset_transform(adata) + + assert adata_subset.is_view != copy, "The subsetted AnnData object is not a view as expected." + + # Assert the expected shape + assert adata_subset.shape == expected_shape + + # Check the integrity of the subset data + for key, expected_value in expected_values.items(): + if axis == "var": + assert all(adata_subset.var[key] == expected_value) + else: + assert all(adata_subset.obs[key] == expected_value) + + # Check how X is changed + if axis == "obs": + mask = adata.obs[list(key_value.keys())[0]].isin(list(key_value.values())[0]) + assert np.all(adata_subset.X == adata.X[mask]) + else: + mask = adata.var[list(key_value.keys())[0]].isin(list(key_value.values())[0]) + assert np.all(adata_subset.X == adata.X[:, mask]) + + +def test_subset_multiple_keys(adata): + # Apply the Subset transformation with multiple keys + key_value = {"cell_type": ["B cell"], "condition": ["healthy"]} + subset_transform = Subset(key_value, axis="obs") + adata_subset = subset_transform(adata) + + # Assert the expected shape + assert adata_subset.shape == (1, 4) + + # Check the integrity of the subset data + assert all(adata_subset.obs["cell_type"] == "B cell") + assert all(adata_subset.obs["condition"] == "healthy") + + # Check how X is changed + mask = adata.obs["cell_type"].isin(key_value["cell_type"]) & adata.obs["condition"].isin(key_value["condition"]) + assert np.all(adata_subset.X == adata.X[mask])