Skip to content

Commit

Permalink
shortened the line search function and updated the test function acco…
Browse files Browse the repository at this point in the history
…rdingly
  • Loading branch information
anushka255 committed Nov 15, 2024
1 parent 5eeb82d commit 6e99670
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 38 deletions.
52 changes: 17 additions & 35 deletions src/paste3/paste.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import logging
from collections.abc import Callable
from typing import Any

import numpy as np
Expand Down Expand Up @@ -654,26 +655,24 @@ def my_fused_gromov_wasserstein(
]
)

def f_loss(pi):
"""Compute the Gromov-Wasserstein loss for a given transport plan."""
combined_spatial_cost, a_gradient, b_gradient = ot.gromov.init_matrix(
def transform_matrix(pi):
p, q = torch.sum(pi, axis=1), torch.sum(pi, axis=0)
return ot.gromov.init_matrix(
a_spatial_dist,
b_spatial_dist,
torch.sum(pi, axis=1).reshape(-1, 1).to(a_spatial_dist.dtype),
torch.sum(pi, axis=0).reshape(1, -1).to(b_spatial_dist.dtype),
p,
q,
loss_fun,
)

def f_loss(pi):
"""Compute the Gromov-Wasserstein loss for a given transport plan."""
combined_spatial_cost, a_gradient, b_gradient = transform_matrix(pi)
return ot.gromov.gwloss(combined_spatial_cost, a_gradient, b_gradient, pi)

def f_gradient(pi):
"""Compute the gradient of the Gromov-Wasserstein loss for a given transport plan."""
combined_spatial_cost, a_gradient, b_gradient = ot.gromov.init_matrix(
a_spatial_dist,
b_spatial_dist,
torch.sum(pi, axis=1).reshape(-1, 1),
torch.sum(pi, axis=0).reshape(1, -1),
loss_fun,
)
combined_spatial_cost, a_gradient, b_gradient = transform_matrix(pi)
return ot.gromov.gwggrad(combined_spatial_cost, a_gradient, b_gradient, pi)

def line_search(f_cost, pi, pi_diff, linearized_matrix, cost_pi, _, **kwargs):
Expand All @@ -697,7 +696,8 @@ def line_search(f_cost, pi, pi_diff, linearized_matrix, cost_pi, _, **kwargs):
a_spatial_dist,
b_spatial_dist,
pi_diff,
loss_fun=loss_fun,
f_cost,
f_gradient,
)
return ot.gromov.solve_gromov_linesearch(
G=pi,
Expand Down Expand Up @@ -767,7 +767,8 @@ def line_search_partial(
a_spatial_dist: torch.Tensor,
b_spatial_dist: torch.Tensor,
pi_diff: torch.Tensor,
loss_fun: str = "square_loss",
f_cost: Callable,
f_gradient: Callable,
):
"""
Solve the linesearch in the fused wasserstein iterations for partially overlapping slices
Expand Down Expand Up @@ -799,31 +800,12 @@ def line_search_partial(
cost_G : float
The final cost after the update of the transport plan.
"""
combined_spatial_cost, a_gradient, b_gradient = ot.gromov.init_matrix(
a_spatial_dist,
b_spatial_dist,
torch.sum(pi_diff, axis=1).reshape(-1, 1),
torch.sum(pi_diff, axis=0).reshape(1, -1),
loss_fun,
)

dot = torch.matmul(torch.matmul(a_spatial_dist, pi_diff), b_spatial_dist.T)
a = alpha * torch.sum(dot * pi_diff)
b = (1 - alpha) * torch.sum(exp_dissim_matrix * pi_diff) + 2 * alpha * torch.sum(
ot.gromov.gwggrad(combined_spatial_cost, a_gradient, b_gradient, pi_diff)
* 0.5
* pi
f_gradient(pi_diff) * 0.5 * pi
)
minimal_cost = ot.optim.solve_1d_linesearch_quad(a, b)
pi = pi + minimal_cost * pi_diff
combined_spatial_cost, a_gradient, b_gradient = ot.gromov.init_matrix(
a_spatial_dist,
b_spatial_dist,
torch.sum(pi, axis=1).reshape(-1, 1),
torch.sum(pi, axis=0).reshape(1, -1),
loss_fun,
)
cost_G = (1 - alpha) * torch.sum(exp_dissim_matrix * pi) + alpha * ot.gromov.gwloss(
combined_spatial_cost, a_gradient, b_gradient, pi
)
cost_G = f_cost(pi + minimal_cost * pi_diff)
return minimal_cost, a, cost_G
22 changes: 19 additions & 3 deletions tests/test_paste.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,22 +235,38 @@ def test_gromov_linesearch(spot_distance_matrix):


def test_line_search_partial(spot_distance_matrix):
d1, d2 = spot_distance_matrix[1], spot_distance_matrix[2]
G = 1.509115054931788e-05 * torch.ones((251, 264)).double()
deltaG = torch.Tensor(
np.genfromtxt(input_dir / "deltaG.csv", delimiter=",")
).double()
M = torch.Tensor(
np.genfromtxt(input_dir / "euc_dissimilarity.csv", delimiter=",")
).double()
alpha = 0.1

alpha, a, cost_G = line_search_partial(
alpha=0.1,
def f_cost(pi):
p, q = torch.sum(pi, axis=1), torch.sum(pi, axis=0)
constC, hC1, hC2 = ot.gromov.init_matrix(d1, d2, p, q)
return (1 - alpha) * torch.sum(M * pi) + alpha * ot.gromov.gwloss(
constC, hC1, hC2, pi
)

def f_gradient(pi):
p, q = torch.sum(pi, axis=1), torch.sum(pi, axis=0)
constC, hC1, hC2 = ot.gromov.init_matrix(d1, d2, p, q)
return ot.gromov.gwggrad(constC, hC1, hC2, pi)

minimal_cost, a, cost_G = line_search_partial(
alpha=alpha,
exp_dissim_matrix=M,
pi=G,
a_spatial_dist=spot_distance_matrix[1],
b_spatial_dist=spot_distance_matrix[2],
pi_diff=deltaG,
f_cost=f_cost,
f_gradient=f_gradient,
)
assert alpha == 1.0
assert minimal_cost == 1.0
assert pytest.approx(a) == 0.4858849047237918
assert pytest.approx(cost_G) == 102.6333512778727

0 comments on commit 6e99670

Please sign in to comment.