Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix low-rank convergence criterion #547

Merged
merged 8 commits into from
Jul 3, 2024
Merged

Conversation

michalk8
Copy link
Collaborator

@michalk8 michalk8 commented Jun 5, 2024

closes #495

Copy link

codecov bot commented Jun 5, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 89.39%. Comparing base (c6fb25c) to head (ef479a8).
Report is 33 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #547      +/-   ##
==========================================
+ Coverage   89.09%   89.39%   +0.30%     
==========================================
  Files          70       71       +1     
  Lines        7427     7556     +129     
  Branches     1051     1080      +29     
==========================================
+ Hits         6617     6755     +138     
+ Misses        659      650       -9     
  Partials      151      151              
Files with missing lines Coverage Δ
src/ott/solvers/linear/sinkhorn_lr.py 98.65% <100.00%> (+<0.01%) ⬆️
src/ott/solvers/quadratic/gromov_wasserstein_lr.py 81.60% <100.00%> (+0.05%) ⬆️

... and 2 files with indirect coverage changes

@michalk8 michalk8 added the bug Something isn't working label Jun 5, 2024
@michalk8 michalk8 requested a review from marcocuturi June 5, 2024 16:09
Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM!

I am convinced that we need to get rid of the (1/ gamma**2) rescaling in line 60 of sinkhorn_lr.py. This rescaling might make sense for a theoretical analysis, but does not make sense for a practical convergence analysis.

tagging @meyerscetbon who might have an opinion, but, for instance, setting large gamma, e.g. to 100 vs. 10, makes the same criterion be rescaled "optimistically" (smaller) by 1e2! with a threshold of 1e-3 that we use by default, this makes absolutely no sense, and might explain erratic behavior.

@@ -687,7 +682,10 @@ def one_iteration(
lambda: state.reg_ot_cost(ot_prob, epsilon=self.epsilon),
lambda: jnp.inf
)
error = state.compute_error(previous_state)
error = jax.lax.cond(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

@marcocuturi
Copy link
Contributor

I suggest to:

  • remove this factor gamma^2 in the error. With the default gamma, this was dividing the criterion by 100, and hence applying a very large threshold=1e-1.
  • keep the default threshold parameter in LRSinkhorn. It's inheriting currently the 1e-3 setting of Sinkhorn. When removing gamma^2 this sounds like a reasonable value (although we might want to make it size dependent, as should be done soon with Sinkhorn)

@meyerscetbon
Copy link
Collaborator

The convergence criterion (as currently implemented) only makes “theoretical” sense when the gradient step is constant throughout the iterations. In that case, for each choice of gamma, we could adapt the desired approximation error (by considering gamma^2 \times esp) as suggested Marco. This will be equivalent as the current procedure and we are additionally eliminating an indeterminacy.

However, as we are using a rescaled gradient step, we should, in theory, monitor 1/gamma_k^2 (err_Q + err_R + err_g). And if gamma_k^2 converges towards a constant (close to 1) then it becomes equivalent to only monitor the errors. We could also apply the same idea as in the constant gradient-step case, by monitoring gamma_k^2 \times eps instead.

I hope this helps a little.

@marcocuturi
Copy link
Contributor

thanks @meyerscetbon ! I am not sure what you advocate though, practically speaking :) It seems to me that the sum of the 3 KLs is more natural as a quantity to monitor, independently of the stepsize that's chosen. Is there any reason why, intuitively, the termination criterion should be scaled with stepsize? For instance, if the first gamma=1000 (small stepsize) we will immediately converge in one iteration in principle. This seems counterintuitive to me.

@meyerscetbon
Copy link
Collaborator

I was trying to say that in the proof, we look at the convergence of (1/\gamma_k^2) * err (where err are the KL terms) and gamma_k is not assumed to be constant (indeed we only assume that it is constant at the end of the proof in order to obtain a sufficient condition of convergence). Therefore there might be some cases where the convergence of (1/\gamma_k^2) * err does not imply the convergence of err. However I think we can safely assume that gamma_k will converge towards a constant and therefore we can only monitor the error term as you suggested.

I was also trying to say that when gamma is constant along the iterations, then rescaling the error by (1/\gamma^2) or not is the same in term of convergence, and we can, as you suggested, remove (1/gamma^2) from the criterion and only monitor the error term.

Sorry for the confusion.

@giovp
Copy link
Contributor

giovp commented Jun 17, 2024

thanks for the discussion all, very interesting, and thanks @michalk8 for the fix. I've just tested on some of the failing tests we have in moscot and now the convergence seems to be set correctly (returning True when it should be).

@michalk8 michalk8 merged commit 7cfd393 into main Jul 3, 2024
12 checks passed
@michalk8 michalk8 deleted the feature/better-converged branch July 3, 2024 15:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

converged flag compatibility with min_iterations logic
5 participants