diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index 69e3eeb43..e4bce2e2b 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -57,7 +57,8 @@ def compute_error( # noqa: D102 err_r = mu.gen_js(self.r, previous_state.r, c=1.0) err_g = mu.gen_js(self.g, previous_state.g, c=1.0) - return ((1.0 / self.gamma) ** 2) * (err_q + err_r + err_g) + # don't scale by (1 / gamma ** 2); https://github.com/ott-jax/ott/pull/547 + return err_q + err_r + err_g def reg_ot_cost( # noqa: D102 self, @@ -175,6 +176,7 @@ class LRSinkhornOutput(NamedTuple): ot_prob: linear_problem.LinearProblem epsilon: float inner_iterations: int + converged: bool # TODO(michalk8): Optional is an artifact of the current impl., refactor reg_ot_cost: Optional[float] = None @@ -221,12 +223,6 @@ def b(self) -> jnp.ndarray: # noqa: D102 def n_iters(self) -> int: # noqa: D102 return jnp.sum(self.errors != -1) * self.inner_iterations - @property - def converged(self) -> bool: # noqa: D102 - return jnp.logical_and( - jnp.any(self.costs == -1), jnp.all(jnp.isfinite(self.costs)) - ) - @property def matrix(self) -> jnp.ndarray: """Transport matrix if it can be instantiated.""" @@ -687,7 +683,10 @@ def one_iteration( lambda: state.reg_ot_cost(ot_prob, epsilon=self.epsilon), lambda: jnp.inf ) - error = state.compute_error(previous_state) + error = jax.lax.cond( + iteration >= self.min_iterations, + lambda: state.compute_error(previous_state), lambda: jnp.inf + ) crossed_threshold = jnp.logical_or( state.crossed_threshold, jnp.logical_and( @@ -761,6 +760,8 @@ def output_from_state( Returns: A LRSinkhornOutput. """ + it = jnp.sum(state.errors != -1.0) * self.inner_iterations + converged = self._converged(state, it) return LRSinkhornOutput( q=state.q, r=state.r, @@ -770,6 +771,7 @@ def output_from_state( errors=state.errors, epsilon=self.epsilon, inner_iterations=self.inner_iterations, + converged=converged, ) def _converged(self, state: LRSinkhornState, iteration: int) -> bool: @@ -800,11 +802,13 @@ def conv_not_crossed(prev_err: float, curr_err: float) -> bool: ) def _diverged(self, state: LRSinkhornState, iteration: int) -> bool: - it = iteration // self.inner_iterations - return jnp.logical_and( - jnp.logical_not(jnp.isfinite(state.errors[it - 1])), - jnp.logical_not(jnp.isfinite(state.costs[it - 1])) + it = iteration // self.inner_iterations - 1 + is_not_finite = jnp.logical_and( + jnp.logical_not(jnp.isfinite(state.errors[it])), + jnp.logical_not(jnp.isfinite(state.costs[it])) ) + # `jnp.inf` is used if `it < self.min_iterations` + return jnp.logical_and(it >= self.min_iterations, is_not_finite) def run( diff --git a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py index 15abdbf20..0237cb2af 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein_lr.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein_lr.py @@ -123,6 +123,7 @@ def ent(x: jnp.ndarray) -> float: errors=None, epsilon=None, inner_iterations=None, + converged=False, ) cost = out.primal_cost - epsilon * (ent(q) + ent(r) + ent(g)) @@ -148,6 +149,7 @@ class LRGWOutput(NamedTuple): ot_prob: quadratic_problem.QuadraticProblem epsilon: float inner_iterations: int + converged: bool reg_gw_cost: Optional[float] = None def set(self, **kwargs: Any) -> "LRGWOutput": @@ -194,12 +196,6 @@ def b(self) -> jnp.ndarray: # noqa: D102 def n_iters(self) -> int: # noqa: D102 return jnp.sum(self.errors != -1) * self.inner_iterations - @property - def converged(self) -> bool: # noqa: D102 - return jnp.logical_and( - jnp.any(self.costs == -1), jnp.all(jnp.isfinite(self.costs)) - ) - @property def matrix(self) -> jnp.ndarray: """Transport matrix if it can be instantiated.""" @@ -718,7 +714,10 @@ def one_iteration( lambda: state.reg_gw_cost(ot_prob, epsilon=self.epsilon), lambda: jnp.inf ) - error = state.compute_error(previous_state) + error = jax.lax.cond( + iteration >= self.min_iterations, + lambda: state.compute_error(previous_state), lambda: jnp.inf + ) crossed_threshold = jnp.logical_or( state.crossed_threshold, jnp.logical_and( @@ -794,6 +793,8 @@ def output_from_state( Returns: A LRGWOutput. """ + it = jnp.sum(state.errors != -1.0) * self.inner_iterations + converged = self._converged(state, it) return LRGWOutput( q=state.q, r=state.r, @@ -803,6 +804,7 @@ def output_from_state( errors=state.errors, epsilon=self.epsilon, inner_iterations=self.inner_iterations, + converged=converged, ) def _converged(self, state: LRGWState, iteration: int) -> bool: @@ -833,11 +835,13 @@ def conv_not_crossed(prev_err: float, curr_err: float) -> bool: ) def _diverged(self, state: LRGWState, iteration: int) -> bool: - it = iteration // self.inner_iterations - return jnp.logical_and( - jnp.logical_not(jnp.isfinite(state.errors[it - 1])), - jnp.logical_not(jnp.isfinite(state.costs[it - 1])) + it = iteration // self.inner_iterations - 1 + is_not_finite = jnp.logical_and( + jnp.logical_not(jnp.isfinite(state.errors[it])), + jnp.logical_not(jnp.isfinite(state.costs[it])) ) + # `jnp.inf` is used if `it < self.min_iterations` + return jnp.logical_and(it >= self.min_iterations, is_not_finite) def run( diff --git a/tests/initializers/linear/sinkhorn_lr_init_test.py b/tests/initializers/linear/sinkhorn_lr_init_test.py index 3c6e50c86..f4b7ff2b3 100644 --- a/tests/initializers/linear/sinkhorn_lr_init_test.py +++ b/tests/initializers/linear/sinkhorn_lr_init_test.py @@ -17,7 +17,7 @@ import jax.numpy as jnp import numpy as np -from ott.geometry import geometry, pointcloud +from ott.geometry import pointcloud from ott.initializers.linear import initializers_lr from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn_lr @@ -82,35 +82,6 @@ def test_generalized_k_means_has_correct_rank( assert jnp.linalg.matrix_rank(q) == rank assert jnp.linalg.matrix_rank(r) == rank - def test_generalized_k_means_matches_k_means(self, rng: jax.Array): - n, d, rank = 27, 7, 5 - eps = 1e-1 - rng1, rng2 = jax.random.split(rng, 2) - x = jax.random.normal(rng1, (n, d)) - y = jax.random.normal(rng1, (n, d)) - - pc = pointcloud.PointCloud(x, y, epsilon=eps) - geom = geometry.Geometry(cost_matrix=pc.cost_matrix, epsilon=eps) - pc_problem = linear_problem.LinearProblem(pc) - geom_problem = linear_problem.LinearProblem(geom) - - solver = sinkhorn_lr.LRSinkhorn( - rank=rank, initializer="k-means", max_iterations=5000 - ) - pc_out = solver(pc_problem) - - solver = sinkhorn_lr.LRSinkhorn( - rank=rank, initializer="generalized-k-means", max_iterations=5000 - ) - geom_out = solver(geom_problem) - - with pytest.raises(AssertionError): - np.testing.assert_allclose(pc_out.costs, geom_out.costs) - - np.testing.assert_allclose( - pc_out.reg_ot_cost, geom_out.reg_ot_cost, atol=0.5, rtol=0.02 - ) - @pytest.mark.parametrize("epsilon", [0.0, 1e-1]) def test_better_initialization_helps(self, rng: jax.Array, epsilon: float): n, d, rank = 81, 13, 3 diff --git a/tests/solvers/linear/sinkhorn_lr_test.py b/tests/solvers/linear/sinkhorn_lr_test.py index e5fc121d7..f91d21a02 100644 --- a/tests/solvers/linear/sinkhorn_lr_test.py +++ b/tests/solvers/linear/sinkhorn_lr_test.py @@ -200,15 +200,9 @@ def progress_fn( lin_prob ) - # check that the function is called on the 10th iteration (iter #9), the - # 20th iteration (iter #19). - assert traced_values["iters"] == [9, 19] - - # check that error decreases - np.testing.assert_array_equal(np.diff(traced_values["error"]) < 0, True) - - # check that max iterations is provided each time: [30, 30] - assert traced_values["total"] == [num_iterations] * 2 + assert traced_values["iters"] == [9, 19, 29, 39] + assert traced_values["total"] == [num_iterations + ] * len(traced_values["total"]) @pytest.mark.fast.with_args(eps=[0.0, 1e-1]) def test_lse_matches_kernel_mode(self, eps: float): @@ -318,7 +312,7 @@ def test_lr_unbalanced_ti( assert out.converged assert out_ti.converged - np.testing.assert_allclose(out.errors, out_ti.errors, rtol=1e-4, atol=1e-4) + np.testing.assert_allclose(out.errors, out_ti.errors, rtol=5e-4, atol=5e-4) np.testing.assert_allclose( out.reg_ot_cost, out_ti.reg_ot_cost, rtol=1e-2, atol=1e-2 )