Skip to content

Commit c7420bd

Browse files
Added strict=True
1 parent 0ee47c9 commit c7420bd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+208
-107
lines changed

diffrax/_adjoint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def get_ys(_final_state):
107107
return final_state
108108

109109

110-
class AbstractAdjoint(eqx.Module):
110+
class AbstractAdjoint(eqx.Module, strict=True):
111111
"""Abstract base class for all adjoint methods."""
112112

113113
@abc.abstractmethod
@@ -167,7 +167,7 @@ def _uncallable(*args, **kwargs):
167167
assert False
168168

169169

170-
class RecursiveCheckpointAdjoint(AbstractAdjoint):
170+
class RecursiveCheckpointAdjoint(AbstractAdjoint, strict=True):
171171
"""Backpropagate through [`diffrax.diffeqsolve`][] by differentiating the numerical
172172
solution directly. This is sometimes known as "discretise-then-optimise", or
173173
described as "backpropagation through the solver".
@@ -318,7 +318,7 @@ def loop(
318318
"""
319319

320320

321-
class DirectAdjoint(AbstractAdjoint):
321+
class DirectAdjoint(AbstractAdjoint, strict=True):
322322
"""A variant of [`diffrax.RecursiveCheckpointAdjoint`][]. The differences are that
323323
`DirectAdjoint`:
324324
@@ -434,7 +434,7 @@ def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]:
434434
return frozenset(iter_x)
435435

436436

437-
class ImplicitAdjoint(AbstractAdjoint):
437+
class ImplicitAdjoint(AbstractAdjoint, strict=True):
438438
r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).
439439
440440
This is used when solving towards a steady state, typically using
@@ -705,7 +705,7 @@ def __get(__aug):
705705
return a_y1, a_diff_args1, a_diff_terms1
706706

707707

708-
class BacksolveAdjoint(AbstractAdjoint):
708+
class BacksolveAdjoint(AbstractAdjoint, strict=True):
709709
"""Backpropagate through [`diffrax.diffeqsolve`][] by solving the continuous
710710
adjoint equations backwards-in-time. This is also sometimes known as
711711
"optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint

diffrax/_brownian/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .._path import AbstractPath
99

1010

11-
class AbstractBrownianPath(AbstractPath):
11+
class AbstractBrownianPath(AbstractPath, strict=True):
1212
"""Abstract base class for all Brownian paths."""
1313

1414
levy_area: AbstractVar[LevyArea]

diffrax/_brownian/path.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .base import AbstractBrownianPath
2020

2121

22-
class UnsafeBrownianPath(AbstractBrownianPath):
22+
class UnsafeBrownianPath(AbstractBrownianPath, strict=True):
2323
"""Brownian simulation that is only suitable for certain cases.
2424
2525
This is a very quick way to simulate Brownian motion, but can only be used when all

diffrax/_brownian/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
_Spline: TypeAlias = Literal["sqrt", "quad", "zero"]
6161

6262

63-
class _State(eqx.Module):
63+
class _State(eqx.Module, strict=True):
6464
level: IntScalarLike # level of the tree
6565
s: RealScalarLike # starting time of the interval
6666
w_s_u_su: FloatTriple # W_s, W_u, W_{s,u}
@@ -109,7 +109,7 @@ def _split_interval(
109109
return x_s, x_u, x_su
110110

111111

112-
class VirtualBrownianTree(AbstractBrownianPath):
112+
class VirtualBrownianTree(AbstractBrownianPath, strict=True):
113113
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`.
114114
115115
Can be initialised with `levy_area` set to `""`, or `"space-time"`.

diffrax/_custom_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
LevyArea: TypeAlias = Literal["", "space-time"]
5757

5858

59-
class LevyVal(eqx.Module):
59+
class LevyVal(eqx.Module, strict=True):
6060
dt: PyTree
6161
W: PyTree
6262
H: Optional[PyTree]

diffrax/_event.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ._step_size_controller import AbstractAdaptiveStepSizeController
1111

1212

13-
class AbstractDiscreteTerminatingEvent(eqx.Module):
13+
class AbstractDiscreteTerminatingEvent(eqx.Module, strict=True):
1414
"""Evaluated at the end of each integration step. If true then the solve is stopped
1515
at that time.
1616
"""
@@ -30,7 +30,7 @@ def __call__(self, state, **kwargs) -> BoolScalarLike:
3030
"""
3131

