diff --git a/docs/source/notebooks/paste2_tutorial.ipynb b/docs/source/notebooks/paste2_tutorial.ipynb index 852aa39..b5bfd52 100644 --- a/docs/source/notebooks/paste2_tutorial.ipynb +++ b/docs/source/notebooks/paste2_tutorial.ipynb @@ -249,15 +249,15 @@ "):\n", " if (\"color\" not in kwargs) and (\"c\" not in kwargs):\n", " kwargs[\"color\"] = \"k\"\n", - " mx = G.max()\n", + " mx = G.max().item()\n", " # idx = np.where(G/mx>=thr)\n", - " idx = largest_indices(G, top)\n", + " idx = largest_indices(G.cpu().numpy(), top)\n", " for i in range(len(idx[0])):\n", " plt.plot(\n", " [xs[idx[0][i], 0], xt[idx[1][i], 0]],\n", " [xs[idx[0][i], 1], xt[idx[1][i], 1]],\n", " alpha=alpha * (1 - weight_alpha)\n", - " + (weight_alpha * G[idx[0][i], idx[1][i]] / mx),\n", + " + (weight_alpha * G[idx[0][i], idx[1][i]].item() / mx),\n", " c=\"k\",\n", " )\n", "\n", diff --git a/docs/source/notebooks/paste_tutorial.ipynb b/docs/source/notebooks/paste_tutorial.ipynb index de17acd..fd8c58f 100644 --- a/docs/source/notebooks/paste_tutorial.ipynb +++ b/docs/source/notebooks/paste_tutorial.ipynb @@ -40,6 +40,7 @@ "import time\n", "import pandas as pd\n", "import numpy as np\n", + "import torch\n", "import scanpy as sc\n", "import matplotlib.pyplot as plt\n", "import matplotlib.patches as mpatches\n", @@ -251,7 +252,7 @@ }, "outputs": [], "source": [ - "pd.DataFrame(pi12)" + "pd.DataFrame(pi12.cpu().numpy())" ] }, { @@ -454,7 +455,11 @@ "\n", "b = []\n", "for i in range(len(slices)):\n", - " b.append(match_spots_using_spatial_heuristic(slices[0].X, slices[i].X))" + " b.append(\n", + " torch.Tensor(\n", + " match_spots_using_spatial_heuristic(slices[0].X, slices[i].X)\n", + " ).double()\n", + " )" ] }, { @@ -796,9 +801,9 @@ "source": [ "start = time.time()\n", "\n", - "pi12 = pairwise_align(slice1, slice2, backend=ot.backend.TorchBackend(), use_gpu=False)\n", - "pi23 = pairwise_align(slice2, slice3, backend=ot.backend.TorchBackend(), use_gpu=False)\n", - "pi34 = pairwise_align(slice3, slice4, backend=ot.backend.TorchBackend(), use_gpu=False)\n", + "pi12 = pairwise_align(slice1, slice2, backend=ot.backend.TorchBackend(), use_gpu=True)\n", + "pi23 = pairwise_align(slice2, slice3, backend=ot.backend.TorchBackend(), use_gpu=True)\n", + "pi34 = pairwise_align(slice3, slice4, backend=ot.backend.TorchBackend(), use_gpu=True)\n", "\n", "print(\"Runtime: \" + str(time.time() - start))" ] @@ -814,7 +819,7 @@ }, "outputs": [], "source": [ - "pd.DataFrame(pi12)" + "pd.DataFrame(pi12.cpu().numpy())" ] }, { @@ -864,7 +869,7 @@ " lmbda,\n", " random_seed=5,\n", " backend=ot.backend.TorchBackend(),\n", - " use_gpu=False,\n", + " use_gpu=True,\n", ")\n", "\n", "print(\"Runtime: \" + str(time.time() - start))" diff --git a/scripts/workflow.py b/scripts/workflow.py new file mode 100644 index 0000000..a7c51bd --- /dev/null +++ b/scripts/workflow.py @@ -0,0 +1,145 @@ +from typing import Optional, List, Tuple +from pathlib import Path +import scanpy as sc +import numpy as np +from anndata import AnnData +import logging +from paste3.helper import match_spots_using_spatial_heuristic +from paste3.visualization import stack_slices_pairwise, stack_slices_center +from paste3.paste import pairwise_align, center_align + + +logger = logging.getLogger(__name__) + + +class Slice: + def __init__( + self, filepath: Optional[Path] = None, adata: Optional[AnnData] = None + ): + if adata is None: + self.adata = sc.read_h5ad(filepath) + else: + self.adata = adata + + def __str__(self): + return f"Slice {self.adata}" + + +class AlignmentDataset: + @staticmethod + def from_csvs(gene_expression_csvs: List[Path], coordinate_csvs: List[Path]): + pass + + def __init__( + self, + data_dir: Optional[Path] = None, + slices: Optional[List[Slice]] = None, + max_slices: Optional[int] = None, + ): + if slices is not None: + self.slices = slices[:max_slices] + else: + self.slices = [ + Slice(filepath) + for filepath in sorted(Path(data_dir).glob("*.h5ad"))[:max_slices] + ] + + def __str__(self): + return f"Data with {len(self.slices)} slices" + + def __iter__(self): + return iter(self.slices) + + def __len__(self): + return len(self.slices) + + @property + def slices_adata(self) -> List[AnnData]: + return [slice_.adata for slice_ in self.slices] + + def align( + self, + center_align: bool = False, + center_slice: Optional[Slice] = None, + pis: Optional[np.ndarray] = None, + overlap_fraction: Optional[float] = None, + max_iters: int = 1000, + ): + if center_align: + if overlap_fraction is not None: + logger.warning( + "Ignoring overlap_fraction argument (unsupported in center_align mode)" + ) + return self.center_align(center_slice, pis) + else: + assert overlap_fraction is not None, "overlap_fraction must be specified" + return self.pairwise_align( + overlap_fraction=overlap_fraction, pis=pis, max_iters=max_iters + ) + + def find_pis(self, overlap_fraction: float, max_iters: int = 1000): + pis = [] + for i in range(len(self) - 1): + logger.info(f"Finding Pi for slices {i} and {i+1}") + pis.append( + pairwise_align( + self.slices[i].adata, + self.slices[i + 1].adata, + s=overlap_fraction, + numItermax=max_iters, + maxIter=max_iters, + verbose=True, + ) + ) + return pis + + def pairwise_align( + self, + overlap_fraction: float, + pis: Optional[List[np.ndarray]] = None, + max_iters: int = 1000, + ): + if pis is None: + pis = self.find_pis(overlap_fraction=overlap_fraction, max_iters=max_iters) + new_slices = stack_slices_pairwise(self.slices_adata, pis) + return AlignmentDataset(slices=[Slice(adata=s) for s in new_slices]) + + def find_center_slice( + self, reference_slice: Optional[Slice] = None, pis: Optional[np.ndarray] = None + ) -> Tuple[Slice, List[np.ndarray]]: + if reference_slice is None: + reference_slice = self.slices[0] + center_slice, pis = center_align( + reference_slice.adata, self.slices_adata, pis_init=pis + ) + return Slice(adata=center_slice), pis + + def find_pis_init(self) -> List[np.ndarray]: + reference_slice = self.slices[0] + return [ + match_spots_using_spatial_heuristic(reference_slice.adata.X, slice_.adata.X) + for slice_ in self.slices + ] + + def center_align( + self, + reference_slice: Optional[Slice] = None, + pis: Optional[List[np.ndarray]] = None, + ): + if reference_slice is None: + reference_slice, pis = self.find_center_slice(pis=pis) + else: + pis = self.find_pis_init() + + _, new_slices = stack_slices_center( + center_slice=reference_slice.adata, slices=self.slices_adata, pis=pis + ) + return AlignmentDataset(slices=[Slice(adata=s) for s in new_slices]) + + +if __name__ == "__main__": + dataset = AlignmentDataset("data/", max_slices=3) + aligned_dataset = dataset.align( + center_align=False, overlap_fraction=0.7, max_iters=2 + ) + print(aligned_dataset) diff --git a/src/paste3/align.py b/src/paste3/align.py index d163005..f4d2edc 100644 --- a/src/paste3/align.py +++ b/src/paste3/align.py @@ -32,7 +32,7 @@ def align( max_iter=10, norm=False, numItermax=200, - use_gpu=False, + use_gpu=True, return_obj=False, optimizeTheta=True, eps=1e-4, @@ -99,7 +99,7 @@ def align( b_distribution=slices[i + 1].obsm["weights"], norm=norm, numItermax=numItermax, - backend=ot.backend.NumpyBackend(), + backend=ot.backend.TorchBackend(), use_gpu=use_gpu, return_obj=return_obj, maxIter=max_iter, @@ -110,7 +110,9 @@ def align( ) pis.append(pi) pd.DataFrame( - pi, index=slices[i].obs.index, columns=slices[i + 1].obs.index + pi.cpu().numpy(), + index=slices[i].obs.index, + columns=slices[i + 1].obs.index, ).to_csv(output_directory / f"slice_{i+1}_{i+2}_pairwise.csv") if coordinates: @@ -135,14 +137,16 @@ def align( random_seed=seed, pis_init=pis_init, distributions=[slice.obsm["weights"] for slice in slices], - backend=ot.backend.NumpyBackend(), + backend=ot.backend.TorchBackend(), use_gpu=use_gpu, ) center_slice.write(output_directory / "center_slice.h5ad") for i in range(len(pis) - 1): pd.DataFrame( - pis[i], index=center_slice.obs.index, columns=slices[i].obs.index + pis[i].cpu().numpy(), + index=center_slice.obs.index, + columns=slices[i].obs.index, ).to_csv(output_directory / f"slice_{i}_{i+1}_pairwise.csv") if coordinates: diff --git a/src/paste3/helper.py b/src/paste3/helper.py index 861f668..3ed2a9a 100644 --- a/src/paste3/helper.py +++ b/src/paste3/helper.py @@ -5,6 +5,7 @@ from typing import List from anndata import AnnData import numpy as np +import torch import scipy import ot @@ -22,11 +23,11 @@ def kl_divergence(X, Y): X = X / X.sum(axis=1, keepdims=True) Y = Y / Y.sum(axis=1, keepdims=True) - log_X = np.log(X) - log_Y = np.log(Y) - X_log_X = np.matrix([np.dot(X[i], log_X[i].T) for i in range(X.shape[0])]) - D = X_log_X.T - np.dot(X, log_Y.T) - return np.asarray(D) + log_X = X.log() + log_Y = Y.log() + X_log_X = torch.sum(X * log_X, axis=1)[torch.newaxis, :] + D = X_log_X.T - torch.matmul(X, log_Y.T) + return D def generalized_kl_divergence(X, Y): @@ -40,14 +41,14 @@ def generalized_kl_divergence(X, Y): """ assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features." - log_X = np.log(X) - log_Y = np.log(Y) - X_log_X = np.matrix([np.dot(X[i], log_X[i].T) for i in range(X.shape[0])]) - D = X_log_X.T - np.dot(X, log_Y.T) - sum_X = np.sum(X, axis=1) - sum_Y = np.sum(Y, axis=1) + log_X = X.log() + log_Y = Y.log() + X_log_X = torch.sum(X * log_X, axis=1)[torch.newaxis, :] + D = X_log_X.T - torch.matmul(X, log_Y.T) + sum_X = torch.sum(X, axis=1) + sum_Y = torch.sum(Y, axis=1) D = (D.T - sum_X).T + sum_Y.T - return np.asarray(D) + return D def glmpca_distance( @@ -72,15 +73,15 @@ def glmpca_distance( """ assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features." - joint_matrix = np.vstack((X, Y)) + joint_matrix = torch.vstack((X, Y)) if filter: - gene_umi_counts = np.sum(joint_matrix, axis=0) + gene_umi_counts = torch.sum(joint_matrix, axis=0).cpu().numpy() top_indices = np.sort((-gene_umi_counts).argsort(kind="stable")[:2000]) joint_matrix = joint_matrix[:, top_indices] print("Starting GLM-PCA...") res = glmpca( - joint_matrix.T, + joint_matrix.T.cpu().numpy(), # TODO: Use Tensors latent_dim, penalty=1, verbose=verbose, @@ -116,13 +117,13 @@ def high_umi_gene_distance(X, Y, n): """ assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features." - joint_matrix = np.vstack((X, Y)) - gene_umi_counts = np.sum(joint_matrix, axis=0) + joint_matrix = torch.vstack((X, Y)) + gene_umi_counts = torch.sum(joint_matrix, axis=0).cpu().numpy() top_indices = np.sort((-gene_umi_counts).argsort(kind="stable")[:n]) X = X[:, top_indices] Y = Y[:, top_indices] - X += np.tile(0.01 * (np.sum(X, axis=1) / X.shape[1]), (X.shape[1], 1)).T - Y += np.tile(0.01 * (np.sum(Y, axis=1) / Y.shape[1]), (Y.shape[1], 1)).T + X += torch.tile(0.01 * (torch.sum(X, axis=1) / X.shape[1]), (X.shape[1], 1)).T + Y += torch.tile(0.01 * (torch.sum(Y, axis=1) / Y.shape[1]), (Y.shape[1], 1)).T return kl_divergence(X, Y) @@ -150,7 +151,8 @@ def norm_and_center_coordinates(X): ## Covert a sparse matrix into a dense matrix def to_dense_array(X): - return np.array(X.todense()) if isinstance(X, scipy.sparse.csr.spmatrix) else X + np_array = np.array(X.todense()) if isinstance(X, scipy.sparse.csr.spmatrix) else X + return torch.Tensor(np_array).double() def extract_data_matrix(adata, rep=None): @@ -190,6 +192,7 @@ def match_spots_using_spatial_heuristic(X, Y, use_ot: bool = True) -> np.ndarray Returns: Mapping of spots using a spatial heuristic. """ + # X, Y = X.todense(), Y.todense() n1, n2 = len(X), len(Y) X, Y = norm_and_center_coordinates(X), norm_and_center_coordinates(Y) dist = scipy.spatial.distance_matrix(X, Y) @@ -238,7 +241,7 @@ def kl_divergence_backend(X, Y): def dissimilarity_metric(which, sliceA, sliceB, A, B, **kwargs): match which: case "euc" | "euclidean": - return scipy.spatial.distance.cdist(A, B) + return torch.cdist(A, B) case "gkl": s_A = A + 0.01 s_B = B + 0.01 @@ -254,8 +257,10 @@ def dissimilarity_metric(which, sliceA, sliceB, A, B, **kwargs): case "selection_kl": return high_umi_gene_distance(A, B, 2000) case "pca": - return pca_distance(sliceA, sliceB, 2000, 20) + # TODO: Modify this function to work with Tensors + return torch.Tensor(pca_distance(sliceA, sliceB, 2000, 20)).double() case "glmpca": - return glmpca_distance(A, B, **kwargs) + # TODO: Modify this function to work with Tensors + return torch.Tensor(glmpca_distance(A, B, **kwargs)).double() case _: raise RuntimeError(f"Error: Invalid dissimilarity metric {which}") diff --git a/src/paste3/model_selection.py b/src/paste3/model_selection.py index b46688b..563ab15 100644 --- a/src/paste3/model_selection.py +++ b/src/paste3/model_selection.py @@ -1,4 +1,5 @@ import numpy as np +import torch from scipy.spatial import ConvexHull from matplotlib.path import Path from scipy.spatial.distance import cdist @@ -88,7 +89,7 @@ def calculate_convex_hull_edge_inconsistency(sliceA, sliceB, pi): sliceA = sliceA.copy() source_split = [] - source_mass = np.sum(pi, axis=1) + source_mass = torch.sum(pi, axis=1) for i in range(len(source_mass)): if source_mass[i] > 0: source_split.append("true") @@ -97,10 +98,11 @@ def calculate_convex_hull_edge_inconsistency(sliceA, sliceB, pi): sliceA.obs["aligned"] = source_split source_mapped_points = [] - source_mass = np.sum(pi, axis=1) + source_mass = torch.sum(pi, axis=1) for i in range(len(source_mass)): if source_mass[i] > 0: source_mapped_points.append(sliceA.obsm["spatial"][i]) + source_mapped_points = np.array(source_mapped_points) source_hull = ConvexHull(source_mapped_points) source_hull_path = Path(source_mapped_points[source_hull.vertices]) @@ -116,7 +118,7 @@ def calculate_convex_hull_edge_inconsistency(sliceA, sliceB, pi): sliceB = sliceB.copy() target_split = [] - target_mass = np.sum(pi, axis=0) + target_mass = torch.sum(pi, axis=0) for i in range(len(target_mass)): if target_mass[i] > 0: target_split.append("true") @@ -125,7 +127,7 @@ def calculate_convex_hull_edge_inconsistency(sliceA, sliceB, pi): sliceB.obs["aligned"] = target_split target_mapped_points = [] - target_mass = np.sum(pi, axis=0) + target_mass = torch.sum(pi, axis=0) for i in range(len(target_mass)): if target_mass[i] > 0: target_mapped_points.append(sliceB.obsm["spatial"][i]) @@ -195,7 +197,9 @@ def select_overlap_fraction(sliceA, sliceB, alpha=0.1, show_plot=True, numIterma to_dense_array(extract_data_matrix(sliceA, None)), to_dense_array(extract_data_matrix(sliceB, None)), ) - M = glmpca_distance(A_X, B_X, latent_dim=50, filter=True, maxIter=numItermax) + M = torch.Tensor( + glmpca_distance(A_X, B_X, latent_dim=50, filter=True, maxIter=numItermax) + ).double() m_to_pi = {} for m in overlap_to_check: diff --git a/src/paste3/paste.py b/src/paste3/paste.py index 3b9b60e..6ce8840 100644 --- a/src/paste3/paste.py +++ b/src/paste3/paste.py @@ -1,9 +1,9 @@ from typing import List, Tuple, Optional +import torch import numpy as np from anndata import AnnData import ot from ot.lp import emd -from scipy.spatial import distance from sklearn.decomposition import NMF from paste3.helper import ( intersect, @@ -26,8 +26,8 @@ def pairwise_align( b_distribution=None, norm: bool = False, numItermax: int = 200, - backend=ot.backend.NumpyBackend(), - use_gpu: bool = False, + backend=ot.backend.TorchBackend(), + use_gpu: bool = True, return_obj: bool = False, verbose: bool = False, gpu_verbose: bool = True, @@ -116,14 +116,17 @@ def pairwise_align( D_A = ot.dist(coordinatesA, coordinatesA, metric="euclidean") D_B = ot.dist(coordinatesB, coordinatesB, metric="euclidean") - if isinstance(nx, ot.backend.TorchBackend) and use_gpu: + if isinstance(nx, ot.backend.TorchBackend): + D_A = D_A.double() + D_B = D_B.double() + if use_gpu: D_A = D_A.cuda() D_B = D_B.cuda() # Calculate expression dissimilarity A_X, B_X = ( - nx.from_numpy(to_dense_array(extract_data_matrix(sliceA, use_rep))), - nx.from_numpy(to_dense_array(extract_data_matrix(sliceB, use_rep))), + to_dense_array(extract_data_matrix(sliceA, use_rep)), + to_dense_array(extract_data_matrix(sliceB, use_rep)), ) if isinstance(nx, ot.backend.TorchBackend) and use_gpu: @@ -144,11 +147,17 @@ def pairwise_align( eps=eps, optimizeTheta=optimizeTheta, ) - M = nx.from_numpy(M) if is_histology: # Calculate RGB dissimilarity - M_rgb = distance.cdist(sliceA.obsm["rgb"], sliceB.obsm["rgb"]) + M_rgb = ( + torch.cdist( + torch.Tensor(sliceA.obsm["rgb"]).double(), + torch.Tensor(sliceB.obsm["rgb"]).double(), + ) + .to(M.dtype) + .to(M.device) + ) # Scale M_exp and M_rgb, obtain M by taking half from each M_rgb /= M_rgb[M_rgb > 0].max() @@ -167,6 +176,7 @@ def pairwise_align( b = nx.from_numpy(b_distribution) if isinstance(nx, ot.backend.TorchBackend): + M = M.double() a = a.double() b = b.double() if use_gpu: @@ -184,11 +194,8 @@ def pairwise_align( D_B *= M.max() # Run OT - if G_init is not None: - G_init = nx.from_numpy(G_init) - if isinstance(nx, ot.backend.TorchBackend): - if use_gpu: - G_init.cuda() + if G_init is not None and use_gpu: + G_init.cuda() pi, log = my_fused_gromov_wasserstein( M, D_A, @@ -205,10 +212,7 @@ def pairwise_align( verbose=verbose, ) if not s: - pi = nx.to_numpy(pi) - log = nx.to_numpy(log["fgw_dist"]) - if isinstance(backend, ot.backend.TorchBackend) and use_gpu: - torch.cuda.empty_cache() + log = log["fgw_dist"].item() if return_obj: return pi, log @@ -228,8 +232,8 @@ def center_align( random_seed: Optional[int] = None, pis_init: Optional[List[np.ndarray]] = None, distributions=None, - backend=ot.backend.NumpyBackend(), - use_gpu: bool = False, + backend=ot.backend.TorchBackend(), + use_gpu: bool = True, verbose: bool = False, gpu_verbose: bool = True, ) -> Tuple[AnnData, List[np.ndarray]]: @@ -394,11 +398,17 @@ def center_align( center_slice.X = np.dot(W, H) center_slice.uns["paste_W"] = W center_slice.uns["paste_H"] = H - center_slice.uns["full_rank"] = center_slice.shape[0] * sum( - [ - lmbda[i] * np.dot(pis[i], to_dense_array(slices[i].X)) - for i in range(len(slices)) - ] + center_slice.uns["full_rank"] = ( + center_slice.shape[0] + * sum( + [ + lmbda[i] + * torch.matmul(pis[i], to_dense_array(slices[i].X).to(pis[i].device)) + for i in range(len(slices)) + ] + ) + .cpu() + .numpy() ) center_slice.uns["obj"] = R return center_slice, pis @@ -469,7 +479,8 @@ def center_NMF( n = W.shape[0] B = n * sum( [ - lmbda[i] * np.dot(pis[i], to_dense_array(slices[i].X)) + lmbda[i] + * torch.matmul(pis[i], to_dense_array(slices[i].X).to(pis[i].device)) for i in range(len(slices)) ] ) @@ -489,7 +500,7 @@ def center_NMF( random_state=random_seed, verbose=verbose, ) - W_new = model.fit_transform(B) + W_new = model.fit_transform(B.cpu().numpy()) H_new = model.components_ return W_new, H_new @@ -509,7 +520,7 @@ def my_fused_gromov_wasserstein( numItermax=200, tol_rel=1e-9, tol_abs=1e-9, - use_gpu=False, + use_gpu=True, numItermaxEmd=100000, **kwargs, ): @@ -527,7 +538,7 @@ def my_fused_gromov_wasserstein( raise ValueError( "Problem infeasible. Parameter m should be greater" " than 0." ) - elif m > np.min((np.sum(p), np.sum(q))): + elif m > min(p.sum(), q.sum()): raise ValueError( "Problem infeasible. Parameter m should lower or" " equal to min(|p|_1, |q|_1)." @@ -537,8 +548,8 @@ def my_fused_gromov_wasserstein( _log = {"err": []} count = 0 dummy = 1 - _p = np.append(p, [(np.sum(q) - m) / dummy] * dummy) - _q = np.append(q, [(np.sum(q) - m) / dummy] * dummy) + _p = torch.cat([p, torch.Tensor([(q.sum() - m) / dummy] * dummy).to(p.device)]) + _q = torch.cat([q, torch.Tensor([(q.sum() - m) / dummy] * dummy).to(p.device)]) if G0 is not None: G0 = (1 / nx.sum(G0)) * G0 @@ -549,8 +560,8 @@ def f(G): constC, hC1, hC2 = ot.gromov.init_matrix( C1, C2, - nx.sum(G, axis=1).reshape(-1, 1), - nx.sum(G, axis=0).reshape(1, -1), + nx.sum(G, axis=1).reshape(-1, 1).to(C1.dtype), + nx.sum(G, axis=0).reshape(1, -1).to(C2.dtype), loss_fun, ) return ot.gromov.gwloss(constC, hC1, hC2, G) @@ -574,7 +585,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): if log: # keep track of error only on every 10th iteration if count % 10 == 0: - _log["err"].append(np.linalg.norm(deltaG)) + _log["err"].append(torch.norm(deltaG)) count += 1 if armijo: @@ -593,8 +604,8 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): def lp_solver(a, b, M, **kwargs): if m: - _M = np.pad(M, [(0, dummy)] * 2, mode="constant") - _M[-dummy:, -dummy:] = np.max(M) * 1e2 + _M = torch.nn.functional.pad(M, pad=(0, dummy, 0, dummy), mode="constant") + _M[-dummy:, -dummy:] = torch.max(M) * 1e2 Gc, innerlog_ = emd(_p, _q, _M, 1000000, log=True) if innerlog_.get("warning"): @@ -711,14 +722,14 @@ def line_search_partial(reg, M, G, C1, C2, deltaG, loss_fun="square_loss"): constC, hC1, hC2 = ot.gromov.init_matrix( C1, C2, - np.sum(deltaG, axis=1).reshape(-1, 1), - np.sum(deltaG, axis=0).reshape(1, -1), + torch.sum(deltaG, axis=1).reshape(-1, 1), + torch.sum(deltaG, axis=0).reshape(1, -1), loss_fun, ) - dot = np.dot(np.dot(C1, deltaG), C2.T) - a = reg * np.sum(dot * deltaG) - b = (1 - reg) * np.sum(M * deltaG) + 2 * reg * np.sum( + dot = torch.matmul(torch.matmul(C1, deltaG), C2.T) + a = reg * torch.sum(dot * deltaG) + b = (1 - reg) * torch.sum(M * deltaG) + 2 * reg * torch.sum( ot.gromov.gwggrad(constC, hC1, hC2, deltaG) * 0.5 * G ) alpha = ot.optim.solve_1d_linesearch_quad(a, b) @@ -726,9 +737,9 @@ def line_search_partial(reg, M, G, C1, C2, deltaG, loss_fun="square_loss"): constC, hC1, hC2 = ot.gromov.init_matrix( C1, C2, - np.sum(G, axis=1).reshape(-1, 1), - np.sum(G, axis=0).reshape(1, -1), + torch.sum(G, axis=1).reshape(-1, 1), + torch.sum(G, axis=0).reshape(1, -1), loss_fun, ) - cost_G = (1 - reg) * np.sum(M * G) + reg * ot.gromov.gwloss(constC, hC1, hC2, G) + cost_G = (1 - reg) * torch.sum(M * G) + reg * ot.gromov.gwloss(constC, hC1, hC2, G) return alpha, a, cost_G diff --git a/src/paste3/visualization.py b/src/paste3/visualization.py index dd3d6b1..c4984dc 100644 --- a/src/paste3/visualization.py +++ b/src/paste3/visualization.py @@ -1,6 +1,7 @@ from typing import List, Tuple, Optional from anndata import AnnData import numpy as np +import torch import seaborn as sns import matplotlib.pyplot as plt @@ -47,8 +48,8 @@ def stack_slices_pairwise( thetas = [] translations = [] result = generalized_procrustes_analysis( - slices[0].obsm["spatial"], - slices[1].obsm["spatial"], + torch.Tensor(slices[0].obsm["spatial"]).to(pis[0].dtype).to(pis[0].device), + torch.Tensor(slices[1].obsm["spatial"]).to(pis[0].dtype).to(pis[0].device), pis[0], is_partial=is_partial, output_params=output_params, @@ -65,7 +66,9 @@ def stack_slices_pairwise( for i in range(1, len(slices) - 1): result = generalized_procrustes_analysis( new_coor[i], - slices[i + 1].obsm["spatial"], + torch.Tensor(slices[i + 1].obsm["spatial"]) + .to(pis[i].dtype) + .to(pis[i].device), pis[i], is_partial=is_partial, output_params=output_params, @@ -85,7 +88,7 @@ def stack_slices_pairwise( new_slices = [] for i in range(len(slices)): s = slices[i].copy() - s.obsm["spatial"] = new_coor[i] + s.obsm["spatial"] = new_coor[i].cpu().numpy() new_slices.append(s) if not output_params: @@ -141,12 +144,22 @@ def stack_slices_center( for i in range(len(slices)): if not output_params: c, y = generalized_procrustes_analysis( - center_slice.obsm["spatial"], slices[i].obsm["spatial"], pis[i] + torch.Tensor(center_slice.obsm["spatial"]) + .to(pis[i].dtype) + .to(pis[i].device), + torch.Tensor(slices[i].obsm["spatial"]) + .to(pis[i].dtype) + .to(pis[i].device), + pis[i], ) else: c, y, theta, tX, tY = generalized_procrustes_analysis( - center_slice.obsm["spatial"], - slices[i].obsm["spatial"], + torch.Tensor(center_slice.obsm["spatial"]) + .to(pis[i].dtype) + .to(pis[i].device), + torch.Tensor(slices[i].obsm["spatial"]) + .to(pis[i].dtype) + .to(pis[i].device), pis[i], output_params=output_params, matrix=matrix, @@ -158,11 +171,11 @@ def stack_slices_center( new_slices = [] for i in range(len(slices)): s = slices[i].copy() - s.obsm["spatial"] = new_coor[i] + s.obsm["spatial"] = new_coor[i].cpu().numpy() new_slices.append(s) new_center = center_slice.copy() - new_center.obsm["spatial"] = c + new_center.obsm["spatial"] = c.cpu().numpy() if not output_params: return new_center, new_slices else: @@ -215,21 +228,21 @@ def generalized_procrustes_analysis( """ assert X.shape[1] == 2 and Y.shape[1] == 2 - tX = pi.sum(axis=1).dot(X) - tY = pi.sum(axis=0).dot(Y) + tX = pi.sum(axis=1).matmul(X) + tY = pi.sum(axis=0).matmul(Y) X = X - tX Y = Y - tY if is_partial: - m = np.sum(pi) + m = torch.sum(pi) X = X * (1.0 / m) Y = Y * (1.0 / m) - H = Y.T.dot(pi.T.dot(X)) - U, S, Vt = np.linalg.svd(H) - R = Vt.T.dot(U.T) - Y = R.dot(Y.T).T + H = Y.T.matmul(pi.T.matmul(X)) + U, S, Vt = torch.linalg.svd(H, full_matrices=True) + R = Vt.T.matmul(U.T) + Y = R.matmul(Y.T).T if output_params and not matrix: - M = np.array([[0, -1], [1, 0]]) - theta = np.arctan(np.trace(M.dot(H)) / np.trace(H)) + M = torch.Tensor([[0, -1], [1, 0]]).double() + theta = torch.arctan(torch.trace(M.matmul(H)) / torch.trace(H)) return X, Y, theta, tX, tY elif output_params and matrix: return X, Y, R, tX, tY diff --git a/tests/conftest.py b/tests/conftest.py index f0c353a..5a3a11b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,7 +29,7 @@ def slices(): @pytest.fixture(scope="session") def spot_distance_matrix(slices): - nx = ot.backend.NumpyBackend() + nx = ot.backend.TorchBackend() spot_distances = [] for slice in slices: diff --git a/tests/test_paste.py b/tests/test_paste.py index fae361a..81686ee 100644 --- a/tests/test_paste.py +++ b/tests/test_paste.py @@ -1,6 +1,7 @@ import hashlib from pathlib import Path import numpy as np +import torch import ot.backend import pandas as pd import tempfile @@ -15,24 +16,28 @@ line_search_partial, ) from pandas.testing import assert_frame_equal +import pytest test_dir = Path(__file__).parent input_dir = test_dir / "data/input" output_dir = test_dir / "data/output" -def assert_checksum_equals(temp_dir, filename): +def assert_checksum_equals(temp_dir, filename, loose=False): generated_file = temp_dir / filename oracle = output_dir / filename - assert ( - hashlib.md5( - "".join(open(generated_file, "r").readlines()).encode("utf8") - ).hexdigest() - == hashlib.md5( - "".join(open(oracle, "r").readlines()).encode("utf8") - ).hexdigest() - ) + if loose: + assert_frame_equal(pd.read_csv(generated_file), pd.read_csv(oracle)) + else: + assert ( + hashlib.md5( + "".join(open(generated_file, "r").readlines()).encode("utf8") + ).hexdigest() + == hashlib.md5( + "".join(open(oracle, "r").readlines()).encode("utf8") + ).hexdigest() + ) def test_pairwise_alignment(slices): @@ -48,7 +53,7 @@ def test_pairwise_alignment(slices): backend=ot.backend.TorchBackend(), ) probability_mapping = pd.DataFrame( - outcome, index=slices[0].obs.index, columns=slices[1].obs.index + outcome.cpu().numpy(), index=slices[0].obs.index, columns=slices[1].obs.index ) true_probability_mapping = pd.read_csv( output_dir / "slices_1_2_pairwise.csv", index_col=0 @@ -99,7 +104,7 @@ def test_center_alignment(slices): for i, pi in enumerate(pairwise_info): pairwise_mapping = pd.DataFrame( - pi, index=center_slice.obs.index, columns=slices[i].obs.index + pi.cpu().numpy(), index=center_slice.obs.index, columns=slices[i].obs.index ) true_pairwise_mapping = pd.read_csv( output_dir / f"center_slice{i + 1}_pairwise.csv", index_col=0 @@ -121,9 +126,9 @@ def test_center_ot(slices): slices=slices, center_coordinates=intersecting_slice.obsm["spatial"], common_genes=common_genes, - use_gpu=False, + use_gpu=True, alpha=0.1, - backend=ot.backend.NumpyBackend(), + backend=ot.backend.TorchBackend(), dissimilarity="kl", norm=False, G_inits=[None for _ in range(len(slices))], @@ -136,20 +141,24 @@ def test_center_ot(slices): -25.740615316378296, ] - assert np.all(np.isclose(expected_r, r, rtol=1e-05, atol=1e-08, equal_nan=True)) + assert np.allclose(expected_r, r) for i, pi in enumerate(pairwise_info): pd.DataFrame( - pi, index=intersecting_slice.obs.index, columns=slices[i].obs.index + pi.cpu().numpy(), + index=intersecting_slice.obs.index, + columns=slices[i].obs.index, ).to_csv(temp_dir / f"center_ot{i + 1}_pairwise.csv") - assert_checksum_equals(temp_dir, f"center_ot{i + 1}_pairwise.csv") + assert_checksum_equals(temp_dir, f"center_ot{i + 1}_pairwise.csv", loose=True) def test_center_NMF(intersecting_slices): n_slices = len(intersecting_slices) pairwise_info = [ - np.genfromtxt(input_dir / f"center_ot{i+1}_pairwise.csv", delimiter=",") + torch.Tensor( + np.genfromtxt(input_dir / f"center_ot{i+1}_pairwise.csv", delimiter=",") + ).double() for i in range(n_slices) ] @@ -184,15 +193,17 @@ def test_center_NMF(intersecting_slices): def test_fused_gromov_wasserstein(slices, spot_distance_matrix): temp_dir = Path(tempfile.mkdtemp()) - nx = ot.backend.NumpyBackend() + nx = ot.backend.TorchBackend() - M = np.genfromtxt(input_dir / "gene_distance.csv", delimiter=",") + M = torch.Tensor( + np.genfromtxt(input_dir / "gene_distance.csv", delimiter=",") + ).double() pairwise_info, log = my_fused_gromov_wasserstein( M, spot_distance_matrix[0], spot_distance_matrix[1], - p=nx.ones((254,)) / 254, - q=nx.ones((251,)) / 251, + p=nx.ones((254,)).double() / 254, + q=nx.ones((251,)).double() / 251, alpha=0.1, G0=None, loss_fun="square_loss", @@ -206,10 +217,12 @@ def test_fused_gromov_wasserstein(slices, spot_distance_matrix): def test_gromov_linesearch(slices, spot_distance_matrix): - nx = ot.backend.NumpyBackend() + nx = ot.backend.TorchBackend() - G = 1.509115054931788e-05 * np.ones((251, 264)) - deltaG = np.genfromtxt(input_dir / "deltaG.csv", delimiter=",") + G = 1.509115054931788e-05 * torch.ones((251, 264)).double() + deltaG = torch.Tensor( + np.genfromtxt(input_dir / "deltaG.csv", delimiter=",") + ).double() costG = 6.0935270338235075 alpha, fc, cost_G = solve_gromov_linesearch( @@ -224,13 +237,17 @@ def test_gromov_linesearch(slices, spot_distance_matrix): ) assert alpha == 1.0 assert fc == 1 - assert round(cost_G, 5) == -11.20545 + assert pytest.approx(cost_G) == -11.20545 def test_line_search_partial(slices, spot_distance_matrix): - G = 1.509115054931788e-05 * np.ones((251, 264)) - deltaG = np.genfromtxt(input_dir / "deltaG.csv", delimiter=",") - M = np.genfromtxt(input_dir / "euc_dissimilarity.csv", delimiter=",") + G = 1.509115054931788e-05 * torch.ones((251, 264)).double() + deltaG = torch.Tensor( + np.genfromtxt(input_dir / "deltaG.csv", delimiter=",") + ).double() + M = torch.Tensor( + np.genfromtxt(input_dir / "euc_dissimilarity.csv", delimiter=",") + ).double() alpha, a, cost_G = line_search_partial( reg=0.1, @@ -241,5 +258,5 @@ def test_line_search_partial(slices, spot_distance_matrix): deltaG=deltaG, ) assert alpha == 1.0 - assert a == 0.4858849047237918 - assert cost_G == 102.6333512778727 + assert pytest.approx(a) == 0.4858849047237918 + assert pytest.approx(cost_G) == 102.6333512778727 diff --git a/tests/test_paste2.py b/tests/test_paste2.py index 109da84..2797bf9 100644 --- a/tests/test_paste2.py +++ b/tests/test_paste2.py @@ -1,6 +1,7 @@ from pathlib import Path import pandas as pd import numpy as np +import torch from paste3.helper import intersect import pytest from unittest.mock import patch @@ -18,7 +19,7 @@ def test_partial_pairwise_align_glmpca(fn, slices2): # Load pre-computed dissimilarity metrics, # since it is time-consuming to compute. data = np.load(output_dir / "test_partial_pairwise_align.npz") - fn.return_value = data["glmpca"] + fn.return_value = torch.Tensor(data["glmpca"]).double() pi_BC = pairwise_align( slices2[0], @@ -30,7 +31,7 @@ def test_partial_pairwise_align_glmpca(fn, slices2): maxIter=10, ) - assert np.allclose(pi_BC, data["pi_BC"]) + assert np.allclose(pi_BC.cpu().numpy(), data["pi_BC"], atol=1e-7) def test_partial_pairwise_align_given_cost_matrix(slices): @@ -38,9 +39,11 @@ def test_partial_pairwise_align_given_cost_matrix(slices): sliceA = slices[1][:, common_genes] sliceB = slices[2][:, common_genes] - glmpca_distance_matrix = np.genfromtxt( - input_dir / "glmpca_distance_matrix.csv", delimiter=",", skip_header=1 - ) + glmpca_distance_matrix = torch.Tensor( + np.genfromtxt( + input_dir / "glmpca_distance_matrix.csv", delimiter=",", skip_header=1 + ) + ).double() pairwise_info, log = pairwise_align( sliceA, @@ -57,11 +60,11 @@ def test_partial_pairwise_align_given_cost_matrix(slices): ) assert_frame_equal( - pd.DataFrame(pairwise_info, columns=[str(i) for i in range(264)]), + pd.DataFrame(pairwise_info.cpu().numpy(), columns=[str(i) for i in range(264)]), pd.read_csv(output_dir / "align_given_cost_matrix_pairwise_info.csv"), - rtol=1e-05, + rtol=1e-04, ) - assert log["partial_fgw_cost"] == pytest.approx(40.86494022326222) + assert log["partial_fgw_cost"].cpu().numpy() == pytest.approx(40.86494022326222) def test_partial_pairwise_align_histology(slices2): @@ -78,11 +81,11 @@ def test_partial_pairwise_align_histology(slices2): maxIter=10, is_histology=True, ) - assert log["partial_fgw_cost"] == pytest.approx(88.06713721008786) - assert_frame_equal( - pd.DataFrame(pairwise_info, columns=[str(i) for i in range(2877)]), - pd.read_csv(output_dir / "partial_pairwise_align_histology.csv"), - rtol=1e-05, + assert log["partial_fgw_cost"].cpu().numpy() == pytest.approx(88.06713721008786) + assert np.allclose( + pairwise_info.cpu().numpy(), + pd.read_csv(output_dir / "partial_pairwise_align_histology.csv").to_numpy(), + atol=1e-7, ) @@ -155,11 +158,11 @@ def test_partial_fused_gromov_wasserstein(slices, armijo, expected_log, filename distance_b *= glmpca_distance_matrix.max() pairwise_info, log = my_fused_gromov_wasserstein( - glmpca_distance_matrix, - distance_a, - distance_b, - np.ones((sliceA.shape[0],)) / sliceA.shape[0], - np.ones((sliceB.shape[0],)) / sliceB.shape[0], + torch.Tensor(glmpca_distance_matrix).double(), + torch.Tensor(distance_a).double(), + torch.Tensor(distance_b).double(), + torch.ones((sliceA.shape[0],)).double() / sliceA.shape[0], + torch.ones((sliceB.shape[0],)).double() / sliceB.shape[0], alpha=0.1, m=0.7, G0=None, @@ -168,10 +171,8 @@ def test_partial_fused_gromov_wasserstein(slices, armijo, expected_log, filename log=True, ) - assert_frame_equal( - pd.DataFrame(pairwise_info, columns=[str(i) for i in range(264)]), - pd.read_csv(output_dir / filename), - rtol=1e-05, + assert np.allclose( + pd.read_csv(output_dir / filename).to_numpy(), pairwise_info, atol=1e-7 ) for k, v in expected_log.items(): diff --git a/tests/test_paste_helpers.py b/tests/test_paste_helpers.py index 010fdaf..dd4fe02 100644 --- a/tests/test_paste_helpers.py +++ b/tests/test_paste_helpers.py @@ -1,5 +1,6 @@ from pathlib import Path import numpy as np +import torch import pandas as pd import pytest from pandas.testing import assert_frame_equal @@ -56,8 +57,8 @@ def test_kl_divergence_backend(slices): def test_kl_divergence(slices): - X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - Y = np.array([[2, 4, 6], [8, 10, 12], [14, 16, 28]]) + X = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).double() + Y = torch.Tensor([[2, 4, 6], [8, 10, 12], [14, 16, 28]]).double() kl_divergence_matrix = kl_divergence(X, Y) expected_kl_divergence_matrix = np.array( @@ -67,12 +68,9 @@ def test_kl_divergence(slices): [0.05534049, 0.00193493, 0.02355472], ] ) - assert np.all( - np.isclose( - kl_divergence_matrix, - expected_kl_divergence_matrix, - rtol=1e-04, - ) + assert np.allclose( + kl_divergence_matrix, + expected_kl_divergence_matrix, ) @@ -87,8 +85,8 @@ def test_filter_for_common_genes(slices): def test_generalized_kl_divergence(slices): - X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - Y = np.array([[2, 4, 6], [8, 10, 12], [14, 16, 28]]) + X = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).double() + Y = torch.Tensor([[2, 4, 6], [8, 10, 12], [14, 16, 28]]).double() generalized_kl_divergence_matrix = generalized_kl_divergence(X, Y) expected_kl_divergence_matrix = np.array( @@ -98,22 +96,19 @@ def test_generalized_kl_divergence(slices): [5.9637042, 0.69099319, 13.3879729], ] ) - assert np.all( - np.isclose( - generalized_kl_divergence_matrix, - expected_kl_divergence_matrix, - rtol=1e-04, - ) + assert np.allclose( + generalized_kl_divergence_matrix, + expected_kl_divergence_matrix, ) def test_glmpca_distance(): - sliceA_X = np.genfromtxt(input_dir / "sliceA_X.csv", delimiter=",", skip_header=1)[ - 10:, :1000 - ] - sliceB_X = np.genfromtxt(input_dir / "sliceB_X.csv", delimiter=",", skip_header=1)[ - 10:, :1000 - ] + sliceA_X = torch.Tensor( + np.genfromtxt(input_dir / "sliceA_X.csv", delimiter=",", skip_header=1) + ).double()[10:, :1000] + sliceB_X = torch.Tensor( + np.genfromtxt(input_dir / "sliceB_X.csv", delimiter=",", skip_header=1) + ).double()[10:, :1000] glmpca_distance_matrix = glmpca_distance( sliceA_X, sliceB_X, latent_dim=10, filter=True, maxIter=10 diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 8f6326b..a6a8fe7 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,3 +1,4 @@ +import torch import numpy as np import pandas as pd import scanpy as sc @@ -18,7 +19,9 @@ def test_stack_slices_pairwise(slices): n_slices = len(slices) pairwise_info = [ - np.genfromtxt(input_dir / f"slices_{i}_{i + 1}_pairwise.csv", delimiter=",") + torch.Tensor( + np.genfromtxt(input_dir / f"slices_{i}_{i + 1}_pairwise.csv", delimiter=",") + ).double() for i in range(1, n_slices) ] @@ -30,6 +33,7 @@ def test_stack_slices_pairwise(slices): assert_frame_equal( pd.DataFrame(slice.obsm["spatial"], columns=["0", "1"]), pd.read_csv(output_dir / f"aligned_spatial_{i}_{i + 1}.csv"), + atol=1e-6, ) expected_thetas = [-0.25086326614894794, 0.5228805289947901, 0.02478065908672744] @@ -54,7 +58,9 @@ def test_stack_slices_center(slices): center_slice = sc.read_h5ad(input_dir / "center_slice.h5ad") pairwise_info = [ - np.genfromtxt(input_dir / f"center_slice{i}_pairwise.csv", delimiter=",") + torch.Tensor( + np.genfromtxt(input_dir / f"center_slice{i}_pairwise.csv", delimiter=",") + ).double() for i in range(1, len(slices) + 1) ] @@ -64,12 +70,14 @@ def test_stack_slices_center(slices): assert_frame_equal( pd.DataFrame(new_center.obsm["spatial"], columns=["0", "1"]), pd.read_csv(output_dir / "aligned_spatial_center.csv"), + atol=1e-6, ) for i, slice in enumerate(new_slices): assert_frame_equal( pd.DataFrame(slice.obsm["spatial"], columns=["0", "1"]), pd.read_csv(output_dir / f"slice{i}_stack_slices_center.csv"), + atol=1e-6, ) expected_thetas = [ @@ -98,14 +106,14 @@ def test_stack_slices_center(slices): def test_generalized_procrustes_analysis(slices): center_slice = sc.read_h5ad(input_dir / "center_slice.h5ad") - pairwise_info = np.genfromtxt( - input_dir / "center_slice1_pairwise.csv", delimiter="," - ) + pairwise_info = torch.Tensor( + np.genfromtxt(input_dir / "center_slice1_pairwise.csv", delimiter=",") + ).double() aligned_center, aligned_slice, theta, translation_x, translation_y = ( generalized_procrustes_analysis( - center_slice.obsm["spatial"], - slices[0].obsm["spatial"], + torch.Tensor(center_slice.obsm["spatial"]).double(), + torch.Tensor(slices[0].obsm["spatial"]).double(), pairwise_info, output_params=True, ) @@ -114,10 +122,12 @@ def test_generalized_procrustes_analysis(slices): assert_frame_equal( pd.DataFrame(aligned_center, columns=["0", "1"]), pd.read_csv(output_dir / "aligned_center.csv"), + atol=1e-6, ) assert_frame_equal( pd.DataFrame(aligned_slice, columns=["0", "1"]), pd.read_csv(output_dir / "aligned_slice.csv"), + atol=1e-6, ) expected_theta = 0.0 expected_translation_x = [16.44623228, 16.73757874] @@ -150,7 +160,9 @@ def test_partial_stack_slices_pairwise(slices): n_slices = len(slices) pairwise_info = [ - np.genfromtxt(input_dir / f"slices_{i}_{i + 1}_pairwise.csv", delimiter=",") + torch.Tensor( + np.genfromtxt(input_dir / f"slices_{i}_{i + 1}_pairwise.csv", delimiter=",") + ).double() for i in range(1, n_slices) ] @@ -160,19 +172,20 @@ def test_partial_stack_slices_pairwise(slices): assert_frame_equal( pd.DataFrame(slice.obsm["spatial"], columns=["0", "1"]), pd.read_csv(output_dir / f"aligned_spatial_{i}_{i + 1}.csv"), + atol=1e-6, ) def test_partial_procrustes_analysis(slices): center_slice = sc.read_h5ad(input_dir / "center_slice.h5ad") - pairwise_info = np.genfromtxt( - input_dir / "center_slice1_pairwise.csv", delimiter="," - ) + pairwise_info = torch.Tensor( + np.genfromtxt(input_dir / "center_slice1_pairwise.csv", delimiter=",") + ).double() aligned_center, aligned_slice = generalized_procrustes_analysis( - center_slice.obsm["spatial"], - slices[0].obsm["spatial"], + torch.Tensor(center_slice.obsm["spatial"]).double(), + torch.Tensor(slices[0].obsm["spatial"]).double(), pairwise_info, is_partial=True, ) @@ -180,8 +193,10 @@ def test_partial_procrustes_analysis(slices): assert_frame_equal( pd.DataFrame(aligned_center, columns=["0", "1"]), pd.read_csv(output_dir / "aligned_center.csv"), + atol=1e-6, ) assert_frame_equal( pd.DataFrame(aligned_slice, columns=["0", "1"]), pd.read_csv(output_dir / "aligned_slice.csv"), + atol=1e-6, )