Skip to content

Commit ce8c1d4

Browse files
Added strict=True
1 parent 0ee47c9 commit ce8c1d4

26 files changed

+79
-69
lines changed

diffrax/_adjoint.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def get_ys(_final_state):
107107
return final_state
108108

109109

110-
class AbstractAdjoint(eqx.Module):
110+
class AbstractAdjoint(eqx.Module, strict=True):
111111
"""Abstract base class for all adjoint methods."""
112112

113113
@abc.abstractmethod
@@ -167,7 +167,7 @@ def _uncallable(*args, **kwargs):
167167
assert False
168168

169169

170-
class RecursiveCheckpointAdjoint(AbstractAdjoint):
170+
class RecursiveCheckpointAdjoint(AbstractAdjoint, strict=True):
171171
"""Backpropagate through [`diffrax.diffeqsolve`][] by differentiating the numerical
172172
solution directly. This is sometimes known as "discretise-then-optimise", or
173173
described as "backpropagation through the solver".
@@ -318,7 +318,7 @@ def loop(
318318
"""
319319

320320

321-
class DirectAdjoint(AbstractAdjoint):
321+
class DirectAdjoint(AbstractAdjoint, strict=True):
322322
"""A variant of [`diffrax.RecursiveCheckpointAdjoint`][]. The differences are that
323323
`DirectAdjoint`:
324324
@@ -434,7 +434,7 @@ def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]:
434434
return frozenset(iter_x)
435435

436436

437-
class ImplicitAdjoint(AbstractAdjoint):
437+
class ImplicitAdjoint(AbstractAdjoint, strict=True):
438438
r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).
439439
440440
This is used when solving towards a steady state, typically using
@@ -705,7 +705,7 @@ def __get(__aug):
705705
return a_y1, a_diff_args1, a_diff_terms1
706706

707707

708-
class BacksolveAdjoint(AbstractAdjoint):
708+
class BacksolveAdjoint(AbstractAdjoint, strict=True):
709709
"""Backpropagate through [`diffrax.diffeqsolve`][] by solving the continuous
710710
adjoint equations backwards-in-time. This is also sometimes known as
711711
"optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint

diffrax/_brownian/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .._path import AbstractPath
99

1010

11-
class AbstractBrownianPath(AbstractPath):
11+
class AbstractBrownianPath(AbstractPath, strict=True):
1212
"""Abstract base class for all Brownian paths."""
1313

1414
levy_area: AbstractVar[LevyArea]

diffrax/_brownian/path.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .base import AbstractBrownianPath
2020

2121

22-
class UnsafeBrownianPath(AbstractBrownianPath):
22+
class UnsafeBrownianPath(AbstractBrownianPath, strict=True):
2323
"""Brownian simulation that is only suitable for certain cases.
2424
2525
This is a very quick way to simulate Brownian motion, but can only be used when all

diffrax/_brownian/tree.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
_Spline: TypeAlias = Literal["sqrt", "quad", "zero"]
6161

6262

63-
class _State(eqx.Module):
63+
class _State(eqx.Module, strict=True):
6464
level: IntScalarLike # level of the tree
6565
s: RealScalarLike # starting time of the interval
6666
w_s_u_su: FloatTriple # W_s, W_u, W_{s,u}
@@ -109,7 +109,7 @@ def _split_interval(
109109
return x_s, x_u, x_su
110110

111111

112-
class VirtualBrownianTree(AbstractBrownianPath):
112+
class VirtualBrownianTree(AbstractBrownianPath, strict=True):
113113
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`.
114114
115115
Can be initialised with `levy_area` set to `""`, or `"space-time"`.

diffrax/_custom_types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
LevyArea: TypeAlias = Literal["", "space-time"]
5757

5858

59-
class LevyVal(eqx.Module):
59+
class LevyVal(eqx.Module, strict=True):
6060
dt: PyTree
6161
W: PyTree
6262
H: Optional[PyTree]

diffrax/_event.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ._step_size_controller import AbstractAdaptiveStepSizeController
1111

1212

13-
class AbstractDiscreteTerminatingEvent(eqx.Module):
13+
class AbstractDiscreteTerminatingEvent(eqx.Module, strict=True):
1414
"""Evaluated at the end of each integration step. If true then the solve is stopped
1515
at that time.
1616
"""
@@ -30,7 +30,7 @@ def __call__(self, state, **kwargs) -> BoolScalarLike:
3030
"""
3131