3232

33-
class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent):
33+
class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent, strict=True):
3434
"""Terminates the solve if its condition is ever active."""
3535

3636
cond_fn: Callable[..., BoolScalarLike]
@@ -50,7 +50,7 @@ def __call__(self, state, **kwargs):
5050
"""
5151

5252

53-
class SteadyStateEvent(AbstractDiscreteTerminatingEvent):
53+
class SteadyStateEvent(AbstractDiscreteTerminatingEvent, strict=True):
5454
"""Terminates the solve once it reaches a steady state."""
5555

5656
rtol: Optional[float] = None

diffrax/_global_interpolation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ._path import AbstractPath
2424

2525

26-
class AbstractGlobalInterpolation(AbstractPath):
26+
class AbstractGlobalInterpolation(AbstractPath, strict=True):
2727
ts: AbstractVar[Real[Array, " times"]]
2828
ts_size: AbstractVar[IntScalarLike]
2929

@@ -52,7 +52,7 @@ def t1(self):
5252
return self.ts[-1]
5353

5454

55-
class LinearInterpolation(AbstractGlobalInterpolation):
55+
class LinearInterpolation(AbstractGlobalInterpolation, strict=True):
5656
"""Linearly interpolates some data `ys` over the interval $[t_0, t_1]$ with knots
5757
at `ts`.
5858
@@ -178,7 +178,7 @@ def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]:
178178
"""
179179

180180

181-
class CubicInterpolation(AbstractGlobalInterpolation):
181+
class CubicInterpolation(AbstractGlobalInterpolation, strict=True):
182182
"""Piecewise cubic spline interpolation over the interval $[t_0, t_1]$."""
183183

184184
ts: Real[Array, " times"]
@@ -302,7 +302,7 @@ def derivative(
302302
"""
303303

304304

305-
class DenseInterpolation(AbstractGlobalInterpolation):
305+
class DenseInterpolation(AbstractGlobalInterpolation, strict=True):
306306
ts: Real[Array, " times"]
307307
# DenseInterpolations typically get `ts` and `infos` that are way longer than they
308308
# need to be, and padded with `nan`s. This means the normal way of measuring how

diffrax/_integrate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@
4747
from ._term import AbstractTerm, MultiTerm, ODETerm, WrapTerm
4848

4949

50-
class SaveState(eqx.Module):
50+
class SaveState(eqx.Module, strict=True):
5151
saveat_ts_index: IntScalarLike
5252
ts: eqxi.MaybeBuffer[Real[Array, " times"]]
5353
ys: PyTree[eqxi.MaybeBuffer[Inexact[Array, "times ..."]]]
5454
save_index: IntScalarLike
5555

5656

57-
class State(eqx.Module):
57+
class State(eqx.Module, strict=True):
5858
# Evolving state during the solve
5959
y: PyTree[Array]
6060
tprev: FloatScalarLike

diffrax/_local_interpolation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional, TYPE_CHECKING
22

3+
import equinox as eqx
34
import jax.numpy as jnp
45
import jax.tree_util as jtu
56
import numpy as np
@@ -17,11 +18,11 @@
1718
from ._path import AbstractPath
1819

1920

20-
class AbstractLocalInterpolation(AbstractPath):
21+
class AbstractLocalInterpolation(AbstractPath, strict=True):
2122
pass
2223

2324

