Skip to content

Commit bbad868

Browse files
committed
In progress commit on branch delay.
1 parent f9dae13 commit bbad868

24 files changed

+12391
-106
lines changed

diffrax/__init__.py

Lines changed: 93 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,113 @@
11
import importlib.metadata
22

3-
from ._adjoint import (
4-
AbstractAdjoint as AbstractAdjoint,
5-
BacksolveAdjoint as BacksolveAdjoint,
6-
DirectAdjoint as DirectAdjoint,
7-
ImplicitAdjoint as ImplicitAdjoint,
8-
RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint,
9-
)
10-
from ._autocitation import citation as citation, citation_rules as citation_rules
11-
from ._brownian import (
12-
AbstractBrownianPath as AbstractBrownianPath,
13-
UnsafeBrownianPath as UnsafeBrownianPath,
14-
VirtualBrownianTree as VirtualBrownianTree,
15-
)
16-
from ._event import (
17-
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
18-
DiscreteTerminatingEvent as DiscreteTerminatingEvent,
19-
SteadyStateEvent as SteadyStateEvent,
20-
)
3+
from ._adjoint import AbstractAdjoint as AbstractAdjoint
4+
from ._adjoint import BacksolveAdjoint as BacksolveAdjoint
5+
from ._adjoint import DirectAdjoint as DirectAdjoint
6+
from ._adjoint import ImplicitAdjoint as ImplicitAdjoint
7+
from ._adjoint import RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint
8+
from ._autocitation import citation as citation
9+
from ._autocitation import citation_rules as citation_rules
10+
from ._brownian import AbstractBrownianPath
11+
from ._brownian import AbstractBrownianPath as AbstractBrownianPath
12+
from ._brownian import UnsafeBrownianPath
13+
from ._brownian import UnsafeBrownianPath as UnsafeBrownianPath
14+
from ._brownian import VirtualBrownianTree
15+
from ._brownian import VirtualBrownianTree as VirtualBrownianTree
16+
from ._delays import Delays as Delays
17+
from ._delays import bind_history as bind_history
18+
from ._delays import history_extrapolation_implicit as history_extrapolation_implicit
19+
from ._delays import maybe_find_discontinuity as maybe_find_discontinuity
20+
from ._event import AbstractDiscreteTerminatingEvent
21+
from ._event import AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent
22+
from ._event import DiscreteTerminatingEvent
23+
from ._event import DiscreteTerminatingEvent as DiscreteTerminatingEvent
24+
from ._event import SteadyStateEvent
25+
from ._event import SteadyStateEvent as SteadyStateEvent
2126
from ._global_interpolation import (
2227
AbstractGlobalInterpolation as AbstractGlobalInterpolation,
28+
)
29+
from ._global_interpolation import CubicInterpolation as CubicInterpolation
30+
from ._global_interpolation import DenseInterpolation as DenseInterpolation
31+
from ._global_interpolation import LinearInterpolation as LinearInterpolation
32+
from ._global_interpolation import (
2333
backward_hermite_coefficients as backward_hermite_coefficients,
24-
CubicInterpolation as CubicInterpolation,
25-
DenseInterpolation as DenseInterpolation,
26-
linear_interpolation as linear_interpolation,
27-
LinearInterpolation as LinearInterpolation,
34+
)
35+
from ._global_interpolation import linear_interpolation as linear_interpolation
36+
from ._global_interpolation import (
2837
rectilinear_interpolation as rectilinear_interpolation,
2938
)
3039
from ._integrate import diffeqsolve as diffeqsolve
3140
from ._local_interpolation import (
3241
AbstractLocalInterpolation as AbstractLocalInterpolation,
42+
)
43+
from ._local_interpolation import (
3344
FourthOrderPolynomialInterpolation as FourthOrderPolynomialInterpolation,
34-
LocalLinearInterpolation as LocalLinearInterpolation,
35-
ThirdOrderHermitePolynomialInterpolation as ThirdOrderHermitePolynomialInterpolation, # noqa: E501
3645
)
46+
from ._local_interpolation import LocalLinearInterpolation as LocalLinearInterpolation
47+
from ._local_interpolation import (
48+
ThirdOrderHermitePolynomialInterpolation as ThirdOrderHermitePolynomialInterpolation,
49+
) # noqa: E501
3750
from ._misc import adjoint_rms_seminorm as adjoint_rms_seminorm
3851
from ._path import AbstractPath as AbstractPath
39-
from ._root_finder import (
40-
VeryChord as VeryChord,
41-
with_stepsize_controller_tols as with_stepsize_controller_tols,
42-
)
43-
from ._saveat import SaveAt as SaveAt, SubSaveAt as SubSaveAt
44-
from ._solution import (
45-
is_event as is_event,
46-
is_okay as is_okay,
47-
is_successful as is_successful,
48-
RESULTS as RESULTS,
49-
Solution as Solution,
50-
)
51-
from ._solver import (
52-
AbstractAdaptiveSolver as AbstractAdaptiveSolver,
53-
AbstractDIRK as AbstractDIRK,
54-
AbstractERK as AbstractERK,
55-
AbstractESDIRK as AbstractESDIRK,
56-
AbstractImplicitSolver as AbstractImplicitSolver,
57-
AbstractItoSolver as AbstractItoSolver,
58-
AbstractRungeKutta as AbstractRungeKutta,
59-
AbstractSDIRK as AbstractSDIRK,
60-
AbstractSolver as AbstractSolver,
61-
AbstractStratonovichSolver as AbstractStratonovichSolver,
62-
AbstractWrappedSolver as AbstractWrappedSolver,
63-
Bosh3 as Bosh3,
64-
ButcherTableau as ButcherTableau,
65-
CalculateJacobian as CalculateJacobian,
66-
Dopri5 as Dopri5,
67-
Dopri8 as Dopri8,
68-
Euler as Euler,
69-
EulerHeun as EulerHeun,
70-
HalfSolver as HalfSolver,
71-
Heun as Heun,
72-
ImplicitEuler as ImplicitEuler,
73-
ItoMilstein as ItoMilstein,
74-
KenCarp3 as KenCarp3,
75-
KenCarp4 as KenCarp4,
76-
KenCarp5 as KenCarp5,
77-
Kvaerno3 as Kvaerno3,
78-
Kvaerno4 as Kvaerno4,
79-
Kvaerno5 as Kvaerno5,
80-
LeapfrogMidpoint as LeapfrogMidpoint,
81-
Midpoint as Midpoint,
82-
MultiButcherTableau as MultiButcherTableau,
83-
Ralston as Ralston,
84-
ReversibleHeun as ReversibleHeun,
85-
SemiImplicitEuler as SemiImplicitEuler,
86-
Sil3 as Sil3,
87-
StratonovichMilstein as StratonovichMilstein,
88-
Tsit5 as Tsit5,
89-
)
52+
from ._root_finder import VeryChord as VeryChord
53+
from ._root_finder import with_stepsize_controller_tols as with_stepsize_controller_tols
54+
from ._saveat import SaveAt as SaveAt
55+
from ._saveat import SubSaveAt as SubSaveAt
56+
from ._solution import RESULTS as RESULTS
57+
from ._solution import Solution as Solution
58+
from ._solution import is_event as is_event
59+
from ._solution import is_okay as is_okay
60+
from ._solution import is_successful as is_successful
61+
from ._solver import AbstractAdaptiveSolver as AbstractAdaptiveSolver
62+
from ._solver import AbstractDIRK as AbstractDIRK
63+
from ._solver import AbstractERK as AbstractERK
64+
from ._solver import AbstractESDIRK as AbstractESDIRK
65+
from ._solver import AbstractImplicitSolver as AbstractImplicitSolver
66+
from ._solver import AbstractItoSolver as AbstractItoSolver
67+
from ._solver import AbstractRungeKutta as AbstractRungeKutta
68+
from ._solver import AbstractSDIRK as AbstractSDIRK
69+
from ._solver import AbstractSolver as AbstractSolver
70+
from ._solver import AbstractStratonovichSolver as AbstractStratonovichSolver
71+
from ._solver import AbstractWrappedSolver as AbstractWrappedSolver
72+
from ._solver import Bosh3 as Bosh3
73+
from ._solver import ButcherTableau as ButcherTableau
74+
from ._solver import CalculateJacobian as CalculateJacobian
75+
from ._solver import Dopri5 as Dopri5
76+
from ._solver import Dopri8 as Dopri8
77+
from ._solver import Euler as Euler
78+
from ._solver import EulerHeun as EulerHeun
79+
from ._solver import HalfSolver as HalfSolver
80+
from ._solver import Heun as Heun
81+
from ._solver import ImplicitEuler as ImplicitEuler
82+
from ._solver import ItoMilstein as ItoMilstein
83+
from ._solver import KenCarp3 as KenCarp3
84+
from ._solver import KenCarp4 as KenCarp4
85+
from ._solver import KenCarp5 as KenCarp5
86+
from ._solver import Kvaerno3 as Kvaerno3
87+
from ._solver import Kvaerno4 as Kvaerno4
88+
from ._solver import Kvaerno5 as Kvaerno5
89+
from ._solver import LeapfrogMidpoint as LeapfrogMidpoint
90+
from ._solver import Midpoint as Midpoint
91+
from ._solver import MultiButcherTableau as MultiButcherTableau
92+
from ._solver import Ralston as Ralston
93+
from ._solver import ReversibleHeun as ReversibleHeun
94+
from ._solver import SemiImplicitEuler as SemiImplicitEuler
95+
from ._solver import Sil3 as Sil3
96+
from ._solver import StratonovichMilstein as StratonovichMilstein
97+
from ._solver import Tsit5 as Tsit5
9098
from ._step_size_controller import (
9199
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
92-
AbstractStepSizeController as AbstractStepSizeController,
93-
ConstantStepSize as ConstantStepSize,
94-
PIDController as PIDController,
95-
StepTo as StepTo,
96100
)
97-
from ._term import (
98-
AbstractTerm as AbstractTerm,
99-
ControlTerm as ControlTerm,
100-
MultiTerm as MultiTerm,
101-
ODETerm as ODETerm,
102-
WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm,
101+
from ._step_size_controller import (
102+
AbstractStepSizeController as AbstractStepSizeController,
103103
)
104-
104+
from ._step_size_controller import ConstantStepSize as ConstantStepSize
105+
from ._step_size_controller import PIDController as PIDController
106+
from ._step_size_controller import StepTo as StepTo
107+
from ._term import AbstractTerm as AbstractTerm
108+
from ._term import ControlTerm as ControlTerm
109+
from ._term import MultiTerm as MultiTerm
110+
from ._term import ODETerm as ODETerm
111+
from ._term import WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm
105112

106113
__version__ = importlib.metadata.version("diffrax")

diffrax/_adjoint.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def loop(
119119
solver,
120120
stepsize_controller,
121121
discrete_terminating_event,
122+
delays,
122123
saveat,
123124
t0,
124125
t1,
@@ -128,6 +129,7 @@ def loop(
128129
init_state,
129130
passed_solver_state,
130131
passed_controller_state,
132+
y0_history,
131133
) -> Any:
132134
"""Runs the main solve loop. Subclasses can override this to provide custom
133135
backpropagation behaviour; see for example the implementation of
@@ -550,13 +552,15 @@ def _loop_backsolve_bwd(
550552
solver,
551553
stepsize_controller,
552554
discrete_terminating_event,
555+
delays,
553556
saveat,
554557
t0,
555558
t1,
556559
dt0,
557560
max_steps,
558561
throw,
559562
init_state,
563+
y0_history,
560564
):
561565
assert discrete_terminating_event is None
562566

@@ -594,6 +598,8 @@ def _loop_backsolve_bwd(
594598
adjoint=self,
595599
solver=solver,
596600
stepsize_controller=stepsize_controller,
601+
discrete_terminating_event=discrete_terminating_event,
602+
delays=delays,
597603
terms=adjoint_terms,
598604
dt0=None if dt0 is None else -dt0,
599605
max_steps=max_steps,
@@ -773,6 +779,7 @@ def loop(
773779
passed_solver_state,
774780
passed_controller_state,
775781
discrete_terminating_event,
782+
delays,
776783
**kwargs,
777784
):
778785
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
@@ -818,6 +825,10 @@ def loop(
818825
raise NotImplementedError(
819826
"`diffrax.BacksolveAdjoint` is not compatible with events."
820827
)
828+
if delays is not None:
829+
raise NotImplementedError(
830+
"Cannot use `delays` with `adjoint=BacksolveAdjoint()`"
831+
)
821832

822833
y = init_state.y
823834
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
@@ -832,6 +843,7 @@ def loop(
832843
init_state=init_state,
833844
solver=solver,
834845
discrete_terminating_event=discrete_terminating_event,
846+
delays=delays,
835847
**kwargs,
836848
)
837849
final_state = _only_transpose_ys(final_state)

0 commit comments

Comments
 (0)