3232

33-
class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent):
33+
class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent, strict=True):
3434
"""Terminates the solve if its condition is ever active."""
3535

3636
cond_fn: Callable[..., BoolScalarLike]
@@ -50,7 +50,7 @@ def __call__(self, state, **kwargs):
5050
"""
5151

5252

53-
class SteadyStateEvent(AbstractDiscreteTerminatingEvent):
53+
class SteadyStateEvent(AbstractDiscreteTerminatingEvent, strict=True):
5454
"""Terminates the solve once it reaches a steady state."""
5555

5656
rtol: Optional[float] = None

diffrax/_global_interpolation.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ._path import AbstractPath
2424

2525

26-
class AbstractGlobalInterpolation(AbstractPath):
26+
class AbstractGlobalInterpolation(AbstractPath, strict=True):
2727
ts: AbstractVar[Real[Array, " times"]]
2828
ts_size: AbstractVar[IntScalarLike]
2929

@@ -52,7 +52,7 @@ def t1(self):
5252
return self.ts[-1]
5353

5454

55-
class LinearInterpolation(AbstractGlobalInterpolation):
55+
class LinearInterpolation(AbstractGlobalInterpolation, strict=True):
5656
"""Linearly interpolates some data `ys` over the interval $[t_0, t_1]$ with knots
5757
at `ts`.
5858
@@ -178,7 +178,7 @@ def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]:
178178
"""
179179

180180

181-
class CubicInterpolation(AbstractGlobalInterpolation):
181+
class CubicInterpolation(AbstractGlobalInterpolation, strict=True):
182182
"""Piecewise cubic spline interpolation over the interval $[t_0, t_1]$."""
183183

184184
ts: Real[Array, " times"]
@@ -302,7 +302,7 @@ def derivative(
302302
"""
303303

304304

305-
class DenseInterpolation(AbstractGlobalInterpolation):
305+
class DenseInterpolation(AbstractGlobalInterpolation, strict=True):
306306
ts: Real[Array, " times"]
307307
# DenseInterpolations typically get `ts` and `infos` that are way longer than they
308308
# need to be, and padded with `nan`s. This means the normal way of measuring how

diffrax/_integrate.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@
4747
from ._term import AbstractTerm, MultiTerm, ODETerm, WrapTerm
4848

4949

50-
class SaveState(eqx.Module):
50+
class SaveState(eqx.Module, strict=True):
5151
saveat_ts_index: IntScalarLike
5252
ts: eqxi.MaybeBuffer[Real[Array, " times"]]
5353
ys: PyTree[eqxi.MaybeBuffer[Inexact[Array, "times ..."]]]
5454
save_index: IntScalarLike
5555

5656

57-
class State(eqx.Module):
57+
class State(eqx.Module, strict=True):
5858
# Evolving state during the solve
5959
y: PyTree[Array]
6060
tprev: FloatScalarLike

diffrax/_local_interpolation.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional, TYPE_CHECKING
22

3+
import equinox as eqx
34
import jax.numpy as jnp
45
import jax.tree_util as jtu
56
import numpy as np
@@ -17,11 +18,11 @@
1718
from ._path import AbstractPath
1819

1920

20-
class AbstractLocalInterpolation(AbstractPath):
21+
class AbstractLocalInterpolation(AbstractPath, strict=True):
2122
pass
2223

2324

24-
class LocalLinearInterpolation(AbstractLocalInterpolation):
25+
class LocalLinearInterpolation(AbstractLocalInterpolation, strict=True):
2526
t0: RealScalarLike
2627
t1: RealScalarLike
2728
y0: Y
@@ -39,7 +40,7 @@ def evaluate(
3940
return (coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω
4041

4142

42-
class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation):
43+
class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation, strict=True):
4344
t0: RealScalarLike
4445
t1: RealScalarLike
4546
coeffs: PyTree[Shaped[Array, "4 ?*dims"], "Y"]
@@ -83,7 +84,9 @@ def _eval(_coeffs):
8384
return jtu.tree_map(_eval, self.coeffs)
8485

