Skip to content

First DDE version #169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,28 @@
)
from ._global_interpolation import (
AbstractGlobalInterpolation as AbstractGlobalInterpolation,
)
from ._global_interpolation import CubicInterpolation as CubicInterpolation
from ._global_interpolation import DenseInterpolation as DenseInterpolation
from ._global_interpolation import LinearInterpolation as LinearInterpolation
from ._global_interpolation import (
backward_hermite_coefficients as backward_hermite_coefficients,
CubicInterpolation as CubicInterpolation,
DenseInterpolation as DenseInterpolation,
linear_interpolation as linear_interpolation,
LinearInterpolation as LinearInterpolation,
)
from ._global_interpolation import linear_interpolation as linear_interpolation
from ._global_interpolation import (
rectilinear_interpolation as rectilinear_interpolation,
)
from ._integrate import diffeqsolve as diffeqsolve
from ._local_interpolation import (
AbstractLocalInterpolation as AbstractLocalInterpolation,
)
from ._local_interpolation import (
FourthOrderPolynomialInterpolation as FourthOrderPolynomialInterpolation,
LocalLinearInterpolation as LocalLinearInterpolation,
ThirdOrderHermitePolynomialInterpolation as ThirdOrderHermitePolynomialInterpolation, # noqa: E501
)
from ._local_interpolation import LocalLinearInterpolation as LocalLinearInterpolation
from ._local_interpolation import (
ThirdOrderHermitePolynomialInterpolation as ThirdOrderHermitePolynomialInterpolation,
) # noqa: E501
from ._misc import adjoint_rms_seminorm as adjoint_rms_seminorm
from ._path import AbstractPath as AbstractPath
from ._progress_meter import (
Expand Down Expand Up @@ -120,6 +128,8 @@
)
from ._step_size_controller import (
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
)
from ._step_size_controller import (
AbstractStepSizeController as AbstractStepSizeController,
ClipStepSizeController as ClipStepSizeController,
ConstantStepSize as ConstantStepSize,
Expand All @@ -135,6 +145,18 @@
UnderdampedLangevinDriftTerm as UnderdampedLangevinDriftTerm,
WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm,
)
from ._step_size_controller import ConstantStepSize as ConstantStepSize
from ._step_size_controller import PIDController as PIDController
from ._step_size_controller import StepTo as StepTo
from ._term import AbstractTerm as AbstractTerm
from ._term import ControlTerm as ControlTerm
from ._term import MultiTerm as MultiTerm
from ._term import ODETerm as ODETerm
from ._term import WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm

from ._delays import Delays as Delays
from ._delays import bind_history as bind_history
from ._delays import history_extrapolation_implicit as history_extrapolation_implicit
from ._delays import maybe_find_discontinuity as maybe_find_discontinuity

__version__ = importlib.metadata.version("diffrax")
11 changes: 11 additions & 0 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def loop(
solver,
stepsize_controller,
event,
delays,
saveat,
t0,
t1,
Expand All @@ -137,6 +138,7 @@ def loop(
passed_solver_state,
passed_controller_state,
progress_meter,
y0_history,
) -> Any:
"""Runs the main solve loop. Subclasses can override this to provide custom
backpropagation behaviour; see for example the implementation of
Expand Down Expand Up @@ -578,6 +580,7 @@ def _loop_backsolve_bwd(
solver,
stepsize_controller,
event,
delays,
saveat,
t0,
t1,
Expand All @@ -586,6 +589,7 @@ def _loop_backsolve_bwd(
throw,
init_state,
progress_meter,
y0_history,
):
assert event is None

Expand Down Expand Up @@ -623,6 +627,7 @@ def _loop_backsolve_bwd(
adjoint=self,
solver=solver,
stepsize_controller=stepsize_controller,
delays=delays,
terms=adjoint_terms,
dt0=None if dt0 is None else -dt0,
max_steps=max_steps,
Expand Down Expand Up @@ -804,6 +809,7 @@ def loop(
passed_solver_state,
passed_controller_state,
event,
delays,
**kwargs,
):
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
Expand Down Expand Up @@ -849,6 +855,10 @@ def loop(
raise NotImplementedError(
"`diffrax.BacksolveAdjoint` is not compatible with events."
)
if delays is not None:
raise NotImplementedError(
"Cannot use `delays` with `adjoint=BacksolveAdjoint()`"
)

y = init_state.y
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
Expand All @@ -863,6 +873,7 @@ def loop(
init_state=init_state,
solver=solver,
event=event,
delays=delays,
**kwargs,
)
final_state = _only_transpose_ys(final_state)
Expand Down
Loading
Loading