Skip to content

Commit

Permalink
optimized two test cases relating paste1
Browse files Browse the repository at this point in the history
  • Loading branch information
anushka255 authored Sep 12, 2024
1 parent 882ef79 commit 6bc4723
Show file tree
Hide file tree
Showing 11 changed files with 1,365 additions and 886 deletions.
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import scanpy as sc
import pytest
from paste3.helper import intersect
import ot.backend

test_dir = Path(__file__).parent
input_dir = test_dir / "data/input"
Expand All @@ -25,6 +26,22 @@ def slices():
return slices


@pytest.fixture(scope="session")
def spot_distance_matrix(slices):
nx = ot.backend.NumpyBackend()

spot_distances = []
for slice in slices:
spot_distances.append(
ot.dist(
nx.from_numpy(slice.obsm["spatial"]),
nx.from_numpy(slice.obsm["spatial"]),
metric="euclidean",
)
)

return spot_distances

@pytest.fixture(scope="session")
def intersecting_slices(slices):
# Make a copy of the list
Expand Down
251 changes: 251 additions & 0 deletions tests/data/input/deltaG.csv

Large diffs are not rendered by default.

254 changes: 254 additions & 0 deletions tests/data/input/gene_distance.csv

Large diffs are not rendered by default.

32 changes: 16 additions & 16 deletions tests/data/output/H_center.csv

Large diffs are not rendered by default.

508 changes: 254 additions & 254 deletions tests/data/output/W_center.csv

Large diffs are not rendered by default.

334 changes: 167 additions & 167 deletions tests/data/output/center_slice2_pairwise.csv

Large diffs are not rendered by default.

298 changes: 149 additions & 149 deletions tests/data/output/center_slice3_pairwise.csv

Large diffs are not rendered by default.

364 changes: 182 additions & 182 deletions tests/data/output/center_slice4_pairwise.csv

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions tests/data/output/fused_gromov_wasserstein.csv
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250
0.003937007874015748,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4.705587100417314e-05,0.003889952003011575,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.003937007874015748,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Expand Down
91 changes: 24 additions & 67 deletions tests/test_paste.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import ot.backend
from ot.lp import emd
import pandas as pd
import tempfile

Expand All @@ -13,9 +12,6 @@
center_ot,
intersect,
center_NMF,
extract_data_matrix,
kl_divergence_backend,
to_dense_array,
my_fused_gromov_wasserstein,
solve_gromov_linesearch,
)
Expand Down Expand Up @@ -71,6 +67,7 @@ def test_center_alignment(slices):
n_components=15,
random_seed=0,
threshold=0.001,
max_iter=2,
dissimilarity="kl",
distributions=[slices[i].obsm["weights"] for i in range(len(slices))],
)
Expand All @@ -87,7 +84,7 @@ def test_center_alignment(slices):
)
assert_frame_equal(
pd.DataFrame(center_slice.uns["paste_H"], columns=center_slice.var.index),
pd.read_csv(output_dir / "H_center.csv"),
pd.read_csv(output_dir / "H_center.csv", index_col=0),
rtol=1e-05,
atol=1e-08,
)
Expand Down Expand Up @@ -173,88 +170,48 @@ def test_center_NMF(intersecting_slices):
)


def test_fused_gromov_wasserstein(slices):
def test_fused_gromov_wasserstein(slices, spot_distance_matrix):
np.random.seed(0)
temp_dir = Path(tempfile.mkdtemp())

common_genes = intersect(slices[0].var.index, slices[1].var.index)
sliceA = slices[0][:, common_genes]
sliceB = slices[1][:, common_genes]

nx = ot.backend.NumpyBackend()
slice1_dist = ot.dist(
nx.from_numpy(sliceA.obsm["spatial"]),
nx.from_numpy(sliceA.obsm["spatial"]),
metric="euclidean",
)
slice2_dist = ot.dist(
nx.from_numpy(sliceB.obsm["spatial"]),
nx.from_numpy(sliceB.obsm["spatial"]),
metric="euclidean",
)
slice1_distr = nx.ones((sliceA.shape[0],)) / sliceA.shape[0]
slice2_distr = nx.ones((sliceB.shape[0],)) / sliceB.shape[0]

