-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix low-rank convergence criterion #547
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #547 +/- ##
==========================================
+ Coverage 89.09% 89.39% +0.30%
==========================================
Files 70 71 +1
Lines 7427 7556 +129
Branches 1051 1080 +29
==========================================
+ Hits 6617 6755 +138
+ Misses 659 650 -9
Partials 151 151
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM!
I am convinced that we need to get rid of the (1/ gamma**2)
rescaling in line 60 of sinkhorn_lr.py
. This rescaling might make sense for a theoretical analysis, but does not make sense for a practical convergence analysis.
tagging @meyerscetbon who might have an opinion, but, for instance, setting large gamma
, e.g. to 100 vs. 10, makes the same criterion be rescaled "optimistically" (smaller) by 1e2! with a threshold of 1e-3 that we use by default, this makes absolutely no sense, and might explain erratic behavior.
@@ -687,7 +682,10 @@ def one_iteration( | |||
lambda: state.reg_ot_cost(ot_prob, epsilon=self.epsilon), | |||
lambda: jnp.inf | |||
) | |||
error = state.compute_error(previous_state) | |||
error = jax.lax.cond( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch!
I suggest to:
|
The convergence criterion (as currently implemented) only makes “theoretical” sense when the gradient step is constant throughout the iterations. In that case, for each choice of gamma, we could adapt the desired approximation error (by considering gamma^2 \times esp) as suggested Marco. This will be equivalent as the current procedure and we are additionally eliminating an indeterminacy. However, as we are using a rescaled gradient step, we should, in theory, monitor 1/gamma_k^2 (err_Q + err_R + err_g). And if gamma_k^2 converges towards a constant (close to 1) then it becomes equivalent to only monitor the errors. We could also apply the same idea as in the constant gradient-step case, by monitoring gamma_k^2 \times eps instead. I hope this helps a little. |
thanks @meyerscetbon ! I am not sure what you advocate though, practically speaking :) It seems to me that the sum of the 3 KLs is more natural as a quantity to monitor, independently of the stepsize that's chosen. Is there any reason why, intuitively, the termination criterion should be scaled with stepsize? For instance, if the first |
I was trying to say that in the proof, we look at the convergence of (1/\gamma_k^2) * err (where err are the KL terms) and gamma_k is not assumed to be constant (indeed we only assume that it is constant at the end of the proof in order to obtain a sufficient condition of convergence). Therefore there might be some cases where the convergence of (1/\gamma_k^2) * err does not imply the convergence of err. However I think we can safely assume that gamma_k will converge towards a constant and therefore we can only monitor the error term as you suggested. I was also trying to say that when gamma is constant along the iterations, then rescaling the error by (1/\gamma^2) or not is the same in term of convergence, and we can, as you suggested, remove (1/gamma^2) from the criterion and only monitor the error term. Sorry for the confusion. |
thanks for the discussion all, very interesting, and thanks @michalk8 for the fix. I've just tested on some of the failing tests we have in moscot and now the convergence seems to be set correctly (returning True when it should be). |
closes #495