From 6839ddc82129b92b34273ff5d8e924ea99147e1e Mon Sep 17 00:00:00 2001 From: Viraj Pandya Date: Wed, 22 Feb 2023 12:20:56 -0500 Subject: [PATCH] use jnp.nanmin and jnp.nanmax to compute new stepsize factor instead of jnp.clip Care was taken to make sure that the order of jnp.nanmin and jnp.nanmax reflects the actual behavior of jnp.clip. According to https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.clip.html, using jnp.(nan)min and jnp.(nan)max may be a bit slower than jnp.clip but at least it will be robust against NaN's in y_error. --- diffrax/step_size_controller/adaptive.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/diffrax/step_size_controller/adaptive.py b/diffrax/step_size_controller/adaptive.py index 5297050d..0cc0afc0 100644 --- a/diffrax/step_size_controller/adaptive.py +++ b/diffrax/step_size_controller/adaptive.py @@ -496,11 +496,7 @@ def _scale(_y0, _y1_candidate, _y_error): factor2 = 1 if _zero_coeff(coeff2) else prev_inv_scaled_error ** coeff2 factor3 = 1 if _zero_coeff(coeff3) else prev_prev_inv_scaled_error ** coeff3 factormin = jnp.where(keep_step, 1, self.factormin) - factor = jnp.clip( - self.safety * factor1 * factor2 * factor3, - a_min=factormin, - a_max=self.factormax, - ) + factor = jnp.nanmin(jnp.array([self.factormax, jnp.nanmax(jnp.array([self.safety * factor1 * factor2 * factor3, factormin]))])) # Once again, see above. In case we have gradients on {i,p,d}coeff. # (Probably quite common for them to have zero tangents if passed across # a grad API boundary as part of a larger model.)