Skip to content

Commit

Permalink
parameterized some test functions (#51)
Browse files Browse the repository at this point in the history
Co-authored-by: Vineet Bansal <[email protected]>
  • Loading branch information
anushka255 and vineetbansal authored Sep 25, 2024
1 parent c4b1594 commit e3cff96
Show file tree
Hide file tree
Showing 17 changed files with 21,415 additions and 50 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ dependencies = [
"scipy",
"scikit-learn",
"IPython",
"statsmodels"
"statsmodels",
"torch"
]
dynamic = ["version"]

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ coverage==7.6.1
coveralls==4.0.1
ruff==0.6.6
pre-commit==3.8.0
torch==2.4.1
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,28 @@
import scanpy as sc
import pytest
from paste3.helper import intersect
import torch
import ot.backend

test_dir = Path(__file__).parent
input_dir = test_dir / "data/input"


def pytest_generate_tests(metafunc):
if "use_gpu" and "backend" in metafunc.fixturenames:
if torch.cuda.is_available():
metafunc.parametrize(
"use_gpu, backend",
[(True, ot.backend.TorchBackend()), (False, ot.backend.NumpyBackend())],
)
else:
metafunc.parametrize(
"use_gpu, backend", [(False, ot.backend.NumpyBackend())]
)
if "gpu_verbose" in metafunc.fixturenames:
metafunc.parametrize("gpu_verbose", [True, False])


@pytest.fixture(scope="session")
def slices():
slices = []
Expand Down
252 changes: 252 additions & 0 deletions tests/data/output/gwloss_partial_kl_loss.csv

Large diffs are not rendered by default.

252 changes: 252 additions & 0 deletions tests/data/output/partial_fused_gromov_wasserstein_true.csv

Large diffs are not rendered by default.

2,930 changes: 2,930 additions & 0 deletions tests/data/output/partial_pairwise_align_euc.csv

Large diffs are not rendered by default.

2,930 changes: 2,930 additions & 0 deletions tests/data/output/partial_pairwise_align_gkl.csv

Large diffs are not rendered by default.

2,930 changes: 2,930 additions & 0 deletions tests/data/output/partial_pairwise_align_glmpca.csv

Large diffs are not rendered by default.

2,930 changes: 2,930 additions & 0 deletions tests/data/output/partial_pairwise_align_histology.csv

Large diffs are not rendered by default.

2,930 changes: 2,930 additions & 0 deletions tests/data/output/partial_pairwise_align_kl.csv

Large diffs are not rendered by default.

2,930 changes: 2,930 additions & 0 deletions tests/data/output/partial_pairwise_align_pca.csv

Large diffs are not rendered by default.

2,930 changes: 2,930 additions & 0 deletions tests/data/output/partial_pairwise_align_selection_kl.csv

Large diffs are not rendered by default.

255 changes: 255 additions & 0 deletions tests/data/output/spots_mapping_false.csv

Large diffs are not rendered by default.

File renamed without changes.
48 changes: 30 additions & 18 deletions tests/test_paste.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import hashlib
from pathlib import Path

import numpy as np
import ot.backend
import pandas as pd
Expand Down Expand Up @@ -36,26 +35,29 @@ def assert_checksum_equals(temp_dir, filename):
)


def test_pairwise_alignment(slices):
temp_dir = Path(tempfile.mkdtemp())
def test_pairwise_alignment(slices, use_gpu, backend, gpu_verbose):
outcome = pairwise_align(
slices[0],
slices[1],
alpha=0.1,
dissimilarity="kl",
a_distribution=slices[0].obsm["weights"],
b_distribution=slices[1].obsm["weights"],
a_distribution=slices[0].obsm["weights"].astype(slices[0].X.dtype),
b_distribution=slices[1].obsm["weights"].astype(slices[1].X.dtype),
G_init=None,
use_gpu=use_gpu,
backend=backend,
gpu_verbose=gpu_verbose,
)
pd.DataFrame(
probability_mapping = pd.DataFrame(
outcome, index=slices[0].obs.index, columns=slices[1].obs.index
).to_csv(temp_dir / "slices_1_2_pairwise.csv")
assert_checksum_equals(temp_dir, "slices_1_2_pairwise.csv")
)
true_probability_mapping = pd.read_csv(
output_dir / "slices_1_2_pairwise.csv", index_col=0
)
assert_frame_equal(probability_mapping, true_probability_mapping, check_dtype=False)


def test_center_alignment(slices):
temp_dir = Path(tempfile.mkdtemp())

