diff --git a/pyproject.toml b/pyproject.toml index 50fbf51..1455ecb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,9 @@ classifiers = [ ] dependencies = [ - "anndata", + "anndata==0.10.9", "scanpy", - "POT", + "POT>=0.9.5", "numpy<2", "scipy", "scikit-learn", diff --git a/requirements.txt b/requirements.txt index 2bfc69e..422062e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,7 +45,7 @@ pexpect==4.9.0 pillow==10.4.0 platformdirs==4.3.6 pluggy==1.5.0 -POT==0.9.4 +POT==0.9.5 pre-commit==3.8.0 prompt_toolkit==3.0.48 ptyprocess==0.7.0 diff --git a/src/paste3/paste.py b/src/paste3/paste.py index 9f86292..3331542 100644 --- a/src/paste3/paste.py +++ b/src/paste3/paste.py @@ -786,7 +786,7 @@ def f_gradient(pi): if loss_fun == "kl_loss": armijo = True # there is no closed form line-search with KL - def line_search(f_cost, pi, pi_diff, linearized_matrix, cost_pi, **kwargs): + def line_search(f_cost, pi, pi_diff, linearized_matrix, cost_pi, _, **kwargs): """Solve the linesearch in the fused wasserstein iterations""" if overlap_fraction: nonlocal count @@ -809,14 +809,14 @@ def line_search(f_cost, pi, pi_diff, linearized_matrix, cost_pi, **kwargs): pi_diff, loss_fun=loss_fun, ) - return solve_gromov_linesearch( - pi, - pi_diff, - cost_pi, - a_spatial_dist, - b_spatial_dist, - exp_dissim_matrix=0.0, - alpha=1.0, + return ot.gromov.solve_gromov_linesearch( + G=pi, + deltaG=pi_diff, + cost_G=cost_pi, + C1=a_spatial_dist, + C2=b_spatial_dist, + M=0.0, + reg=2 * 1.0, nx=nx, **kwargs, ) @@ -878,90 +878,6 @@ def lp_solver( return pi, info -def solve_gromov_linesearch( - pi: torch.Tensor, - pi_diff: torch.Tensor, - cost_pi: float, - a_spatial_dist: torch.Tensor, - b_spatial_dist: torch.Tensor, - exp_dissim_matrix: float, - alpha: float, - alpha_min: float | None = None, - alpha_max: float | None = None, - nx: str | None = None, -): - """ - Perform a line search to optimize the transport plan with respect to the Gromov-Wasserstein loss. - - Parameters - ---------- - pi : torch.Tensor - The transport map at a given iteration of the FW - pi_diff : torch.Tensor - Difference between the optimal map found by linearization in the fused wasserstein algorithm and the value at a given iteration - cost_pi : float - Value of the cost at `G` - a_spatial_dist : torch.Tensor - Spot distance matrix in the first slice. - b_spatial_dist : torch.Tensor - Spot distance matrix in the second slice. - exp_dissim_matrix : torch.Tensor - Expression dissimilarity matrix between two slices. - alpha : float - Regularization parameter balancing transcriptional dissimilarity and spatial distance among aligned spots. - Setting \alpha = 0 uses only transcriptional information, while \alpha = 1 uses only spatial coordinates. - alpha_min : float, Optional - Minimum value for alpha - alpha_max : float, Optional - Maximum value for alpha - nx : str, Optional - If let to its default value None, a backend test will be conducted. - - Returns - ------- - minimal_cost : float - The optimal step size of the fused wasserstein - fc : int - Number of function call. (Not used in this case) - cost_pi : float - The final cost after the update of the transport plan. - - .. _references-solve-linesearch: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary RĂ©mi, Tavenard Romain and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - """ - if nx is None: - pi, pi_diff, a_spatial_dist, b_spatial_dist = ot.utils.list_to_array( - pi, pi_diff, a_spatial_dist, b_spatial_dist - ) - - if isinstance(exp_dissim_matrix, (int | float)): - nx = ot.backend.get_backend(pi, pi_diff, a_spatial_dist, b_spatial_dist) - else: - nx = ot.backend.get_backend( - pi, pi_diff, a_spatial_dist, b_spatial_dist, exp_dissim_matrix - ) - - dot = nx.dot(nx.dot(a_spatial_dist, pi_diff), b_spatial_dist.T) - a = -2 * alpha * nx.sum(dot * pi_diff) - b = nx.sum(exp_dissim_matrix * pi_diff) - 2 * alpha * ( - nx.sum(dot * pi) - + nx.sum(nx.dot(nx.dot(a_spatial_dist, pi), b_spatial_dist.T) * pi_diff) - ) - - minimal_cost = ot.optim.solve_1d_linesearch_quad(a, b) - if alpha_min is not None or alpha_max is not None: - minimal_cost = np.clip(minimal_cost, alpha_min, alpha_max) - - # the new cost is deduced from the line search quadratic function - cost_pi = cost_pi + a * (minimal_cost**2) + b * minimal_cost - - return minimal_cost, 1, cost_pi - - def line_search_partial( alpha: float, exp_dissim_matrix: torch.Tensor, diff --git a/tests/test_paste.py b/tests/test_paste.py index c092a27..1ae71e7 100644 --- a/tests/test_paste.py +++ b/tests/test_paste.py @@ -16,7 +16,6 @@ line_search_partial, my_fused_gromov_wasserstein, pairwise_align, - solve_gromov_linesearch, ) test_dir = Path(__file__).parent @@ -223,14 +222,14 @@ def test_gromov_linesearch(spot_distance_matrix): ).double() costG = 6.0935270338235075 - alpha, fc, cost_G = solve_gromov_linesearch( - G, - deltaG, - costG, - spot_distance_matrix[1], - spot_distance_matrix[2], - exp_dissim_matrix=0.0, - alpha=1.0, + alpha, fc, cost_G = ot.gromov.solve_gromov_linesearch( + G=G, + deltaG=deltaG, + cost_G=costG, + C1=spot_distance_matrix[1], + C2=spot_distance_matrix[2], + M=0.0, + reg=2 * 1.0, nx=nx, ) assert alpha == 1.0