slice1_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceA, None)))
slice2_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceB, None)))

M = nx.from_numpy(kl_divergence_backend(slice1_X + 0.01, slice2_X + 0.01))

M = np.genfromtxt(input_dir / "gene_distance.csv", delimiter=",")
pairwise_info, log = my_fused_gromov_wasserstein(
M,
slice1_dist,
slice2_dist,
slice1_distr,
slice2_distr,
spot_distance_matrix[0],
spot_distance_matrix[1],
p=nx.ones((254,)) / 254,
q=nx.ones((251,)) / 251,
G_init=None,
loss_fun="square_loss",
alpha=0.1,
log=True,
numItermax=200,
)
pd.DataFrame(pairwise_info).to_csv(temp_dir / "fused_gromov_wasserstein.csv")
# TODO: Need to figure out where the randomness is coming from
# assert_checksum_equals(temp_dir, "fused_gromov_wasserstein.csv")


def test_gromov_linesearch(slices):
common_genes = intersect(slices[1].var.index, slices[2].var.index)
sliceA = slices[1][:, common_genes]
sliceB = slices[2][:, common_genes]

nx = ot.backend.NumpyBackend()
slice1_dist = ot.dist(
nx.from_numpy(sliceA.obsm["spatial"]),
nx.from_numpy(sliceA.obsm["spatial"]),
metric="euclidean",
)
slice2_dist = ot.dist(
nx.from_numpy(sliceB.obsm["spatial"]),
nx.from_numpy(sliceB.obsm["spatial"]),
metric="euclidean",
pd.DataFrame(pairwise_info).to_csv(
temp_dir / "fused_gromov_wasserstein.csv", index=False
)
slice1_distr = nx.ones((sliceA.shape[0],)) / sliceA.shape[0]
slice2_distr = nx.ones((sliceB.shape[0],)) / sliceB.shape[0]
assert_checksum_equals(temp_dir, "fused_gromov_wasserstein.csv")

slice1_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceA, None)))
slice2_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceB, None)))

M = nx.from_numpy(kl_divergence_backend(slice1_X + 0.01, slice2_X + 0.01))
slice1_distr, slice2_distr = ot.utils.list_to_array(slice1_distr, slice2_distr)
def test_gromov_linesearch(slices, spot_distance_matrix):

constC, hC1, hC2 = ot.gromov.init_matrix(
slice1_dist, slice2_dist, slice1_distr, slice2_distr, loss_fun="square_loss"
)
nx = ot.backend.NumpyBackend()

G = slice1_distr[:, None] * slice2_distr[None, :]
Mi = M + 0.1 + ot.gromov.gwggrad(constC, hC1, hC2, G)
Mi = Mi + nx.min(Mi)
G = 1.509115054931788e-05 * np.ones((251, 264))
deltaG = np.genfromtxt(input_dir / "deltaG.csv", delimiter=",")
costG = 6.0935270338235075

Gc = emd(slice1_distr, slice2_distr, Mi)
deltaG = Gc - G
costG = nx.sum(M * G) + 0.1 * ot.gromov.gwloss(constC, hC1, hC2, G)
alpha, fc, cost_G = solve_gromov_linesearch(
G, deltaG, costG, slice1_dist, slice2_dist, M=0.0, reg=1.0, nx=nx
G,
deltaG,
costG,
spot_distance_matrix[1],
spot_distance_matrix[2],
M=0.0,
reg=1.0,
nx=nx,
)
assert alpha == 1.0
assert fc == 1
Expand Down
101 changes: 50 additions & 51 deletions tests/test_paste_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pathlib import Path
import numpy as np
import ot.backend
import pandas as pd
from pandas.testing import assert_frame_equal
from paste3.helper import (
Expand Down Expand Up @@ -34,41 +33,44 @@ def test_intersect(slices):


def test_kl_divergence_backend(slices):
nx = ot.backend.NumpyBackend()

common_genes = intersect(slices[1].var.index, slices[2].var.index)
sliceA = slices[1][:, common_genes]
sliceB = slices[2][:, common_genes]

slice1_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceA, None)))
slice2_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceB, None)))

