11
11
12
12
from .custom_types import Array , DenseInfos , Int , PyTree , Scalar
13
13
from .local_interpolation import AbstractLocalInterpolation
14
- from .misc import fill_forward , left_broadcast_to
14
+ from .misc import fill_forward , left_broadcast_to , linear_rescale
15
15
from .path import AbstractPath
16
16
17
17
@@ -124,10 +124,10 @@ def _index(_ys):
124
124
next_ys = (self .ys ** ω )[index + 1 ].ω
125
125
prev_t = self .ts [index ]
126
126
next_t = self .ts [index + 1 ]
127
- diff_t = next_t - prev_t
128
-
129
127
return (
130
- prev_ys ** ω + (next_ys ** ω - prev_ys ** ω ) * (fractional_part / diff_t )
128
+ prev_ys ** ω
129
+ + (next_ys ** ω - prev_ys ** ω )
130
+ * (linear_rescale (prev_t , fractional_part , next_t ))
131
131
).ω
132
132
133
133
@eqx .filter_jit
@@ -407,7 +407,6 @@ def _linear_interpolation_forward(
407
407
Tuple [Array ["channels" :...], Array ["channels" :...]], # noqa: F821
408
408
Array ["channels" :...], # noqa: F821
409
409
]:
410
-
411
410
prev_ti , prev_yi = carry
412
411
ti , yi , next_ti , next_yi = value
413
412
cond = jnp .isnan (yi )
@@ -426,7 +425,6 @@ def _linear_interpolation(
426
425
ys : Array ["times" , "channels" :...], # noqa: F821
427
426
replace_nans_at_start : Optional [Array ["channels" :...]] = None , # noqa: F821
428
427
) -> Array ["times" , "channels" :...]: # noqa: F821
429
-
430
428
ts = left_broadcast_to (ts , ys .shape )
431
429
432
430
if replace_nans_at_start is None :
@@ -599,7 +597,6 @@ def _hermite_forward(
599
597
Array ["channels" :...], # noqa: F821
600
598
],
601
599
]:
602
-
603
600
prev_ti , prev_yi , prev_deriv_i = carry
604
601
ti , yi , next_ti , next_yi = value
605
602
first_deriv_i = (next_yi - yi ) / (next_ti - ti )
0 commit comments