Skip to content

Commit

Permalink
Merge pull request #48 from theislab/feature/_to_category_iterator
Browse files Browse the repository at this point in the history
Add transform operation: subset data
  • Loading branch information
selmanozleyen authored May 16, 2024
2 parents 77fbb35 + 30d0a84 commit f2a58ed
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/geome/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .base.transform import Transform
from .categorize import Categorize
from .compose import Compose
from .subset import Subset

__all__ = [
"Transform",
Expand All @@ -13,4 +14,5 @@
"Compose",
"AddEdgeIndex",
"AddEdgeIndexFromAdj",
"Subset",
]
67 changes: 67 additions & 0 deletions src/geome/transforms/subset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from typing import Literal

import numpy as np
from anndata import AnnData

from .base.transform import Transform


class Subset(Transform):
"""
Create a subset of an AnnData object based on specified observation or feature values.
Parameters
----------
key_value : dict
A dictionary where keys are observation or feature names, and values are lists of values to keep.
axis : Literal[0, 1, "obs", "var"], optional
The axis to subset on. It can be 0 or "obs" for observations and 1 or "var" for features.
Default is "obs".
copy : bool, optional
If True, return a copy of the subsetted AnnData object instead of a view. Default is False.
Example:
import anndata as ad
import numpy as np
from geome.transforms import Subset
# Create test data
obs_data = {"cell_type": ["B cell", "T cell", "B cell", "T cell"]}
var_data = {"gene": ["gene1", "gene2", "gene3", "gene4"]}
adata = ad.AnnData(X=np.random.rand(4, 4), obs=obs_data, var=var_data)
# Subset by observations
adata_subset_obs = Subset(key_value={"cell_type": ["B cell"]}, axis="obs")(adata)
print(adata_subset_obs)
# View of AnnData object with n_obs x n_vars = 2 x 4
# obs: 'cell_type'
# var: 'gene'
"""

def __init__(self, key_value: dict, axis: Literal[0, 1, "obs", "var"] = "obs", copy: bool = False):
self.key_value = key_value
if axis not in (0, 1, "obs", "var"):
raise TypeError("axis needs to be one of obs, var, 0 or 1")
if isinstance(axis, int):
axis = ("obs", "var")[axis]
self.axis = axis
self.copy = copy

def __call__(self, adata: AnnData):
"""Subset the AnnData object based on the provided key_value and axis."""
subset_mask = self._generate_subset_mask(adata, self.axis)
if self.axis == "obs":
sub_adata = adata[subset_mask]
elif self.axis == "var":
sub_adata = adata[:, subset_mask]
return sub_adata.copy() if self.copy else sub_adata

def _generate_subset_mask(self, adata: AnnData, axis: Literal["obs", "var"]):
"""Generate a boolean mask for selecting observations or variables based on the specified axis."""
data_attr = adata.obs if axis == "obs" else adata.var
subset_mask = np.ones(data_attr.shape[0], dtype=bool)
for key, values in self.key_value.items():
subset_mask &= data_attr[key].isin(values)
return subset_mask
76 changes: 76 additions & 0 deletions tests/transforms/test_subset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import anndata as ad
import numpy as np
import pytest

from geome.transforms import Subset


# copy value true or false as a fixture
@pytest.fixture(params=[True, False])
def copy(request):
return request.param


@pytest.fixture
def adata():
# Create test data for obs and var
obs_data = {
"cell_type": ["B cell", "T cell", "B cell", "T cell"],
"condition": ["healthy", "disease", "disease", "healthy"],
}
var_data = {"gene": ["gene1", "gene2", "gene3", "gene4"], "chromosome": ["chr1", "chr2", "chr1", "chr2"]}
adata = ad.AnnData(X=np.random.rand(4, 4), obs=obs_data, var=var_data)
return adata


@pytest.mark.parametrize(
"axis,key_value,expected_shape,expected_values",
[
("obs", {"cell_type": ["B cell"]}, (2, 4), {"cell_type": "B cell"}),
("obs", {"condition": ["healthy"]}, (2, 4), {"condition": "healthy"}),
("var", {"chromosome": ["chr1"]}, (4, 2), {"chromosome": "chr1"}),
("var", {"gene": ["gene1", "gene3"]}, (4, 2), {"gene": ["gene1", "gene3"]}),
],
)
def test_subset(adata, axis, key_value, expected_shape, expected_values, copy):
# Apply the Subset transformation
subset_transform = Subset(key_value, axis=axis, copy=copy)
adata_subset = subset_transform(adata)

assert adata_subset.is_view != copy, "The subsetted AnnData object is not a view as expected."

# Assert the expected shape
assert adata_subset.shape == expected_shape

# Check the integrity of the subset data
for key, expected_value in expected_values.items():
if axis == "var":
assert all(adata_subset.var[key] == expected_value)
else:
assert all(adata_subset.obs[key] == expected_value)

# Check how X is changed
if axis == "obs":
mask = adata.obs[list(key_value.keys())[0]].isin(list(key_value.values())[0])
assert np.all(adata_subset.X == adata.X[mask])
else:
mask = adata.var[list(key_value.keys())[0]].isin(list(key_value.values())[0])
assert np.all(adata_subset.X == adata.X[:, mask])


def test_subset_multiple_keys(adata):
# Apply the Subset transformation with multiple keys
key_value = {"cell_type": ["B cell"], "condition": ["healthy"]}
subset_transform = Subset(key_value, axis="obs")
adata_subset = subset_transform(adata)

# Assert the expected shape
assert adata_subset.shape == (1, 4)

# Check the integrity of the subset data
assert all(adata_subset.obs["cell_type"] == "B cell")
assert all(adata_subset.obs["condition"] == "healthy")

# Check how X is changed
mask = adata.obs["cell_type"].isin(key_value["cell_type"]) & adata.obs["condition"].isin(key_value["condition"])
assert np.all(adata_subset.X == adata.X[mask])

0 comments on commit f2a58ed

Please sign in to comment.