Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix low-rank convergence criterion #547

Merged
merged 8 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def compute_error( # noqa: D102
err_r = mu.gen_js(self.r, previous_state.r, c=1.0)
err_g = mu.gen_js(self.g, previous_state.g, c=1.0)

return ((1.0 / self.gamma) ** 2) * (err_q + err_r + err_g)
# don't scale by (1 / gamma ** 2); https://github.com/ott-jax/ott/pull/547
return err_q + err_r + err_g

def reg_ot_cost( # noqa: D102
self,
Expand Down Expand Up @@ -175,6 +176,7 @@ class LRSinkhornOutput(NamedTuple):
ot_prob: linear_problem.LinearProblem
epsilon: float
inner_iterations: int
converged: bool
# TODO(michalk8): Optional is an artifact of the current impl., refactor
reg_ot_cost: Optional[float] = None

Expand Down Expand Up @@ -221,12 +223,6 @@ def b(self) -> jnp.ndarray: # noqa: D102
def n_iters(self) -> int: # noqa: D102
return jnp.sum(self.errors != -1) * self.inner_iterations

@property
def converged(self) -> bool: # noqa: D102
return jnp.logical_and(
jnp.any(self.costs == -1), jnp.all(jnp.isfinite(self.costs))
)

@property
def matrix(self) -> jnp.ndarray:
"""Transport matrix if it can be instantiated."""
Expand Down Expand Up @@ -687,7 +683,10 @@ def one_iteration(
lambda: state.reg_ot_cost(ot_prob, epsilon=self.epsilon),
lambda: jnp.inf
)
error = state.compute_error(previous_state)
error = jax.lax.cond(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

iteration >= self.min_iterations,
lambda: state.compute_error(previous_state), lambda: jnp.inf
)
crossed_threshold = jnp.logical_or(
state.crossed_threshold,
jnp.logical_and(
Expand Down Expand Up @@ -761,6 +760,8 @@ def output_from_state(
Returns:
A LRSinkhornOutput.
"""
it = jnp.sum(state.errors != -1.0) * self.inner_iterations
converged = self._converged(state, it)
return LRSinkhornOutput(
q=state.q,
r=state.r,
Expand All @@ -770,6 +771,7 @@ def output_from_state(
errors=state.errors,
epsilon=self.epsilon,
inner_iterations=self.inner_iterations,
converged=converged,
)

def _converged(self, state: LRSinkhornState, iteration: int) -> bool:
Expand Down Expand Up @@ -800,11 +802,13 @@ def conv_not_crossed(prev_err: float, curr_err: float) -> bool:
)

def _diverged(self, state: LRSinkhornState, iteration: int) -> bool:
it = iteration // self.inner_iterations
return jnp.logical_and(
jnp.logical_not(jnp.isfinite(state.errors[it - 1])),
jnp.logical_not(jnp.isfinite(state.costs[it - 1]))
it = iteration // self.inner_iterations - 1
is_not_finite = jnp.logical_and(
jnp.logical_not(jnp.isfinite(state.errors[it])),
jnp.logical_not(jnp.isfinite(state.costs[it]))
)
# `jnp.inf` is used if `it < self.min_iterations`
return jnp.logical_and(it >= self.min_iterations, is_not_finite)


def run(
Expand Down
26 changes: 15 additions & 11 deletions src/ott/solvers/quadratic/gromov_wasserstein_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def ent(x: jnp.ndarray) -> float:
errors=None,
epsilon=None,
inner_iterations=None,
converged=False,
)

cost = out.primal_cost - epsilon * (ent(q) + ent(r) + ent(g))
Expand All @@ -148,6 +149,7 @@ class LRGWOutput(NamedTuple):
ot_prob: quadratic_problem.QuadraticProblem
epsilon: float
inner_iterations: int
converged: bool
reg_gw_cost: Optional[float] = None

def set(self, **kwargs: Any) -> "LRGWOutput":
Expand Down Expand Up @@ -194,12 +196,6 @@ def b(self) -> jnp.ndarray: # noqa: D102
def n_iters(self) -> int: # noqa: D102
return jnp.sum(self.errors != -1) * self.inner_iterations

@property
def converged(self) -> bool: # noqa: D102
return jnp.logical_and(
jnp.any(self.costs == -1), jnp.all(jnp.isfinite(self.costs))
)

@property
def matrix(self) -> jnp.ndarray:
"""Transport matrix if it can be instantiated."""
Expand Down Expand Up @@ -718,7 +714,10 @@ def one_iteration(
lambda: state.reg_gw_cost(ot_prob, epsilon=self.epsilon),
lambda: jnp.inf
)
error = state.compute_error(previous_state)
error = jax.lax.cond(
iteration >= self.min_iterations,
lambda: state.compute_error(previous_state), lambda: jnp.inf
)
crossed_threshold = jnp.logical_or(
state.crossed_threshold,
jnp.logical_and(
Expand Down Expand Up @@ -794,6 +793,8 @@ def output_from_state(
Returns:
A LRGWOutput.
"""
it = jnp.sum(state.errors != -1.0) * self.inner_iterations
converged = self._converged(state, it)
return LRGWOutput(
q=state.q,
r=state.r,
Expand All @@ -803,6 +804,7 @@ def output_from_state(
errors=state.errors,
epsilon=self.epsilon,
inner_iterations=self.inner_iterations,
converged=converged,
)

def _converged(self, state: LRGWState, iteration: int) -> bool:
Expand Down Expand Up @@ -833,11 +835,13 @@ def conv_not_crossed(prev_err: float, curr_err: float) -> bool:
)

def _diverged(self, state: LRGWState, iteration: int) -> bool:
it = iteration // self.inner_iterations
return jnp.logical_and(
jnp.logical_not(jnp.isfinite(state.errors[it - 1])),
jnp.logical_not(jnp.isfinite(state.costs[it - 1]))
it = iteration // self.inner_iterations - 1
is_not_finite = jnp.logical_and(
jnp.logical_not(jnp.isfinite(state.errors[it])),
jnp.logical_not(jnp.isfinite(state.costs[it]))
)
# `jnp.inf` is used if `it < self.min_iterations`
return jnp.logical_and(it >= self.min_iterations, is_not_finite)


def run(
Expand Down
31 changes: 1 addition & 30 deletions tests/initializers/linear/sinkhorn_lr_init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax.numpy as jnp
import numpy as np

from ott.geometry import geometry, pointcloud
from ott.geometry import pointcloud
from ott.initializers.linear import initializers_lr
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn_lr
Expand Down Expand Up @@ -82,35 +82,6 @@ def test_generalized_k_means_has_correct_rank(
assert jnp.linalg.matrix_rank(q) == rank
assert jnp.linalg.matrix_rank(r) == rank

def test_generalized_k_means_matches_k_means(self, rng: jax.Array):
n, d, rank = 27, 7, 5
eps = 1e-1
rng1, rng2 = jax.random.split(rng, 2)
x = jax.random.normal(rng1, (n, d))
y = jax.random.normal(rng1, (n, d))

pc = pointcloud.PointCloud(x, y, epsilon=eps)
geom = geometry.Geometry(cost_matrix=pc.cost_matrix, epsilon=eps)
pc_problem = linear_problem.LinearProblem(pc)
geom_problem = linear_problem.LinearProblem(geom)

solver = sinkhorn_lr.LRSinkhorn(
rank=rank, initializer="k-means", max_iterations=5000
)
pc_out = solver(pc_problem)

solver = sinkhorn_lr.LRSinkhorn(
rank=rank, initializer="generalized-k-means", max_iterations=5000
)
geom_out = solver(geom_problem)

with pytest.raises(AssertionError):
np.testing.assert_allclose(pc_out.costs, geom_out.costs)

np.testing.assert_allclose(
pc_out.reg_ot_cost, geom_out.reg_ot_cost, atol=0.5, rtol=0.02
)

@pytest.mark.parametrize("epsilon", [0.0, 1e-1])
def test_better_initialization_helps(self, rng: jax.Array, epsilon: float):
n, d, rank = 81, 13, 3
Expand Down
14 changes: 4 additions & 10 deletions tests/solvers/linear/sinkhorn_lr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,9 @@ def progress_fn(
lin_prob
)

# check that the function is called on the 10th iteration (iter #9), the
# 20th iteration (iter #19).
assert traced_values["iters"] == [9, 19]

# check that error decreases
np.testing.assert_array_equal(np.diff(traced_values["error"]) < 0, True)

# check that max iterations is provided each time: [30, 30]
assert traced_values["total"] == [num_iterations] * 2
assert traced_values["iters"] == [9, 19, 29, 39]
assert traced_values["total"] == [num_iterations
] * len(traced_values["total"])

@pytest.mark.fast.with_args(eps=[0.0, 1e-1])
def test_lse_matches_kernel_mode(self, eps: float):
Expand Down Expand Up @@ -318,7 +312,7 @@ def test_lr_unbalanced_ti(

assert out.converged
assert out_ti.converged
np.testing.assert_allclose(out.errors, out_ti.errors, rtol=1e-4, atol=1e-4)
np.testing.assert_allclose(out.errors, out_ti.errors, rtol=5e-4, atol=5e-4)
np.testing.assert_allclose(
out.reg_ot_cost, out_ti.reg_ot_cost, rtol=1e-2, atol=1e-2
)
Expand Down
Loading