def test_center_alignment(slices, use_gpu, backend, gpu_verbose):
# Make a copy of the list
slices = list(slices)
n_slices = len(slices)
Expand All @@ -69,8 +71,13 @@ def test_center_alignment(slices):
threshold=0.001,
max_iter=2,
dissimilarity="kl",
use_gpu=True,
distributions=[slices[i].obsm["weights"] for i in range(len(slices))],
use_gpu=use_gpu,
backend=backend,
gpu_verbose=gpu_verbose,
distributions=[
slices[i].obsm["weights"].astype(slices[i].X.dtype)
for i in range(len(slices))
],
)
assert_frame_equal(
pd.DataFrame(
Expand All @@ -81,20 +88,25 @@ def test_center_alignment(slices):
pd.read_csv(output_dir / "W_center.csv", index_col=0),
check_names=False,
rtol=1e-05,
atol=1e-08,
atol=1e-04,
check_dtype=False,
)
assert_frame_equal(
pd.DataFrame(center_slice.uns["paste_H"], columns=center_slice.var.index),
pd.read_csv(output_dir / "H_center.csv", index_col=0),
rtol=1e-05,
atol=1e-08,
atol=1e-04,
check_dtype=False,
)

for i, pi in enumerate(pairwise_info):
pd.DataFrame(
pairwise_mapping = pd.DataFrame(
pi, index=center_slice.obs.index, columns=slices[i].obs.index
).to_csv(temp_dir / f"center_slice{i + 1}_pairwise.csv")
assert_checksum_equals(temp_dir, f"center_slice{i + 1}_pairwise.csv")
)
true_pairwise_mapping = pd.read_csv(
output_dir / f"center_slice{i + 1}_pairwise.csv", index_col=0
)
assert_frame_equal(pairwise_mapping, true_pairwise_mapping, check_dtype=False)


def test_center_ot(slices):
Expand Down
117 changes: 89 additions & 28 deletions tests/test_paste2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,85 @@
output_dir = test_dir / "data/output"


def test_partial_pairwise_align(slices2):
pi_BC = partial_pairwise_align(slices2[0], slices2[1], s=0.7)
def pytest_generate_tests(metafunc):
if "loss_fun" in metafunc.fixturenames:
metafunc.parametrize(
"loss_fun, filename",
[
("square_loss", "gwloss_partial.csv"),
("kl_loss", "gwloss_partial_kl_loss.csv"),
],
)
if "dissimilarity" in metafunc.fixturenames:
metafunc.parametrize(
"dissimilarity, filename",
[
("euc", "partial_pairwise_align_euc.csv"),
("gkl", "partial_pairwise_align_gkl.csv"),
("kl", "partial_pairwise_align_kl.csv"),
("selection_kl", "partial_pairwise_align_selection_kl.csv"),
("pca", "partial_pairwise_align_pca.csv"),
("glmpca", "partial_pairwise_align_glmpca.csv"),
],
)
if "armijo" in metafunc.fixturenames:
metafunc.parametrize(
"armijo, expected_log, filename",
[
(
False,
{
"err": [0.047201842558232954],
"loss": [
52.31031712851437,
35.35388862002473,
30.84819243143108,
30.770197475353303,
30.7643461256797,
30.76336403641352,
30.76332791868975,
30.762808654741757,
30.762727812006336,
30.762727812006336,
],
"partial_fgw_cost": 30.762727812006336,
},
"partial_fused_gromov_wasserstein.csv",
),
(
True,
{
"err": [0.047201842558232954, 9.659795787581263e-08],
"loss": [
53.40351168112148,
35.56234792074653,
30.897730857089122,
30.77217881677637,
30.764588004718373,
30.763380009717963,
30.76332859918154,
30.762818343959903,
30.762728863994322,
30.76272782254089,
30.76272781211168,
],
"partial_fgw_cost": 30.76272781211168,
},
"partial_fused_gromov_wasserstein_true.csv",
),
],
)


def test_partial_pairwise_align(slices2, dissimilarity, filename):
pi_BC = partial_pairwise_align(
slices2[0], slices2[1], s=0.7, dissimilarity=dissimilarity
)
pd.DataFrame(pi_BC).to_csv(output_dir / filename, index=False)

assert_frame_equal(
pd.DataFrame(pi_BC, columns=[str(i) for i in range(pi_BC.shape[1])]),
pd.read_csv(output_dir / "partial_pairwise_align.csv"),
pd.read_csv(output_dir / filename),
rtol=1e-03,
atol=1e-03,
)
Expand Down Expand Up @@ -61,16 +134,19 @@ def test_partial_pairwise_align_given_cost_matrix(slices):
assert log == pytest.approx(expected_log)


@pytest.mark.skip
def test_partial_pairwise_align_histology(slices2):
# TODO: this function doesn't seem to be called anywhere and also seems to be incomplete

pairwise_info, log = partial_pairwise_align_histology(
slices2[0], slices2[1], return_obj=True, dissimilarity="euclidean"
slices2[0], slices2[1], s=0.7, return_obj=True, dissimilarity="euclidean"
)
assert round(log, 3) == round(78.30015827691841, 3)
assert_frame_equal(
pd.DataFrame(pairwise_info, columns=[str(i) for i in range(2877)]),
pd.read_csv(output_dir / "partial_pairwise_align_histology.csv"),
rtol=1e-05,
)


def test_partial_fused_gromov_wasserstein(slices):
def test_partial_fused_gromov_wasserstein(slices, armijo, expected_log, filename):
common_genes = intersect(slices[1].var.index, slices[2].var.index)
sliceA = slices[1][:, common_genes]
sliceB = slices[2][:, common_genes]
Expand All @@ -96,32 +172,17 @@ def test_partial_fused_gromov_wasserstein(slices):
distance_b,
np.ones((sliceA.shape[0],)) / sliceA.shape[0],
np.ones((sliceB.shape[0],)) / sliceB.shape[0],
armijo=armijo,
alpha=0.1,
m=0.7,
G0=None,
loss_fun="square_loss",
log=True,
)
expected_log = {
"err": [0.047201842558232954],
"loss": [
52.31031712851437,
35.35388862002473,
30.84819243143108,
30.770197475353303,
30.7643461256797,
30.76336403641352,
30.76332791868975,
30.762808654741757,
30.762727812006336,
30.762727812006336,
],
"partial_fgw_cost": 30.762727812006336,
}

