Skip to content

Commit

Permalink
added some tests and refactored a line
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Mar 13, 2024
1 parent 6ea9a16 commit 63de968
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 24 deletions.
14 changes: 14 additions & 0 deletions tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ def _make_grid(grid_size: int) -> ArrayLike:
return np.vstack([X1.ravel(), X2.ravel()]).T


def _assert_marginals_set(adata_time, problem, key, marginal_keys):
"""Helper function to check if marginals are set correctly"""
adata_time0 = adata_time[key[0] == adata_time.obs["time"]]
adata_time1 = adata_time[key[1] == adata_time.obs["time"]]
if marginal_keys[0] is not None: # check if marginal keys are set
a = adata_time0.obs[marginal_keys[0]].values
b = adata_time1.obs[marginal_keys[1]].values
assert np.allclose(problem[key].a, a)
assert np.allclose(problem[key].b, b)
else: # otherwise check if marginals are uniform
assert np.allclose(problem[key].a, 1.0 / adata_time0.shape[0])
assert np.allclose(problem[key].b, 1.0 / adata_time1.shape[0])


class Problem(CompoundProblem[Any, OTProblem]):
@property
def _base_problem_type(self) -> Type[B]:
Expand Down
36 changes: 20 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,32 @@ def adata_y(y: Geom_t) -> AnnData:
return AnnData(X=np.asarray(y, dtype=float), obsm={"X_pca": pc})


def creat_prob(n: int, *, uniform: bool = False, seed: Optional[int] = None) -> Geom_t:
rng = np.random.RandomState(seed)
a = np.ones((n,)) if uniform else np.abs(rng.normal(size=(n,)))
a /= np.sum(a)
return jnp.asarray(a)


@pytest.fixture()
def adata_time() -> AnnData:
rng = np.random.RandomState(42)
adatas = [AnnData(X=csr_matrix(rng.normal(size=(96, 60)))) for _ in range(3)]

adatas = [
AnnData(
X=csr_matrix(rng.normal(size=(96, 60))),
obs={
"left_marginals_balanced": creat_prob(96, seed=42),
"right_marginals_balanced": creat_prob(96, seed=42),
},
)
for _ in range(3)
]
adata = ad.concat(adatas, label="time", index_unique="-")
adata.obs["time"] = pd.to_numeric(adata.obs["time"]).astype("category")
adata.obs["batch"] = rng.choice((0, 1, 2), len(adata))
adata.obs["left_marginals"] = np.ones(len(adata))
adata.obs["right_marginals"] = np.ones(len(adata))
adata.obs["left_marginals_unbalanced"] = np.ones(len(adata))
adata.obs["right_marginals_unbalanced"] = np.ones(len(adata))
adata.obs["celltype"] = rng.choice(["A", "B", "C"], size=len(adata))
# genes from mouse/human proliferation/apoptosis
genes = ["ANLN", "ANP32E", "ATAD2", "Mcm4", "Smc4", "Gtse1", "ADD1", "AIFM3", "ANKH", "Ercc5", "Serpinb5", "Inhbb"]
Expand All @@ -139,19 +156,6 @@ def adata_time() -> AnnData:
return adata


def create_marginals(n: int, m: int, *, uniform: bool = False, seed: Optional[int] = None) -> Geom_t:
rng = np.random.RandomState(seed)
if uniform:
a, b = np.ones((n,)), np.ones((m,))
else:
a = np.abs(rng.normal(size=(n,)))
b = np.abs(rng.normal(size=(m,)))
a /= np.sum(a)
b /= np.sum(b)

return jnp.asarray(a), jnp.asarray(b)


@pytest.fixture()
def gt_temporal_adata() -> AnnData:
adata = _gt_temporal_adata.copy()
Expand Down
12 changes: 12 additions & 0 deletions tests/problems/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
return adata


# keys for marginals
@pytest.fixture(
params=[
(None, None),
("left_marginals_balanced", "right_marginals_balanced"),
],
ids=["default", "balanced"],
)
def marginal_keys(request):
return request.param