kl_divergence_matrix = kl_divergence_backend(slice1_X + 0.01, slice2_X + 0.01)
assert_frame_equal(
pd.DataFrame(kl_divergence_matrix, columns=[str(i) for i in range(264)]),
pd.read_csv(output_dir / "kl_divergence_backend_matrix.csv"),
check_names=False,
check_dtype=False,
rtol=1e-04,
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
Y = np.array([[2, 4, 6], [8, 10, 12], [14, 16, 28]])

kl_divergence_matrix = kl_divergence_backend(X, Y)
expected_kl_divergence_matrix = np.array(
[
[0.0, 0.03323784, 0.01889736],
[0.03607688, 0.0, 0.01442773],
[0.05534049, 0.00193493, 0.02355472],
]
)
assert np.all(
np.isclose(
kl_divergence_matrix,
expected_kl_divergence_matrix,
rtol=1e-04,
)
)


def test_kl_divergence(slices):
common_genes = intersect(slices[1].var.index, slices[2].var.index)
sliceA = slices[1][:, common_genes]
sliceB = slices[2][:, common_genes]

sliceA_X, sliceB_X = to_dense_array(
extract_data_matrix(sliceA, None)
), to_dense_array(extract_data_matrix(sliceB, None))

kl_divergence_matrix = kl_divergence(sliceA_X + 0.01, sliceB_X + 0.01)
assert_frame_equal(
pd.DataFrame(kl_divergence_matrix, columns=[str(i) for i in range(264)]),
pd.read_csv(output_dir / "kl_divergence_matrix.csv"),
check_names=False,
check_dtype=False,
rtol=1e-03,
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
Y = np.array([[2, 4, 6], [8, 10, 12], [14, 16, 28]])

kl_divergence_matrix = kl_divergence(X, Y)
expected_kl_divergence_matrix = np.array(
[
[0.0, 0.03323784, 0.01889736],
[0.03607688, 0.0, 0.01442773],
[0.05534049, 0.00193493, 0.02355472],
]
)
assert np.all(
np.isclose(
kl_divergence_matrix,
expected_kl_divergence_matrix,
rtol=1e-04,
)
)


Expand All @@ -83,25 +85,23 @@ def test_filter_for_common_genes(slices):


def test_generalized_kl_divergence(slices):
common_genes = intersect(slices[1].var.index, slices[2].var.index)
sliceA = slices[1][:, common_genes]
sliceB = slices[2][:, common_genes]

sliceA_X, sliceB_X = to_dense_array(
extract_data_matrix(sliceA, None)
), to_dense_array(extract_data_matrix(sliceB, None))

generalized_kl_divergence_matrix = generalized_kl_divergence(
sliceA_X + 0.01, sliceB_X + 0.01
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
Y = np.array([[2, 4, 6], [8, 10, 12], [14, 16, 28]])

generalized_kl_divergence_matrix = generalized_kl_divergence(X, Y)
expected_kl_divergence_matrix = np.array(
[
[1.84111692, 14.54279955, 38.50128292],
[0.88830648, 4.60279229, 22.93052383],
[5.9637042, 0.69099319, 13.3879729],
]
)
assert_frame_equal(
pd.DataFrame(
generalized_kl_divergence_matrix, columns=[str(i) for i in range(264)]
),
pd.read_csv(output_dir / "generalized_kl_divergence_matrix.csv"),
check_names=False,
check_dtype=False,
rtol=1e-01,
assert np.all(
np.isclose(
generalized_kl_divergence_matrix,
expected_kl_divergence_matrix,
rtol=1e-04,
)
)


Expand All @@ -126,7 +126,6 @@ def test_glmpca_distance():
rtol=1e-04,
)


def test_pca_distance(slices2):
common_genes = intersect(slices2[1].var.index, slices2[2].var.index)
sliceA = slices2[1][:, common_genes]
Expand Down

0 comments on commit 6bc4723

Please sign in to comment.