diff --git a/diffrax/step_size_controller/adaptive.py b/diffrax/step_size_controller/adaptive.py index 5297050d..0cc0afc0 100644 --- a/diffrax/step_size_controller/adaptive.py +++ b/diffrax/step_size_controller/adaptive.py @@ -496,11 +496,7 @@ def _scale(_y0, _y1_candidate, _y_error): factor2 = 1 if _zero_coeff(coeff2) else prev_inv_scaled_error ** coeff2 factor3 = 1 if _zero_coeff(coeff3) else prev_prev_inv_scaled_error ** coeff3 factormin = jnp.where(keep_step, 1, self.factormin) - factor = jnp.clip( - self.safety * factor1 * factor2 * factor3, - a_min=factormin, - a_max=self.factormax, - ) + factor = jnp.nanmin(jnp.array([self.factormax, jnp.nanmax(jnp.array([self.safety * factor1 * factor2 * factor3, factormin]))])) # Once again, see above. In case we have gradients on {i,p,d}coeff. # (Probably quite common for them to have zero tangents if passed across # a grad API boundary as part of a larger model.)