24-
class LocalLinearInterpolation(AbstractLocalInterpolation):
25+
class LocalLinearInterpolation(AbstractLocalInterpolation, strict=True):
2526
t0: RealScalarLike
2627
t1: RealScalarLike
2728
y0: Y
@@ -39,7 +40,7 @@ def evaluate(
3940
return (coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω
4041

4142

42-
class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation):
43+
class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation, strict=True):
4344
t0: RealScalarLike
4445
t1: RealScalarLike
4546
coeffs: PyTree[Shaped[Array, "4 ?*dims"], "Y"]
@@ -83,7 +84,9 @@ def _eval(_coeffs):
8384
return jtu.tree_map(_eval, self.coeffs)
8485

8586

86-
class FourthOrderPolynomialInterpolation(AbstractLocalInterpolation):
87+
class FourthOrderPolynomialInterpolation(
88+
AbstractLocalInterpolation, strict=eqx.StrictConfig(allow_abstract_name=True)
89+
):
8790
t0: RealScalarLike
8891
t1: RealScalarLike
8992
coeffs: PyTree[Shaped[Array, "5 ?*y"], "Y"]

diffrax/_path.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ._custom_types import RealScalarLike
1616

1717

18-
class AbstractPath(eqx.Module):
18+
class AbstractPath(eqx.Module, strict=True):
1919
"""Abstract base class for all paths.
2020
2121
Every path has a start point `t0` and an end point `t1`. In between these values

diffrax/_root_finder/_verychord.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _converged(factor: Scalar, tol: float) -> Bool[Array, ""]:
3030
return (factor > 0) & (factor < tol)
3131

3232

33-
class _VeryChordState(eqx.Module):
33+
class _VeryChordState(eqx.Module, strict=True):
3434
linear_state: tuple[lx.AbstractLinearOperator, PyTree[Any]]
3535
diff: Y
3636
diffsize: Scalar
@@ -39,7 +39,7 @@ class _VeryChordState(eqx.Module):
3939
step: Scalar
4040

4141

42-
class _NoAux(eqx.Module):
42+
class _NoAux(eqx.Module, strict=True):
4343
fn: Callable
4444

4545
def __call__(self, y, args):
@@ -48,7 +48,7 @@ def __call__(self, y, args):
4848
return out
4949

5050

51-
class VeryChord(optx.AbstractRootFinder):
51+
class VeryChord(optx.AbstractRootFinder, strict=True):
5252
"""The Chord method of root finding.
5353
5454
As `optimistix.Chord`, except that in Runge--Kutta methods, the linearisation point

