Skip to content

Commit 7965e89

Browse files
Fix a crash and an unclear error message.
1. Fixes a spurious crash when using an implicit solver with DirectAdjoint. 2. Fixes the unclear error message when using an implicit solver without an adaptive step size controller.
1 parent e240aea commit 7965e89

File tree

5 files changed

+77
-24
lines changed

5 files changed

+77
-24
lines changed

diffrax/_adjoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,9 @@ def loop(
368368
# Support forward-mode autodiff.
369369
# TODO: remove this hack once we can JVP through custom_vjps.
370370
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
371-
solver = eqx.tree_at(lambda s: s.scan_kind, solver, "bounded")
371+
solver = eqx.tree_at(
372+
lambda s: s.scan_kind, solver, "bounded", is_leaf=_is_none
373+
)
372374
inner_while_loop = ft.partial(_inner_loop, kind=kind)
373375
outer_while_loop = ft.partial(_outer_loop, kind=kind)
374376
final_state = self._loop(

diffrax/_integrate.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import typing
33
import warnings
44
from collections.abc import Callable
5-
from typing import Any, cast, get_args, get_origin, Optional, Tuple, TYPE_CHECKING
5+
from typing import Any, get_args, get_origin, Optional, Tuple, TYPE_CHECKING
66

77
import equinox as eqx
88
import equinox.internal as eqxi
@@ -736,27 +736,44 @@ def _wrap(term):
736736
is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm),
737737
)
738738

739-
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
740-
if isinstance(solver, AbstractImplicitSolver):
741-
if solver.root_finder.rtol is use_stepsize_tol:
742-
solver = eqx.tree_at(
743-
lambda s: s.root_finder.rtol,
744-
solver,
745-
stepsize_controller.rtol,
746-
)
747-
solver = cast(AbstractImplicitSolver, solver)
748-
if solver.root_finder.atol is use_stepsize_tol:
749-
solver = eqx.tree_at(
750-
lambda s: s.root_finder.atol,
751-
solver,
752-
stepsize_controller.atol,
753-
)
754-
solver = cast(AbstractImplicitSolver, solver)
755-
if solver.root_finder.norm is use_stepsize_tol:
756-
solver = eqx.tree_at(
757-
lambda s: s.root_finder.norm,
758-
solver,
759-
stepsize_controller.norm,
739+
if isinstance(solver, AbstractImplicitSolver):
740+
741+
def _get_tols(x):
742+
outs = []
743+
for attr in ("rtol", "atol", "norm"):
744+
if getattr(solver.root_finder, attr) is use_stepsize_tol: # pyright: ignore
745+
outs.append(getattr(x, attr))
746+
return tuple(outs)
747+
748+
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
749+
solver = eqx.tree_at(
750+
lambda s: _get_tols(s.root_finder),
751+
solver,
752+
_get_tols(stepsize_controller),
753+
)
754+
else:
755+
if len(_get_tols(solver.root_finder)) > 0:
756+
raise ValueError(
757+
"A fixed step size controller is being used alongside an implicit "
758+
"solver, but the tolerances for the implicit solver have not been "
759+
"specified. (Being unspecified is the default in Diffrax.)\n"
760+
"The correct fix is almost always to use an adaptive step size "
761+
"controller. For example "
762+
"`diffrax.diffeqsolve(..., "
763+
"stepsize_controller=diffrax.PIDController(rtol=..., atol=...))`. "
764+
"In this case the same tolerances are used for the implicit "
765+
"solver as are used to control the adaptive stepping.\n"
766+
"(Note for advanced users: the tolerances for the implicit "
767+
"solver can also be explicitly set instead. For example "
768+
"`diffrax.diffeqsolve(..., solver=diffrax.Kvaerno5(root_finder="
769+
"diffrax.VeryChord(rtol=..., atol=..., "
770+
"norm=optimistix.max_norm)))`. In this case the norm must also be "
771+
"explicitly specified.)\n"
772+
"Adaptive step size controllers are the preferred solution, as "
773+
"sometimes the implicit solver may fail to converge, and in this "
774+
"case an adaptive step size controller can reject the step and try "
775+
"a smaller one, whilst with a fixed step size controller the "
776+
"overall differential equation solve will simply fail."
760777
)
761778

762779
# Error checking

diffrax/_root_finder/_with_tols.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33
import optimistix as optx
44

55

6-
use_stepsize_tol = object()
6+
class _UseStepSizeTol:
7+
def __repr__(self):
8+
return (
9+
"<tolerance taken from `diffeqsolve(..., stepsize_controller=...)` "
10+
"argument>"
11+
)
12+
13+
14+
use_stepsize_tol = _UseStepSizeTol()
715

816

917
def with_stepsize_controller_tols(cls: type[optx.AbstractRootFinder]):

test/test_adjoint.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,3 +363,16 @@ def run(y0__args, adjoint):
363363
grads3 = run((y0, args), diffrax.RecursiveCheckpointAdjoint())
364364
assert tree_allclose(grads1, grads2, rtol=1e-3, atol=1e-3)
365365
assert tree_allclose(grads1, grads3, rtol=1e-3, atol=1e-3)
366+
367+
368+
def test_implicit_runge_kutta_direct_adjoint():
369+
diffrax.diffeqsolve(
370+
diffrax.ODETerm(lambda t, y, args: -y),
371+
diffrax.Kvaerno5(),
372+
0,
373+
1,
374+
0.01,
375+
1.0,
376+
adjoint=diffrax.DirectAdjoint(),
377+
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
378+
)

test/test_integrate.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,3 +504,16 @@ def vector_field(t, y, args):
504504
assert text == "static_made_jump=False static_result=None\n"
505505
finally:
506506
diffrax._integrate._PRINT_STATIC = False
507+
508+
509+
def test_implicit_tol_error():
510+
msg = "the tolerances for the implicit solver have not been specified"
511+
with pytest.raises(ValueError, match=msg):
512+
diffrax.diffeqsolve(
513+
diffrax.ODETerm(lambda t, y, args: -y),
514+
diffrax.Kvaerno5(),
515+
0,
516+
1,
517+
0.01,
518+
1.0,
519+
)

0 commit comments

Comments
 (0)