diff --git a/scitbx/lstbx/normal_eqns_solving.py b/scitbx/lstbx/normal_eqns_solving.py index 632d1f1b1a..86e665451b 100644 --- a/scitbx/lstbx/normal_eqns_solving.py +++ b/scitbx/lstbx/normal_eqns_solving.py @@ -256,7 +256,7 @@ def do(self): objective_new = self.non_linear_ls.objective() actual_decrease = objective - objective_new rho = actual_decrease/expected_decrease - if rho > 0: + if rho >= 0: if self.has_gradient_converged_to_zero(): break self.mu *= max(1/3, 1 - (2*rho - 1)**3) nu = 2 @@ -325,7 +325,7 @@ def do(self): rho = actual_decrease/expected_decrease if self.objective_decrease_threshold is not None: if actual_decrease/objective < self.objective_decrease_threshold: break - if rho > 0: + if rho >= 0: if self.has_gradient_converged_to_zero(): break self.mu *= max(1/3, 1 - (2*rho - 1)**3) nu = 2 diff --git a/scitbx/lstbx/tests/tst_normal_equations.py b/scitbx/lstbx/tests/tst_normal_equations.py index fb264f678e..bafcdb7ed0 100644 --- a/scitbx/lstbx/tests/tst_normal_equations.py +++ b/scitbx/lstbx/tests/tst_normal_equations.py @@ -179,6 +179,14 @@ def exercise_levenberg_marquardt(non_linear_ls, plot=False): print("\[Mu]=%s;" % iterations.mu_history.mathematica_form(), file=f) print("ListLogPlot[{g,\[Mu]},Joined->True]", file=f) f.close() + non_linear_ls.restart() + iterations = normal_eqns_solving.levenberg_marquardt_iterations( + non_linear_ls, + track_all=True, + gradient_threshold=0, + step_threshold=0, + n_max_iterations=200) + assert iterations.n_iterations == 200 def run(): import sys