diffrax/_saveat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _convert_ts(
2121
return jnp.asarray(ts)
2222

2323

24-
class SubSaveAt(eqx.Module):
24+
class SubSaveAt(eqx.Module, strict=True):
2525
"""Used for finer-grained control over what is saved. A PyTree of these should be
2626
passed to `SaveAt(subs=...)`.
2727
@@ -53,7 +53,7 @@ def __check_init__(self):
5353
"""
5454

5555

56-
class SaveAt(eqx.Module):
56+
class SaveAt(eqx.Module, strict=True):
5757
"""Determines what to save as output from the differential equation solve.
5858
5959
Instances of this class should be passed as the `saveat` argument of

diffrax/_solution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Optional
22

3+
import equinox as eqx
34
import jax
45
import optimistix as optx
56
from jaxtyping import Array, Bool, PyTree, Real, Shaped
@@ -55,7 +56,7 @@ def update_result(old_result: RESULTS, new_result: RESULTS) -> RESULTS:
5556
return RESULTS.where(pred, old_result, out_result)
5657

5758

58-
class Solution(AbstractPath):
59+
class Solution(AbstractPath, strict=eqx.StrictConfig(allow_method_override=True)):
5960
"""The solution to a differential equation.
6061
6162
**Attributes:**

diffrax/_solver/base.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __instancecheck__(cls, obj):
4949
_set_metaclass = dict(metaclass=_MetaAbstractSolver)
5050

5151

52-
class AbstractSolver(eqx.Module, Generic[_SolverState], **_set_metaclass):
52+
class AbstractSolver(eqx.Module, Generic[_SolverState], strict=True, **_set_metaclass):
5353
"""Abstract base class for all differential equation solvers.
5454
5555
Subclasses should have a class-level attribute `terms`, specifying the PyTree
@@ -179,7 +179,7 @@ def func(
179179
"""
180180

181181

182-
class AbstractImplicitSolver(AbstractSolver[_SolverState]):
182+
class AbstractImplicitSolver(AbstractSolver[_SolverState], strict=True):
183183
"""Indicates that this is an implicit differential equation solver, and as such
184184
that it should take a root finder as an argument.
185185
"""
@@ -188,25 +188,25 @@ class AbstractImplicitSolver(AbstractSolver[_SolverState]):
188188
root_find_max_steps: AbstractVar[int]
189189

190190

191-
class AbstractItoSolver(AbstractSolver[_SolverState]):
191+
class AbstractItoSolver(AbstractSolver[_SolverState], strict=True):
192192
"""Indicates that when used as an SDE solver that this solver will converge to the
193193
Itô solution.
194194
"""
195195

196196

197-
class AbstractStratonovichSolver(AbstractSolver[_SolverState]):
197+
class AbstractStratonovichSolver(AbstractSolver[_SolverState], strict=True):
198198
"""Indicates that when used as an SDE solver that this solver will converge to the
199199
Stratonovich solution.
200200
"""
201201

202202

203-
class AbstractAdaptiveSolver(AbstractSolver[_SolverState]):
203+
class AbstractAdaptiveSolver(AbstractSolver[_SolverState], strict=True):
204204
"""Indicates that this solver provides error estimates, and that as such it may be
205205
used with an adaptive step size controller.
206206
"""
207207

208208

209-
class AbstractWrappedSolver(AbstractSolver[_SolverState]):
209+
class AbstractWrappedSolver(AbstractSolver[_SolverState], strict=True):
210210
"""Wraps another solver "transparently", in the sense that all `isinstance` checks
211211
will be forwarded on to the wrapped solver, e.g. when testing whether the solver is
212212
implicit/adaptive/SDE-compatible/etc.
@@ -219,7 +219,9 @@ class if that is not desired behaviour.)
219219

220220

221221
class HalfSolver(
222-
AbstractAdaptiveSolver[_SolverState], AbstractWrappedSolver[_SolverState]
222+
AbstractAdaptiveSolver[_SolverState],
223+
AbstractWrappedSolver[_SolverState],
224+
strict=eqx.StrictConfig(allow_method_override=True),
223225
):
224226
"""Wraps another solver, trading cost in order to provide error estimates. (That
225227
is, it means the solver can be used with an adaptive step size controller,

diffrax/_solver/bosh3.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Callable
2-
from typing import ClassVar
2+
from typing import ClassVar, Literal, Union
33

4+
import equinox as eqx
45
import numpy as np
56

67
from .._local_interpolation import ThirdOrderHermitePolynomialInterpolation
@@ -19,7 +20,7 @@
1920
)
2021

2122

22-
class Bosh3(AbstractERK):
23+
class Bosh3(AbstractERK, strict=eqx.StrictConfig(allow_method_override=True)):
2324
"""Bogacki--Shampine's 3/2 method.
2425
2526
3rd order explicit Runge--Kutta method. Has an embedded 2nd order method for
@@ -29,6 +30,8 @@ class Bosh3(AbstractERK):
2930
Also sometimes known as "Ralston's third order method".
3031
"""
3132

33+
scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None
34+
3235
tableau: ClassVar[ButcherTableau] = _bosh3_tableau
3336
interpolation_cls: ClassVar[
3437
Callable[..., ThirdOrderHermitePolynomialInterpolation]

0 commit comments

Comments
 (0)