Skip to content

Commit

Permalink
Code modifications to work with ot.TorchBackend (#77)
Browse files Browse the repository at this point in the history
* Code modifications to work with ot.TorchBackend; changed backend for align submodule
  • Loading branch information
vineetbansal authored Oct 23, 2024
1 parent 1fbe278 commit e0c08f6
Show file tree
Hide file tree
Showing 13 changed files with 407 additions and 192 deletions.
6 changes: 3 additions & 3 deletions docs/source/notebooks/paste2_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,15 @@
"):\n",
" if (\"color\" not in kwargs) and (\"c\" not in kwargs):\n",
" kwargs[\"color\"] = \"k\"\n",
" mx = G.max()\n",
" mx = G.max().item()\n",
" # idx = np.where(G/mx>=thr)\n",
" idx = largest_indices(G, top)\n",
" idx = largest_indices(G.cpu().numpy(), top)\n",
" for i in range(len(idx[0])):\n",
" plt.plot(\n",
" [xs[idx[0][i], 0], xt[idx[1][i], 0]],\n",
" [xs[idx[0][i], 1], xt[idx[1][i], 1]],\n",
" alpha=alpha * (1 - weight_alpha)\n",
" + (weight_alpha * G[idx[0][i], idx[1][i]] / mx),\n",
" + (weight_alpha * G[idx[0][i], idx[1][i]].item() / mx),\n",
" c=\"k\",\n",
" )\n",
"\n",
Expand Down
19 changes: 12 additions & 7 deletions docs/source/notebooks/paste_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"import time\n",
"import pandas as pd\n",
"import numpy as np\n",
"import torch\n",
"import scanpy as sc\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.patches as mpatches\n",
Expand Down Expand Up @@ -251,7 +252,7 @@
},
"outputs": [],
"source": [
"pd.DataFrame(pi12)"
"pd.DataFrame(pi12.cpu().numpy())"
]
},
{
Expand Down Expand Up @@ -454,7 +455,11 @@
"\n",
"b = []\n",
"for i in range(len(slices)):\n",
" b.append(match_spots_using_spatial_heuristic(slices[0].X, slices[i].X))"
" b.append(\n",
" torch.Tensor(\n",
" match_spots_using_spatial_heuristic(slices[0].X, slices[i].X)\n",
" ).double()\n",
" )"
]
},
{
Expand Down Expand Up @@ -796,9 +801,9 @@
"source": [
"start = time.time()\n",
"\n",
"pi12 = pairwise_align(slice1, slice2, backend=ot.backend.TorchBackend(), use_gpu=False)\n",
"pi23 = pairwise_align(slice2, slice3, backend=ot.backend.TorchBackend(), use_gpu=False)\n",
"pi34 = pairwise_align(slice3, slice4, backend=ot.backend.TorchBackend(), use_gpu=False)\n",
"pi12 = pairwise_align(slice1, slice2, backend=ot.backend.TorchBackend(), use_gpu=True)\n",
"pi23 = pairwise_align(slice2, slice3, backend=ot.backend.TorchBackend(), use_gpu=True)\n",
"pi34 = pairwise_align(slice3, slice4, backend=ot.backend.TorchBackend(), use_gpu=True)\n",
"\n",
"print(\"Runtime: \" + str(time.time() - start))"
]
Expand All @@ -814,7 +819,7 @@
},
"outputs": [],
"source": [
"pd.DataFrame(pi12)"
"pd.DataFrame(pi12.cpu().numpy())"
]
},
{
Expand Down Expand Up @@ -864,7 +869,7 @@
" lmbda,\n",
" random_seed=5,\n",
" backend=ot.backend.TorchBackend(),\n",
" use_gpu=False,\n",
" use_gpu=True,\n",
")\n",
"\n",
"print(\"Runtime: \" + str(time.time() - start))"
Expand Down
145 changes: 145 additions & 0 deletions scripts/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from typing import Optional, List, Tuple
from pathlib import Path
import scanpy as sc
import numpy as np
from anndata import AnnData
import logging
from paste3.helper import match_spots_using_spatial_heuristic
from paste3.visualization import stack_slices_pairwise, stack_slices_center
from paste3.paste import pairwise_align, center_align


logger = logging.getLogger(__name__)


class Slice:
def __init__(
self, filepath: Optional[Path] = None, adata: Optional[AnnData] = None
):
if adata is None:
self.adata = sc.read_h5ad(filepath)
else:
self.adata = adata

def __str__(self):
return f"Slice {self.adata}"


class AlignmentDataset:
@staticmethod
def from_csvs(gene_expression_csvs: List[Path], coordinate_csvs: List[Path]):
pass

def __init__(
self,
data_dir: Optional[Path] = None,
slices: Optional[List[Slice]] = None,
max_slices: Optional[int] = None,
):
if slices is not None:
self.slices = slices[:max_slices]
else:
self.slices = [
Slice(filepath)
for filepath in sorted(Path(data_dir).glob("*.h5ad"))[:max_slices]
]

def __str__(self):
return f"Data with {len(self.slices)} slices"

def __iter__(self):
return iter(self.slices)

def __len__(self):
return len(self.slices)

@property
def slices_adata(self) -> List[AnnData]:
return [slice_.adata for slice_ in self.slices]

def align(
self,
center_align: bool = False,
center_slice: Optional[Slice] = None,
pis: Optional[np.ndarray] = None,
overlap_fraction: Optional[float] = None,
max_iters: int = 1000,
):
if center_align:
if overlap_fraction is not None:
logger.warning(
"Ignoring overlap_fraction argument (unsupported in center_align mode)"
)
return self.center_align(center_slice, pis)
else:
assert overlap_fraction is not None, "overlap_fraction must be specified"
return self.pairwise_align(
overlap_fraction=overlap_fraction, pis=pis, max_iters=max_iters
)

def find_pis(self, overlap_fraction: float, max_iters: int = 1000):
pis = []
for i in range(len(self) - 1):
logger.info(f"Finding Pi for slices {i} and {i+1}")
pis.append(
pairwise_align(
self.slices[i].adata,
self.slices[i + 1].adata,
s=overlap_fraction,
numItermax=max_iters,
maxIter=max_iters,
verbose=True,
)
)
return pis

def pairwise_align(
self,
overlap_fraction: float,
pis: Optional[List[np.ndarray]] = None,
max_iters: int = 1000,
):
if pis is None:
pis = self.find_pis(overlap_fraction=overlap_fraction, max_iters=max_iters)
new_slices = stack_slices_pairwise(self.slices_adata, pis)
return AlignmentDataset(slices=[Slice(adata=s) for s in new_slices])

def find_center_slice(
self, reference_slice: Optional[Slice] = None, pis: Optional[np.ndarray] = None
) -> Tuple[Slice, List[np.ndarray]]:
if reference_slice is None:
reference_slice = self.slices[0]
center_slice, pis = center_align(
reference_slice.adata, self.slices_adata, pis_init=pis
)
return Slice(adata=center_slice), pis

def find_pis_init(self) -> List[np.ndarray]:
reference_slice = self.slices[0]
return [
match_spots_using_spatial_heuristic(reference_slice.adata.X, slice_.adata.X)
for slice_ in self.slices
]

def center_align(
self,
reference_slice: Optional[Slice] = None,
pis: Optional[List[np.ndarray]] = None,
):
if reference_slice is None:
reference_slice, pis = self.find_center_slice(pis=pis)
else:
pis = self.find_pis_init()

_, new_slices = stack_slices_center(
center_slice=reference_slice.adata, slices=self.slices_adata, pis=pis
)
return AlignmentDataset(slices=[Slice(adata=s) for s in new_slices])


if __name__ == "__main__":
dataset = AlignmentDataset("data/", max_slices=3)
aligned_dataset = dataset.align(
center_align=False, overlap_fraction=0.7, max_iters=2
)
print(aligned_dataset)
14 changes: 9 additions & 5 deletions src/paste3/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def align(
max_iter=10,
norm=False,
numItermax=200,
use_gpu=False,
use_gpu=True,
return_obj=False,
optimizeTheta=True,
eps=1e-4,
Expand Down Expand Up @@ -99,7 +99,7 @@ def align(
b_distribution=slices[i + 1].obsm["weights"],
norm=norm,
numItermax=numItermax,
backend=ot.backend.NumpyBackend(),
backend=ot.backend.TorchBackend(),
use_gpu=use_gpu,
return_obj=return_obj,
maxIter=max_iter,
Expand All @@ -110,7 +110,9 @@ def align(
)
pis.append(pi)
pd.DataFrame(
pi, index=slices[i].obs.index, columns=slices[i + 1].obs.index
pi.cpu().numpy(),
index=slices[i].obs.index,
columns=slices[i + 1].obs.index,
).to_csv(output_directory / f"slice_{i+1}_{i+2}_pairwise.csv")

if coordinates:
Expand All @@ -135,14 +137,16 @@ def align(
random_seed=seed,
pis_init=pis_init,
distributions=[slice.obsm["weights"] for slice in slices],
backend=ot.backend.NumpyBackend(),
backend=ot.backend.TorchBackend(),
use_gpu=use_gpu,
)

center_slice.write(output_directory / "center_slice.h5ad")
for i in range(len(pis) - 1):
pd.DataFrame(
pis[i], index=center_slice.obs.index, columns=slices[i].obs.index
pis[i].cpu().numpy(),
index=center_slice.obs.index,
columns=slices[i].obs.index,
).to_csv(output_directory / f"slice_{i}_{i+1}_pairwise.csv")

if coordinates:
Expand Down
51 changes: 28 additions & 23 deletions src/paste3/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List
from anndata import AnnData
import numpy as np
import torch
import scipy
import ot

Expand All @@ -22,11 +23,11 @@ def kl_divergence(X, Y):

X = X / X.sum(axis=1, keepdims=True)
Y = Y / Y.sum(axis=1, keepdims=True)
log_X = np.log(X)
log_Y = np.log(Y)
X_log_X = np.matrix([np.dot(X[i], log_X[i].T) for i in range(X.shape[0])])
D = X_log_X.T - np.dot(X, log_Y.T)
return np.asarray(D)
log_X = X.log()
log_Y = Y.log()
X_log_X = torch.sum(X * log_X, axis=1)[torch.newaxis, :]
D = X_log_X.T - torch.matmul(X, log_Y.T)
return D


def generalized_kl_divergence(X, Y):
Expand All @@ -40,14 +41,14 @@ def generalized_kl_divergence(X, Y):
"""
assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."

log_X = np.log(X)
log_Y = np.log(Y)
X_log_X = np.matrix([np.dot(X[i], log_X[i].T) for i in range(X.shape[0])])
D = X_log_X.T - np.dot(X, log_Y.T)
sum_X = np.sum(X, axis=1)
sum_Y = np.sum(Y, axis=1)
log_X = X.log()
log_Y = Y.log()
X_log_X = torch.sum(X * log_X, axis=1)[torch.newaxis, :]
D = X_log_X.T - torch.matmul(X, log_Y.T)
sum_X = torch.sum(X, axis=1)
sum_Y = torch.sum(Y, axis=1)
D = (D.T - sum_X).T + sum_Y.T
return np.asarray(D)
return D


def glmpca_distance(
Expand All @@ -72,15 +73,15 @@ def glmpca_distance(
"""
assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."

joint_matrix = np.vstack((X, Y))
joint_matrix = torch.vstack((X, Y))
if filter:
gene_umi_counts = np.sum(joint_matrix, axis=0)
gene_umi_counts = torch.sum(joint_matrix, axis=0).cpu().numpy()
top_indices = np.sort((-gene_umi_counts).argsort(kind="stable")[:2000])
joint_matrix = joint_matrix[:, top_indices]

print("Starting GLM-PCA...")
res = glmpca(
joint_matrix.T,
joint_matrix.T.cpu().numpy(), # TODO: Use Tensors
latent_dim,
penalty=1,
verbose=verbose,
Expand Down Expand Up @@ -116,13 +117,13 @@ def high_umi_gene_distance(X, Y, n):
"""
assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."

joint_matrix = np.vstack((X, Y))
gene_umi_counts = np.sum(joint_matrix, axis=0)
joint_matrix = torch.vstack((X, Y))
gene_umi_counts = torch.sum(joint_matrix, axis=0).cpu().numpy()
top_indices = np.sort((-gene_umi_counts).argsort(kind="stable")[:n])
X = X[:, top_indices]
Y = Y[:, top_indices]
X += np.tile(0.01 * (np.sum(X, axis=1) / X.shape[1]), (X.shape[1], 1)).T
Y += np.tile(0.01 * (np.sum(Y, axis=1) / Y.shape[1]), (Y.shape[1], 1)).T
X += torch.tile(0.01 * (torch.sum(X, axis=1) / X.shape[1]), (X.shape[1], 1)).T
Y += torch.tile(0.01 * (torch.sum(Y, axis=1) / Y.shape[1]), (Y.shape[1], 1)).T
return kl_divergence(X, Y)


Expand Down Expand Up @@ -150,7 +151,8 @@ def norm_and_center_coordinates(X):

## Covert a sparse matrix into a dense matrix
def to_dense_array(X):
return np.array(X.todense()) if isinstance(X, scipy.sparse.csr.spmatrix) else X
np_array = np.array(X.todense()) if isinstance(X, scipy.sparse.csr.spmatrix) else X
return torch.Tensor(np_array).double()


def extract_data_matrix(adata, rep=None):
Expand Down Expand Up @@ -190,6 +192,7 @@ def match_spots_using_spatial_heuristic(X, Y, use_ot: bool = True) -> np.ndarray
Returns:
Mapping of spots using a spatial heuristic.
"""
# X, Y = X.todense(), Y.todense()
n1, n2 = len(X), len(Y)
X, Y = norm_and_center_coordinates(X), norm_and_center_coordinates(Y)
dist = scipy.spatial.distance_matrix(X, Y)
Expand Down Expand Up @@ -238,7 +241,7 @@ def kl_divergence_backend(X, Y):
def dissimilarity_metric(which, sliceA, sliceB, A, B, **kwargs):
match which:
case "euc" | "euclidean":
return scipy.spatial.distance.cdist(A, B)
return torch.cdist(A, B)
case "gkl":
s_A = A + 0.01
s_B = B + 0.01
Expand All @@ -254,8 +257,10 @@ def dissimilarity_metric(which, sliceA, sliceB, A, B, **kwargs):
case "selection_kl":
return high_umi_gene_distance(A, B, 2000)
case "pca":
return pca_distance(sliceA, sliceB, 2000, 20)
# TODO: Modify this function to work with Tensors
return torch.Tensor(pca_distance(sliceA, sliceB, 2000, 20)).double()
case "glmpca":
return glmpca_distance(A, B, **kwargs)
# TODO: Modify this function to work with Tensors
return torch.Tensor(glmpca_distance(A, B, **kwargs)).double()
case _:
raise RuntimeError(f"Error: Invalid dissimilarity metric {which}")
Loading

0 comments on commit e0c08f6

Please sign in to comment.