8586

86-
class FourthOrderPolynomialInterpolation(AbstractLocalInterpolation):
87+
class FourthOrderPolynomialInterpolation(
88+
AbstractLocalInterpolation, strict=eqx.StrictConfig(allow_abstract_name=True)
89+
):
8790
t0: RealScalarLike
8891
t1: RealScalarLike
8992
coeffs: PyTree[Shaped[Array, "5 ?*y"], "Y"]

diffrax/_path.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ._custom_types import RealScalarLike
1616

1717

18-
class AbstractPath(eqx.Module):
18+
class AbstractPath(eqx.Module, strict=True):
1919
"""Abstract base class for all paths.
2020
2121
Every path has a start point `t0` and an end point `t1`. In between these values

diffrax/_root_finder/_verychord.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _converged(factor: Scalar, tol: float) -> Bool[Array, ""]:
3030
return (factor > 0) & (factor < tol)
3131

3232

33-
class _VeryChordState(eqx.Module):
33+
class _VeryChordState(eqx.Module, strict=True):
3434
linear_state: tuple[lx.AbstractLinearOperator, PyTree[Any]]
3535
diff: Y
3636
diffsize: Scalar
@@ -39,7 +39,7 @@ class _VeryChordState(eqx.Module):
3939
step: Scalar
4040

4141

42-
class _NoAux(eqx.Module):
42+
class _NoAux(eqx.Module, strict=True):
4343
fn: Callable
4444

4545
def __call__(self, y, args):
@@ -48,7 +48,7 @@ def __call__(self, y, args):
4848
return out
4949

5050

51-
class VeryChord(optx.AbstractRootFinder):
51+
class VeryChord(optx.AbstractRootFinder, strict=True):
5252
"""The Chord method of root finding.
5353
5454
As `optimistix.Chord`, except that in Runge--Kutta methods, the linearisation point