assert_frame_equal(
pd.DataFrame(pairwise_info, columns=[str(i) for i in range(264)]),
pd.read_csv(output_dir / "partial_fused_gromov_wasserstein.csv"),
pd.read_csv(output_dir / filename),
rtol=1e-05,
)

Expand Down Expand Up @@ -160,7 +221,7 @@ def test_gloss_partial(slices):
assert output == expected_output


def test_gwloss_partial(slices):
def test_gwloss_partial(slices, loss_fun, filename):
common_genes = intersect(slices[1].var.index, slices[2].var.index)
sliceA = slices[1][:, common_genes]
sliceB = slices[2][:, common_genes]
Expand All @@ -185,9 +246,9 @@ def test_gwloss_partial(slices):
np.ones((sliceB.shape[0],)) / sliceB.shape[0],
)

output = gwgrad_partial(distance_a, distance_b, G0, loss_fun="square_loss")
output = gwgrad_partial(distance_a, distance_b, G0, loss_fun=loss_fun)

assert_frame_equal(
pd.DataFrame(output, columns=[str(i) for i in range(264)]),
pd.read_csv(output_dir / "gwloss_partial.csv"),
pd.read_csv(output_dir / filename),
)
11 changes: 8 additions & 3 deletions tests/test_paste_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
import numpy as np
import pandas as pd
import pytest
from pandas.testing import assert_frame_equal
from paste3.helper import (
intersect,
Expand Down Expand Up @@ -166,17 +167,21 @@ def test_high_umi_gene_distance(slices):
)


def test_match_spots_using_spatial_heuristic(slices):
@pytest.mark.parametrize(
"_use_ot, filename",
[(True, "spots_mapping_true.csv"), (False, "spots_mapping_false.csv")],
)
def test_match_spots_using_spatial_heuristic(slices, _use_ot, filename):
# creating a copy of the original list
slices = list(slices)
filter_for_common_genes(slices)

spots_mapping = match_spots_using_spatial_heuristic(
slices[0].X, slices[1].X, use_ot=True
slices[0].X, slices[1].X, use_ot=bool(_use_ot)
)
assert_frame_equal(
pd.DataFrame(spots_mapping, columns=[str(i) for i in range(251)]),
pd.read_csv(output_dir / "spots_mapping.csv"),
pd.read_csv(output_dir / filename),
check_names=False,
check_dtype=False,
rtol=1e-04,
Expand Down

0 comments on commit e3cff96

Please sign in to comment.