diff --git a/diffrax/global_interpolation.py b/diffrax/global_interpolation.py index 5900bb45..d38f49e2 100644 --- a/diffrax/global_interpolation.py +++ b/diffrax/global_interpolation.py @@ -11,7 +11,7 @@ from .custom_types import Array, DenseInfos, Int, PyTree, Scalar from .local_interpolation import AbstractLocalInterpolation -from .misc import fill_forward, left_broadcast_to +from .misc import fill_forward, left_broadcast_to, linear_rescale from .path import AbstractPath @@ -124,10 +124,10 @@ def _index(_ys): next_ys = (self.ys**ω)[index + 1].ω prev_t = self.ts[index] next_t = self.ts[index + 1] - diff_t = next_t - prev_t - return ( - prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t) + prev_ys**ω + + (next_ys**ω - prev_ys**ω) + * (linear_rescale(prev_t, fractional_part, next_t)) ).ω @eqx.filter_jit @@ -407,7 +407,6 @@ def _linear_interpolation_forward( Tuple[Array["channels":...], Array["channels":...]], # noqa: F821 Array["channels":...], # noqa: F821 ]: - prev_ti, prev_yi = carry ti, yi, next_ti, next_yi = value cond = jnp.isnan(yi) @@ -426,7 +425,6 @@ def _linear_interpolation( ys: Array["times", "channels":...], # noqa: F821 replace_nans_at_start: Optional[Array["channels":...]] = None, # noqa: F821 ) -> Array["times", "channels":...]: # noqa: F821 - ts = left_broadcast_to(ts, ys.shape) if replace_nans_at_start is None: @@ -599,7 +597,6 @@ def _hermite_forward( Array["channels":...], # noqa: F821 ], ]: - prev_ti, prev_yi, prev_deriv_i = carry ti, yi, next_ti, next_yi = value first_deriv_i = (next_yi - yi) / (next_ti - ti)