Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated paste to be in line with POT; Using line search function from ot library #88

Merged
merged 6 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ classifiers = [
]

dependencies = [
"anndata",
"anndata==0.10.9",
"scanpy",
"POT",
"POT>=0.9.5",
"numpy<2",
"scipy",
"scikit-learn",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pexpect==4.9.0
pillow==10.4.0
platformdirs==4.3.6
pluggy==1.5.0
POT==0.9.4
POT==0.9.5
pre-commit==3.8.0
prompt_toolkit==3.0.48
ptyprocess==0.7.0
Expand Down
102 changes: 9 additions & 93 deletions src/paste3/paste.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def f_gradient(pi):
if loss_fun == "kl_loss":
armijo = True # there is no closed form line-search with KL

def line_search(f_cost, pi, pi_diff, linearized_matrix, cost_pi, **kwargs):
def line_search(f_cost, pi, pi_diff, linearized_matrix, cost_pi, _, **kwargs):
"""Solve the linesearch in the fused wasserstein iterations"""
if overlap_fraction:
nonlocal count
Expand All @@ -809,14 +809,14 @@ def line_search(f_cost, pi, pi_diff, linearized_matrix, cost_pi, **kwargs):
pi_diff,
loss_fun=loss_fun,
)
return solve_gromov_linesearch(
pi,
pi_diff,
cost_pi,
a_spatial_dist,
b_spatial_dist,
exp_dissim_matrix=0.0,
alpha=1.0,
return ot.gromov.solve_gromov_linesearch(
G=pi,
deltaG=pi_diff,
cost_G=cost_pi,
C1=a_spatial_dist,
C2=b_spatial_dist,
M=0.0,
reg=2 * 1.0,
nx=nx,
**kwargs,
)
Expand Down Expand Up @@ -878,90 +878,6 @@ def lp_solver(
return pi, info


def solve_gromov_linesearch(
pi: torch.Tensor,
pi_diff: torch.Tensor,
cost_pi: float,
a_spatial_dist: torch.Tensor,
b_spatial_dist: torch.Tensor,
exp_dissim_matrix: float,
alpha: float,
alpha_min: float | None = None,
alpha_max: float | None = None,
nx: str | None = None,
):
"""
Perform a line search to optimize the transport plan with respect to the Gromov-Wasserstein loss.

Parameters
----------
pi : torch.Tensor
The transport map at a given iteration of the FW
pi_diff : torch.Tensor
Difference between the optimal map found by linearization in the fused wasserstein algorithm and the value at a given iteration
cost_pi : float
Value of the cost at `G`
a_spatial_dist : torch.Tensor
Spot distance matrix in the first slice.
b_spatial_dist : torch.Tensor
Spot distance matrix in the second slice.
exp_dissim_matrix : torch.Tensor
Expression dissimilarity matrix between two slices.
alpha : float
Regularization parameter balancing transcriptional dissimilarity and spatial distance among aligned spots.
Setting \alpha = 0 uses only transcriptional information, while \alpha = 1 uses only spatial coordinates.
alpha_min : float, Optional
Minimum value for alpha
alpha_max : float, Optional
Maximum value for alpha
nx : str, Optional
If let to its default value None, a backend test will be conducted.

Returns
-------
minimal_cost : float
The optimal step size of the fused wasserstein
fc : int
Number of function call. (Not used in this case)
cost_pi : float
The final cost after the update of the transport plan.

.. _references-solve-linesearch:
References
----------
.. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
if nx is None:
pi, pi_diff, a_spatial_dist, b_spatial_dist = ot.utils.list_to_array(
pi, pi_diff, a_spatial_dist, b_spatial_dist
)

if isinstance(exp_dissim_matrix, (int | float)):
nx = ot.backend.get_backend(pi, pi_diff, a_spatial_dist, b_spatial_dist)
else:
nx = ot.backend.get_backend(
pi, pi_diff, a_spatial_dist, b_spatial_dist, exp_dissim_matrix
)

dot = nx.dot(nx.dot(a_spatial_dist, pi_diff), b_spatial_dist.T)
a = -2 * alpha * nx.sum(dot * pi_diff)
b = nx.sum(exp_dissim_matrix * pi_diff) - 2 * alpha * (
nx.sum(dot * pi)
+ nx.sum(nx.dot(nx.dot(a_spatial_dist, pi), b_spatial_dist.T) * pi_diff)
)

minimal_cost = ot.optim.solve_1d_linesearch_quad(a, b)
if alpha_min is not None or alpha_max is not None:
minimal_cost = np.clip(minimal_cost, alpha_min, alpha_max)

# the new cost is deduced from the line search quadratic function
cost_pi = cost_pi + a * (minimal_cost**2) + b * minimal_cost

return minimal_cost, 1, cost_pi


def line_search_partial(
alpha: float,
exp_dissim_matrix: torch.Tensor,
Expand Down
17 changes: 8 additions & 9 deletions tests/test_paste.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
line_search_partial,
my_fused_gromov_wasserstein,
pairwise_align,
solve_gromov_linesearch,
)

test_dir = Path(__file__).parent
Expand Down Expand Up @@ -223,14 +222,14 @@ def test_gromov_linesearch(spot_distance_matrix):
).double()
costG = 6.0935270338235075

alpha, fc, cost_G = solve_gromov_linesearch(
G,
deltaG,
costG,
spot_distance_matrix[1],
spot_distance_matrix[2],
exp_dissim_matrix=0.0,
alpha=1.0,
alpha, fc, cost_G = ot.gromov.solve_gromov_linesearch(
G=G,
deltaG=deltaG,
cost_G=costG,
C1=spot_distance_matrix[1],
C2=spot_distance_matrix[2],
M=0.0,
reg=2 * 1.0,
nx=nx,
)
assert alpha == 1.0
Expand Down
Loading