diffrax/_saveat.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _convert_ts(
2121
return jnp.asarray(ts)
2222

2323

24-
class SubSaveAt(eqx.Module):
24+
class SubSaveAt(eqx.Module, strict=True):
2525
"""Used for finer-grained control over what is saved. A PyTree of these should be
2626
passed to `SaveAt(subs=...)`.
2727
@@ -53,7 +53,7 @@ def __check_init__(self):
5353
"""
5454

5555

56-
class SaveAt(eqx.Module):
56+
class SaveAt(eqx.Module, strict=True):
5757
"""Determines what to save as output from the differential equation solve.
5858
5959
Instances of this class should be passed as the `saveat` argument of

diffrax/_solution.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Optional
22

3+
import equinox as eqx
34
import jax
45
import optimistix as optx
56
from jaxtyping import Array, Bool, PyTree, Real, Shaped
@@ -55,7 +56,7 @@ def update_result(old_result: RESULTS, new_result: RESULTS) -> RESULTS:
5556
return RESULTS.where(pred, old_result, out_result)
5657

5758

58-
class Solution(AbstractPath):
59+
class Solution(AbstractPath, strict=eqx.StrictConfig(allow_method_override=True)):
5960
"""The solution to a differential equation.
6061
6162
**Attributes:**

diffrax/_solver/base.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __instancecheck__(cls, obj):
4949
_set_metaclass = dict(metaclass=_MetaAbstractSolver)
5050

5151

52-
class AbstractSolver(eqx.Module, Generic[_SolverState], **_set_metaclass):
52+
class AbstractSolver(eqx.Module, Generic[_SolverState], strict=True, **_set_metaclass):
5353
"""Abstract base class for all differential equation solvers.
5454
5555
Subclasses should have a class-level attribute `terms`, specifying the PyTree
@@ -179,7 +179,7 @@ def func(
179179
"""
180180

181181

182-
class AbstractImplicitSolver(AbstractSolver[_SolverState]):
182+
class AbstractImplicitSolver(AbstractSolver[_SolverState], strict=True):
183183
"""Indicates that this is an implicit differential equation solver, and as such
184184
that it should take a root finder as an argument.
185185
"""
@@ -188,25 +188,25 @@ class AbstractImplicitSolver(AbstractSolver[_SolverState]):
188188
root_find_max_steps: AbstractVar[int]
189189

190190

191-
class AbstractItoSolver(AbstractSolver[_SolverState]):
191+
class AbstractItoSolver(AbstractSolver[_SolverState], strict=True):
192192
"""Indicates that when used as an SDE solver that this solver will converge to the
193193
Itô solution.
194194
"""
195195

196196

197-
class AbstractStratonovichSolver(AbstractSolver[_SolverState]):
197+
class AbstractStratonovichSolver(AbstractSolver[_SolverState], strict=True):
198198
"""Indicates that when used as an SDE solver that this solver will converge to the
199199
Stratonovich solution.
200200
"""
201201

202202

203-
class AbstractAdaptiveSolver(AbstractSolver[_SolverState]):
203+
class AbstractAdaptiveSolver(AbstractSolver[_SolverState], strict=True):
204204
"""Indicates that this solver provides error estimates, and that as such it may be
205205
used with an adaptive step size controller.
206206
"""
207207

208208

209-
class AbstractWrappedSolver(AbstractSolver[_SolverState]):
209+
class AbstractWrappedSolver(AbstractSolver[_SolverState], strict=True):
210210
"""Wraps another solver "transparently", in the sense that all `isinstance` checks
211211
will be forwarded on to the wrapped solver, e.g. when testing whether the solver is
212212
implicit/adaptive/SDE-compatible/etc.

diffrax/_solver/dopri5.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333

3434

35-
class _Dopri5Interpolation(FourthOrderPolynomialInterpolation):
35+
class _Dopri5Interpolation(FourthOrderPolynomialInterpolation, strict=True):
3636
c_mid: ClassVar[np.ndarray] = np.array(
3737
[
3838
6025192743 / 30085553152 / 2,

diffrax/_solver/dopri8.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@
188188
_vmap_polyval = jax.vmap(jnp.polyval, in_axes=(0, None))
189189

190190

191-
class _Dopri8Interpolation(AbstractLocalInterpolation):
191+
class _Dopri8Interpolation(AbstractLocalInterpolation, strict=True):
192192
t0: RealScalarLike
193193
t1: RealScalarLike
194194
y0: Y

diffrax/_solver/kencarp3.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
)
9090

9191

92-
class KenCarpInterpolation(AbstractLocalInterpolation):
92+
class AbstractKenCarpInterpolation(AbstractLocalInterpolation, strict=True):
9393
t0: RealScalarLike
9494
t1: RealScalarLike
9595
y0: Y
@@ -120,7 +120,7 @@ def evaluate(
120120
return (self.y0**ω + vector_tree_dot(coeffs, k) ** ω).ω
121121

122122

123-
class _KenCarp3Interpolation(KenCarpInterpolation):
123+
class _KenCarp3Interpolation(AbstractKenCarpInterpolation, strict=True):
124124
coeffs = np.array(
125125
[
126126
[-215264564351 / 13552729205753, 4655552711362 / 22874653954995],

diffrax/_solver/kencarp4.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .._root_finder import VeryChord, with_stepsize_controller_tols
88
from .base import AbstractImplicitSolver
9-
from .kencarp3 import KenCarpInterpolation
9+
from .kencarp3 import AbstractKenCarpInterpolation
1010
from .runge_kutta import (
1111
AbstractRungeKutta,
1212
ButcherTableau,
@@ -102,7 +102,7 @@
102102
)
103103

104104

105-
class _KenCarp4Interpolation(KenCarpInterpolation):
105+
class _KenCarp4Interpolation(AbstractKenCarpInterpolation, strict=True):
106106
coeffs = np.array(
107107
[
108108
[

0 commit comments

Comments
 (0)