Skip to content

Commit 6d7854b

Browse files
authored
Merge branch 'patrick-kidger:main' into main
2 parents 3020bb5 + d6d09dc commit 6d7854b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+6838
-1015
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.1.7
3+
rev: v0.2.2
44
hooks:
5+
- id: ruff-format # formatter
6+
types_or: [ python, pyi, jupyter ]
57
- id: ruff # linter
68
types_or: [ python, pyi, jupyter ]
79
args: [ --fix ]
8-
- id: ruff-format # formatter
9-
types_or: [ python, pyi, jupyter ]
1010
- repo: https://github.com/RobertCraigie/pyright-python
11-
rev: v1.1.316
11+
rev: v1.1.350
1212
hooks:
1313
- id: pyright
1414
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typing_extensions]

README.md

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,24 +61,21 @@ If you found this library useful in academic research, please cite: [(arXiv link
6161

6262
## See also: other libraries in the JAX ecosystem
6363

64-
[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays.
65-
66-
[Equinox](https://github.com/patrick-kidger/equinox): neural networks.
67-
68-
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
69-
70-
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
71-
72-
[Lineax](https://github.com/google/lineax): linear solvers.
73-
74-
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
75-
76-
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
77-
78-
[sympy2jax](https://github.com/google/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
79-
80-
[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models.
81-
82-
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
83-
84-
[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
64+
**Always useful**
65+
[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
66+
[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.
67+
68+
**Deep learning**
69+
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
70+
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
71+
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
72+
73+
**Scientific computing**
74+
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
75+
[Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
76+
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
77+
[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
78+
[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
79+
80+
**Awesome JAX**
81+
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.

benchmarks/brownian_tree_times.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
v0.5.0 introduced a new implementation for `diffrax.VirtualBrownianTree` that is
3-
additionally capable of computing Levy area.
3+
additionally capable of computing Lévy area.
44
55
Here we check the speed of the new implementation against the old implementation, to be
66
sure that it is still fast.
@@ -10,6 +10,7 @@
1010
from typing import cast, Optional, Union
1111
from typing_extensions import TypeAlias
1212

13+
import diffrax
1314
import equinox as eqx
1415
import equinox.internal as eqxi
1516
import jax
@@ -50,9 +51,9 @@ def __init__(
5051
tol: RealScalarLike,
5152
shape: tuple[int, ...],
5253
key: PRNGKeyArray,
53-
levy_area: str,
54+
levy_area: type[diffrax.AbstractBrownianIncrement] = diffrax.BrownianIncrement,
5455
):
55-
assert levy_area == ""
56+
assert levy_area == diffrax.BrownianIncrement
5657
self.t0 = t0
5758
self.t1 = t1
5859
self.tol = tol
@@ -187,13 +188,13 @@ def run(_ts):
187188
)
188189

189190

190-
for levy_area in ("", "space-time"):
191+
for levy_area in (diffrax.BrownianIncrement, diffrax.SpaceTimeLevyArea):
191192
print(f"- {levy_area=}")
192193
for tol in (2**-3, 2**-12):
193194
print(f"-- {tol=}")
194-
for num_ts in (1, 100):
195+
for num_ts in (1, 10000):
195196
print(f"--- {num_ts=}")
196-
if levy_area == "":
197+
if levy_area == diffrax.BrownianIncrement:
197198
print(f"Old: {time_tree(OldVBT, num_ts, tol, levy_area):.5f}")
198199
print(f"new: {time_tree(VirtualBrownianTree, num_ts, tol, levy_area):.5f}")
199200
print("")

benchmarks/small_neural_ode.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
class FuncTorch(torch.nn.Module):
2121
def __init__(self):
2222
super().__init__()
23-
self.func = torch.jit.script( # pyright: ignore
23+
self.func = torch.jit.script(
2424
torch.nn.Sequential(
2525
torch.nn.Linear(4, 32),
2626
torch.nn.Softplus(),
@@ -30,7 +30,7 @@ def __init__(self):
3030
)
3131

3232
def forward(self, t, y):
33-
return self.func(y) # pyright: ignore
33+
return self.func(y)
3434

3535

3636
class FuncJax(eqx.Module):
@@ -177,10 +177,10 @@ def run(multiple, grad, batch_size=64, t1=100):
177177
with torch.no_grad():
178178
func_jax = neural_ode_diffrax.func.func
179179
func_torch = neural_ode_torch.func.func
180-
func_torch[0].weight.copy_(torch.tensor(np.asarray(func_jax.layers[0].weight))) # pyright: ignore
181-
func_torch[0].bias.copy_(torch.tensor(np.asarray(func_jax.layers[0].bias))) # pyright: ignore
182-
func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight))) # pyright: ignore
183-
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias))) # pyright: ignore
180+
func_torch[0].weight.copy_(torch.tensor(np.asarray(func_jax.layers[0].weight)))
181+
func_torch[0].bias.copy_(torch.tensor(np.asarray(func_jax.layers[0].bias)))
182+
func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight)))
183+
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias)))
184184

185185
y0_jax = jr.normal(jr.PRNGKey(1), (batch_size, 4))
186186
y0_torch = torch.tensor(np.asarray(y0_jax))

diffrax/__init__.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,22 @@
1313
UnsafeBrownianPath as UnsafeBrownianPath,
1414
VirtualBrownianTree as VirtualBrownianTree,
1515
)
16-
from ._custom_types import LevyVal as LevyVal
16+
from ._custom_types import (
17+
AbstractBrownianIncrement as AbstractBrownianIncrement,
18+
AbstractSpaceTimeLevyArea as AbstractSpaceTimeLevyArea,
19+
AbstractSpaceTimeTimeLevyArea as AbstractSpaceTimeTimeLevyArea,
20+
BrownianIncrement as BrownianIncrement,
21+
SpaceTimeLevyArea as SpaceTimeLevyArea,
22+
SpaceTimeTimeLevyArea as SpaceTimeTimeLevyArea,
23+
)
1724
from ._event import (
18-
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
19-
DiscreteTerminatingEvent as DiscreteTerminatingEvent,
20-
SteadyStateEvent as SteadyStateEvent,
25+
# Deliberately not provided with `X as X` as these are now deprecated, so we'd like
26+
# static type checkers to warn about using them.
27+
AbstractDiscreteTerminatingEvent, # noqa: F401
28+
DiscreteTerminatingEvent, # noqa: F401
29+
Event as Event,
30+
steady_state_event as steady_state_event,
31+
SteadyStateEvent, # noqa: F401
2132
)
2233
from ._global_interpolation import (
2334
AbstractGlobalInterpolation as AbstractGlobalInterpolation,
@@ -37,6 +48,12 @@
3748
)
3849
from ._misc import adjoint_rms_seminorm as adjoint_rms_seminorm
3950
from ._path import AbstractPath as AbstractPath
51+
from ._progress_meter import (
52+
AbstractProgressMeter as AbstractProgressMeter,
53+
NoProgressMeter as NoProgressMeter,
54+
TextProgressMeter as TextProgressMeter,
55+
TqdmProgressMeter as TqdmProgressMeter,
56+
)
4057
from ._root_finder import (
4158
VeryChord as VeryChord,
4259
with_stepsize_controller_tols as with_stepsize_controller_tols,
@@ -59,6 +76,7 @@
5976
AbstractRungeKutta as AbstractRungeKutta,
6077
AbstractSDIRK as AbstractSDIRK,
6178
AbstractSolver as AbstractSolver,
79+
AbstractSRK as AbstractSRK,
6280
AbstractStratonovichSolver as AbstractStratonovichSolver,
6381
AbstractWrappedSolver as AbstractWrappedSolver,
6482
Bosh3 as Bosh3,
@@ -68,6 +86,7 @@
6886
Dopri8 as Dopri8,
6987
Euler as Euler,
7088
EulerHeun as EulerHeun,
89+
GeneralShARK as GeneralShARK,
7190
HalfSolver as HalfSolver,
7291
Heun as Heun,
7392
ImplicitEuler as ImplicitEuler,
@@ -83,8 +102,14 @@
83102
MultiButcherTableau as MultiButcherTableau,
84103
Ralston as Ralston,
85104
ReversibleHeun as ReversibleHeun,
105+
SEA as SEA,
86106
SemiImplicitEuler as SemiImplicitEuler,
107+
ShARK as ShARK,
87108
Sil3 as Sil3,
109+
SlowRK as SlowRK,
110+
SPaRK as SPaRK,
111+
SRA1 as SRA1,
112+
StochasticButcherTableau as StochasticButcherTableau,
88113
StratonovichMilstein as StratonovichMilstein,
89114
Tsit5 as Tsit5,
90115
)

diffrax/_adjoint.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import abc
22
import functools as ft
33
import warnings
4-
from collections.abc import Iterable
5-
from typing import Any, Optional, Union
4+
from collections.abc import Callable, Iterable
5+
from typing import Any, cast, Optional, Union
66

77
import equinox as eqx
88
import equinox.internal as eqxi
@@ -20,6 +20,9 @@
2020
from ._term import AbstractTerm, AdjointTerm
2121

2222

23+
ω = cast(Callable, ω)
24+
25+
2326
def _is_none(x):
2427
return x is None
2528

@@ -118,7 +121,7 @@ def loop(
118121
terms,
119122
solver,
120123
stepsize_controller,
121-
discrete_terminating_event,
124+
event,
122125
saveat,
123126
t0,
124127
t1,
@@ -128,6 +131,7 @@ def loop(
128131
init_state,
129132
passed_solver_state,
130133
passed_controller_state,
134+
progress_meter,
131135
) -> Any:
132136
"""Runs the main solve loop. Subclasses can override this to provide custom
133137
backpropagation behaviour; see for example the implementation of
@@ -425,6 +429,14 @@ def _solve(inputs):
425429
)
426430

427431

432+
# Unwrap jaxtyping decorator during tests, so that these are global functions.
433+
# This is needed to ensure `optx.implicit_jvp` is happy.
434+
if _vf.__globals__["__name__"].startswith("jaxtyping"):
435+
_vf = _vf.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]
436+
if _solve.__globals__["__name__"].startswith("jaxtyping"):
437+
_solve = _solve.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]
438+
439+
428440
def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]:
429441
try:
430442
iter_x = iter(x) # pyright: ignore
@@ -438,7 +450,8 @@ class ImplicitAdjoint(AbstractAdjoint):
438450
r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).
439451
440452
This is used when solving towards a steady state, typically using
441-
[`diffrax.SteadyStateEvent`][]. In this case, the output of the solver is $y(θ)$
453+
[`diffrax.Event`][] where the condition function is obtained by calling
454+
[`diffrax.steady_state_event`][]. In this case, the output of the solver is $y(θ)$
442455
for which $f(t, y(θ), θ) = 0$. (Where $θ$ corresponds to all parameters found
443456
through `terms` and `args`, but not `y0`.) Then we can skip backpropagating through
444457
the solver and instead directly compute
@@ -551,23 +564,24 @@ def _loop_backsolve_bwd(
551564
self,
552565
solver,
553566
stepsize_controller,
554-
discrete_terminating_event,
567+
event,
555568
saveat,
556569
t0,
557570
t1,
558571
dt0,
559572
max_steps,
560573
throw,
561574
init_state,
575+
progress_meter,
562576
):
563-
assert discrete_terminating_event is None
577+
assert event is None
564578

565579
#
566580
# Unpack our various arguments. Delete a lot of things just to make sure we're not
567581
# using them later.
568582
#
569583

570-
del perturbed, init_state, t1
584+
del perturbed, init_state, t1, progress_meter
571585
ts, ys = residuals
572586
del residuals
573587
grad_final_state, _ = grad_final_state__aux_stats
@@ -774,7 +788,7 @@ def loop(
774788
init_state,
775789
passed_solver_state,
776790
passed_controller_state,
777-
discrete_terminating_event,
791+
event,
778792
**kwargs,
779793
):
780794
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
@@ -816,7 +830,7 @@ def loop(
816830
"`diffrax.BacksolveAdjoint` is only compatible with solvers that take "
817831
"a single term."
818832
)
819-
if discrete_terminating_event is not None:
833+
if event is not None:
820834
raise NotImplementedError(
821835
"`diffrax.BacksolveAdjoint` is not compatible with events."
822836
)
@@ -833,7 +847,7 @@ def loop(
833847
saveat=saveat,
834848
init_state=init_state,
835849
solver=solver,
836-
discrete_terminating_event=discrete_terminating_event,
850+
event=event,
837851
**kwargs,
838852
)
839853
final_state = _only_transpose_ys(final_state)

0 commit comments

Comments
 (0)