Skip to content

Commit 07b8691

Browse files
committed
Patched edge-case in LinearInterpolation
1 parent 737bf39 commit 07b8691

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

diffrax/global_interpolation.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from .custom_types import Array, DenseInfos, Int, PyTree, Scalar
1313
from .local_interpolation import AbstractLocalInterpolation
14-
from .misc import fill_forward, left_broadcast_to
14+
from .misc import fill_forward, left_broadcast_to, linear_rescale
1515
from .path import AbstractPath
1616

1717

@@ -124,10 +124,10 @@ def _index(_ys):
124124
next_ys = (self.ys**ω)[index + 1].ω
125125
prev_t = self.ts[index]
126126
next_t = self.ts[index + 1]
127-
diff_t = next_t - prev_t
128-
129127
return (
130-
prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)
128+
prev_ys**ω
129+
+ (next_ys**ω - prev_ys**ω)
130+
* (linear_rescale(prev_t, fractional_part, next_t))
131131
).ω
132132

133133
@eqx.filter_jit
@@ -407,7 +407,6 @@ def _linear_interpolation_forward(
407407
Tuple[Array["channels":...], Array["channels":...]], # noqa: F821
408408
Array["channels":...], # noqa: F821
409409
]:
410-
411410
prev_ti, prev_yi = carry
412411
ti, yi, next_ti, next_yi = value
413412
cond = jnp.isnan(yi)
@@ -426,7 +425,6 @@ def _linear_interpolation(
426425
ys: Array["times", "channels":...], # noqa: F821
427426
replace_nans_at_start: Optional[Array["channels":...]] = None, # noqa: F821
428427
) -> Array["times", "channels":...]: # noqa: F821
429-
430428
ts = left_broadcast_to(ts, ys.shape)
431429

432430
if replace_nans_at_start is None:
@@ -599,7 +597,6 @@ def _hermite_forward(
599597
Array["channels":...], # noqa: F821
600598
],
601599
]:
602-
603600
prev_ti, prev_yi, prev_deriv_i = carry
604601
ti, yi, next_ti, next_yi = value
605602
first_deriv_i = (next_yi - yi) / (next_ti - ti)

0 commit comments

Comments
 (0)