Skip to content

Commit

Permalink
refactor and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed May 16, 2024
1 parent 848667f commit 30d0a84
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 32 deletions.
73 changes: 41 additions & 32 deletions src/geome/transforms/subset.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,67 @@
from __future__ import annotations

from typing import Literal

import numpy as np
from anndata import AnnData

from .base.transform import Transform


class Subset(Transform):
"""Create a subset of adata on list of observation or features.
"""
Create a subset of an AnnData object based on specified observation or feature values.
Parameters
----------
key_value : dict
A dictionary where keys are observation or feature names, and values are lists of values to keep.
axis : Literal[0, 1, "obs", "var"], optional
The axis to subset on. It can be 0 or "obs" for observations and 1 or "var" for features.
Default is "obs".
copy : bool, optional
If True, return a copy of the subsetted AnnData object instead of a view. Default is False.
Example:
import anndata as ad
import numpy as np
from geome.transforms import Subset
# Create test data
obs_data = {"cell_type": ["B cell", "T cell", "B cell", "T cell"]}
var_data = {"gene": ["gene1", "gene2", "gene3", "gene4"]}
adata = ad.AnnData(X=np.random.rand(4, 4), obs=obs_data, var=var_data)
Input:
key_value : dict{str(obs_name): list[str(Value1), str(Value2),...],...}
The dictionary with observation columns and list of values to keep for each observation.
# Subset by observations
adata_subset_obs = Subset(key_value={"cell_type": ["B cell"]}, axis="obs")(adata)
print(adata_subset_obs)
# View of AnnData object with n_obs x n_vars = 2 x 4
# obs: 'cell_type'
# var: 'gene'
"""

def __init__(self, key_value: dict, axis: Literal[0, 1, "obs", "var"] = "obs"):
def __init__(self, key_value: dict, axis: Literal[0, 1, "obs", "var"] = "obs", copy: bool = False):
self.key_value = key_value
if axis not in (0, 1, "obs", "var"):
raise TypeError("axis needs to be one of obs, var, 0 or 1")
if isinstance(axis, int):
axis = ("obs", "var")[axis]
self.axis = axis
self.copy = copy

def __call__(self, adata: AnnData):
"""Converts the given list of observation columns in the AnnData object to categorical.
Args:
----
adata: The AnnData object.
obs_list (list[str]): The list of observation columns to convert to categorical.
"""
"""Subset the AnnData object based on the provided key_value and axis."""
subset_mask = self._generate_subset_mask(adata, self.axis)
if self.axis == "obs":
subset_mask = self._generate_obs_subset_mask(adata)
sub_adata = adata[subset_mask]
elif self.axis == "var":
subset_mask = self._generate_var_subset_mask(adata)
sub_adata = adata[:, subset_mask]
return sub_adata

def _generate_obs_subset_mask(self, adata: AnnData):
"""Generate a boolean mask for selecting observations."""
subset_mask = None
for key, values in self.key_value.items():
if subset_mask is None:
subset_mask = adata.obs[key].isin(values)
else:
subset_mask &= adata.obs[key].isin(values)
return subset_mask
return sub_adata.copy() if self.copy else sub_adata

def _generate_var_subset_mask(self, adata: AnnData):
"""Generate a boolean mask for selecting variables."""
subset_mask = None
def _generate_subset_mask(self, adata: AnnData, axis: Literal["obs", "var"]):
"""Generate a boolean mask for selecting observations or variables based on the specified axis."""
data_attr = adata.obs if axis == "obs" else adata.var
subset_mask = np.ones(data_attr.shape[0], dtype=bool)
for key, values in self.key_value.items():
if subset_mask is None:
subset_mask = adata.var[key].isin(values)
else:
subset_mask &= adata.var[key].isin(values)
subset_mask &= data_attr[key].isin(values)
return subset_mask
76 changes: 76 additions & 0 deletions tests/transforms/test_subset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import anndata as ad
import numpy as np
import pytest

from geome.transforms import Subset


# copy value true or false as a fixture
@pytest.fixture(params=[True, False])
def copy(request):
return request.param


@pytest.fixture
def adata():
# Create test data for obs and var
obs_data = {
"cell_type": ["B cell", "T cell", "B cell", "T cell"],
"condition": ["healthy", "disease", "disease", "healthy"],
}
var_data = {"gene": ["gene1", "gene2", "gene3", "gene4"], "chromosome": ["chr1", "chr2", "chr1", "chr2"]}
adata = ad.AnnData(X=np.random.rand(4, 4), obs=obs_data, var=var_data)
return adata


@pytest.mark.parametrize(
"axis,key_value,expected_shape,expected_values",
[
("obs", {"cell_type": ["B cell"]}, (2, 4), {"cell_type": "B cell"}),
("obs", {"condition": ["healthy"]}, (2, 4), {"condition": "healthy"}),
("var", {"chromosome": ["chr1"]}, (4, 2), {"chromosome": "chr1"}),
("var", {"gene": ["gene1", "gene3"]}, (4, 2), {"gene": ["gene1", "gene3"]}),
],
)
def test_subset(adata, axis, key_value, expected_shape, expected_values, copy):
# Apply the Subset transformation
subset_transform = Subset(key_value, axis=axis, copy=copy)
adata_subset = subset_transform(adata)

assert adata_subset.is_view != copy, "The subsetted AnnData object is not a view as expected."

# Assert the expected shape
assert adata_subset.shape == expected_shape

# Check the integrity of the subset data
for key, expected_value in expected_values.items():
if axis == "var":
assert all(adata_subset.var[key] == expected_value)
else:
assert all(adata_subset.obs[key] == expected_value)

# Check how X is changed
if axis == "obs":
mask = adata.obs[list(key_value.keys())[0]].isin(list(key_value.values())[0])
assert np.all(adata_subset.X == adata.X[mask])
else:
mask = adata.var[list(key_value.keys())[0]].isin(list(key_value.values())[0])
assert np.all(adata_subset.X == adata.X[:, mask])


def test_subset_multiple_keys(adata):
# Apply the Subset transformation with multiple keys
key_value = {"cell_type": ["B cell"], "condition": ["healthy"]}
subset_transform = Subset(key_value, axis="obs")
adata_subset = subset_transform(adata)

# Assert the expected shape
assert adata_subset.shape == (1, 4)

# Check the integrity of the subset data
assert all(adata_subset.obs["cell_type"] == "B cell")
assert all(adata_subset.obs["condition"] == "healthy")

# Check how X is changed
mask = adata.obs["cell_type"].isin(key_value["cell_type"]) & adata.obs["condition"].isin(key_value["condition"])
assert np.all(adata_subset.X == adata.X[mask])

0 comments on commit 30d0a84

Please sign in to comment.