|
2 | 2 | import typing
|
3 | 3 | import warnings
|
4 | 4 | from collections.abc import Callable
|
5 |
| -from typing import Any, cast, get_args, get_origin, Optional, Tuple, TYPE_CHECKING |
| 5 | +from typing import Any, get_args, get_origin, Optional, Tuple, TYPE_CHECKING |
6 | 6 |
|
7 | 7 | import equinox as eqx
|
8 | 8 | import equinox.internal as eqxi
|
@@ -736,27 +736,44 @@ def _wrap(term):
|
736 | 736 | is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm),
|
737 | 737 | )
|
738 | 738 |
|
739 |
| - if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): |
740 |
| - if isinstance(solver, AbstractImplicitSolver): |
741 |
| - if solver.root_finder.rtol is use_stepsize_tol: |
742 |
| - solver = eqx.tree_at( |
743 |
| - lambda s: s.root_finder.rtol, |
744 |
| - solver, |
745 |
| - stepsize_controller.rtol, |
746 |
| - ) |
747 |
| - solver = cast(AbstractImplicitSolver, solver) |
748 |
| - if solver.root_finder.atol is use_stepsize_tol: |
749 |
| - solver = eqx.tree_at( |
750 |
| - lambda s: s.root_finder.atol, |
751 |
| - solver, |
752 |
| - stepsize_controller.atol, |
753 |
| - ) |
754 |
| - solver = cast(AbstractImplicitSolver, solver) |
755 |
| - if solver.root_finder.norm is use_stepsize_tol: |
756 |
| - solver = eqx.tree_at( |
757 |
| - lambda s: s.root_finder.norm, |
758 |
| - solver, |
759 |
| - stepsize_controller.norm, |
| 739 | + if isinstance(solver, AbstractImplicitSolver): |
| 740 | + |
| 741 | + def _get_tols(x): |
| 742 | + outs = [] |
| 743 | + for attr in ("rtol", "atol", "norm"): |
| 744 | + if getattr(solver.root_finder, attr) is use_stepsize_tol: # pyright: ignore |
| 745 | + outs.append(getattr(x, attr)) |
| 746 | + return tuple(outs) |
| 747 | + |
| 748 | + if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): |
| 749 | + solver = eqx.tree_at( |
| 750 | + lambda s: _get_tols(s.root_finder), |
| 751 | + solver, |
| 752 | + _get_tols(stepsize_controller), |
| 753 | + ) |
| 754 | + else: |
| 755 | + if len(_get_tols(solver.root_finder)) > 0: |
| 756 | + raise ValueError( |
| 757 | + "A fixed step size controller is being used alongside an implicit " |
| 758 | + "solver, but the tolerances for the implicit solver have not been " |
| 759 | + "specified. (Being unspecified is the default in Diffrax.)\n" |
| 760 | + "The correct fix is almost always to use an adaptive step size " |
| 761 | + "controller. For example " |
| 762 | + "`diffrax.diffeqsolve(..., " |
| 763 | + "stepsize_controller=diffrax.PIDController(rtol=..., atol=...))`. " |
| 764 | + "In this case the same tolerances are used for the implicit " |
| 765 | + "solver as are used to control the adaptive stepping.\n" |
| 766 | + "(Note for advanced users: the tolerances for the implicit " |
| 767 | + "solver can also be explicitly set instead. For example " |
| 768 | + "`diffrax.diffeqsolve(..., solver=diffrax.Kvaerno5(root_finder=" |
| 769 | + "diffrax.VeryChord(rtol=..., atol=..., " |
| 770 | + "norm=optimistix.max_norm)))`. In this case the norm must also be " |
| 771 | + "explicitly specified.)\n" |
| 772 | + "Adaptive step size controllers are the preferred solution, as " |
| 773 | + "sometimes the implicit solver may fail to converge, and in this " |
| 774 | + "case an adaptive step size controller can reject the step and try " |
| 775 | + "a smaller one, whilst with a fixed step size controller the " |
| 776 | + "overall differential equation solve will simply fail." |
760 | 777 | )
|
761 | 778 |
|
762 | 779 | # Error checking
|
|
0 commit comments