sinkhorn_args_1 = {
"epsilon": 0.7,
"tau_a": 1.0,
Expand Down
16 changes: 16 additions & 0 deletions tests/problems/generic/test_fgw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from moscot.base.output import BaseSolverOutput
from moscot.base.problems import OTProblem
from moscot.problems.generic import FGWProblem
from tests._utils import _assert_marginals_set
from tests.problems.conftest import (
fgw_args_1,
fgw_args_2,
Expand Down Expand Up @@ -67,6 +68,21 @@ def test_prepare(self, adata_space_rotate: AnnData, policy):
assert key in expected_keys[policy]
assert isinstance(problem[key], OTProblem)

@pytest.mark.fast()
def test_prepare_marginals(self, adata_time: AnnData, marginal_keys):
problem = FGWProblem(adata=adata_time)
problem = problem.prepare(
key="time",
policy="sequential",
joint_attr="X_pca",
x_attr="X_pca",
y_attr="X_pca",
a=marginal_keys[0],
b=marginal_keys[1],
)
for key in problem:
_assert_marginals_set(adata_time, problem, key, marginal_keys)

def test_solve_balanced(self, adata_space_rotate: AnnData):
eps = 0.5
adata_space_rotate = adata_space_rotate[adata_space_rotate.obs["batch"].isin(("0", "1"))].copy()
Expand Down
15 changes: 14 additions & 1 deletion tests/problems/generic/test_gw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from moscot.base.output import BaseSolverOutput
from moscot.base.problems import OTProblem
from moscot.problems.generic import GWProblem
from tests._utils import _assert_marginals_set
from tests.problems.conftest import (
geometry_args,
gw_args_1,
Expand All @@ -37,7 +38,10 @@

class TestGWProblem:
@pytest.mark.fast()
@pytest.mark.parametrize("policy", ["sequential", "star"])
@pytest.mark.parametrize(
"policy",
["sequential", "star"],
)
def test_prepare(self, adata_space_rotate: AnnData, policy):
expected_keys = {
"sequential": [("0", "1"), ("1", "2")],
Expand Down Expand Up @@ -181,6 +185,15 @@ def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any,

problem = problem.solve(max_iterations=2)

@pytest.mark.fast()
def test_prepare_marginals(self, adata_time: AnnData, marginal_keys):
problem = GWProblem(adata=adata_time)
problem = problem.prepare(
a=marginal_keys[0], b=marginal_keys[1], key="time", policy="sequential", x_attr="X_pca", y_attr="X_pca"
)
for key in problem:
_assert_marginals_set(adata_time, problem, key, marginal_keys)

@pytest.mark.fast()
@pytest.mark.parametrize(
("cost_str", "cost_inst", "cost_kwargs"),
Expand Down
15 changes: 10 additions & 5 deletions tests/problems/generic/test_sinkhorn_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from moscot.base.output import BaseSolverOutput
from moscot.base.problems import OTProblem
from moscot.problems.generic import SinkhornProblem
from tests._utils import _assert_marginals_set
from tests.problems.conftest import (
geometry_args,
lin_prob_args,
Expand All @@ -37,15 +38,14 @@
class TestSinkhornProblem:
@pytest.mark.fast()
@pytest.mark.parametrize("policy", ["sequential", "star"])
def test_prepare(self, adata_time: AnnData, policy):
def test_prepare(self, adata_time: AnnData, policy, marginal_keys):
expected_keys = {"sequential": [(0, 1), (1, 2)], "star": [(1, 0), (2, 0)]}
problem = SinkhornProblem(adata=adata_time)

assert len(problem) == 0
assert problem.problems == {}
assert problem.solutions == {}

problem = problem.prepare(key="time", policy=policy, reference=0)
problem = problem.prepare(key="time", policy=policy, reference=0, a=marginal_keys[0], b=marginal_keys[1])

assert isinstance(problem.problems, dict)
assert len(problem.problems) == len(expected_keys[policy])
Expand All @@ -54,16 +54,21 @@ def test_prepare(self, adata_time: AnnData, policy):
assert key in expected_keys[policy]
assert isinstance(problem[key], OTProblem)

def test_solve_balanced(self, adata_time: AnnData):
_assert_marginals_set(adata_time, problem, key, marginal_keys)

def test_solve_balanced(self, adata_time: AnnData, marginal_keys):
eps = 0.5
expected_keys = [(0, 1), (1, 2)]
problem = SinkhornProblem(adata=adata_time)
problem = problem.prepare(key="time")
problem = problem.prepare(key="time", a=marginal_keys[0], b=marginal_keys[1])
problem = problem.solve(epsilon=eps)

for key, subsol in problem.solutions.items():
assert isinstance(subsol, BaseSolverOutput)
assert key in expected_keys
assert subsol.converged
assert np.allclose(subsol.a, problem[key].a)
assert np.allclose(subsol.b, problem[key].b)

@pytest.mark.fast()
@pytest.mark.parametrize(
Expand Down
3 changes: 1 addition & 2 deletions tests/problems/time/test_lineage_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ def test_solve_balanced(self, adata_time_barcodes: AnnData):
)
problem = problem.solve(epsilon=eps)

for key, subsol in problem.solutions.items():
for _, subsol in problem.solutions.items():
assert isinstance(subsol, BaseSolverOutput)
assert key == key

def test_solve_unbalanced(self, adata_time_barcodes: AnnData):
taus = [9e-1, 1e-2]
Expand Down

0 comments on commit 63de968

Please sign in to comment.