Skip to content

Commit 3020bb5

Browse files
authored
Merge branch 'patrick-kidger:main' into main
2 parents 98a1354 + 0ee47c9 commit 3020bb5

File tree

112 files changed

+4403
-3132
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+4403
-3132
lines changed

.flake8

Lines changed: 0 additions & 4 deletions
This file was deleted.

.github/workflows/release.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ jobs:
1010
runs-on: ubuntu-latest
1111
steps:
1212
- name: Release
13-
uses: patrick-kidger/action_update_python_project@v1
13+
uses: patrick-kidger/action_update_python_project@v2
1414
with:
1515
python-version: "3.11"
1616
test-script: |
17-
python -m pip install pytest jax jaxlib equinox scipy optax
1817
cp -r ${{ github.workspace }}/test ./test
18+
cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml
19+
python -m pip install -r ./test/requirements.txt
1920
python -m test
2021
pypi-token: ${{ secrets.pypi_token }}
2122
github-user: patrick-kidger

.github/workflows/run_tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ jobs:
2323
- name: Install dependencies
2424
run: |
2525
python -m pip install --upgrade pip
26-
python -m pip install pytest wheel scipy numpy optax jaxlib
26+
python -m pip install -r ./test/requirements.txt
27+
2728
2829
- name: Checks with pre-commit
2930
uses: pre-commit/[email protected]

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ dist/
66
site/
77
.all_objects.cache
88
.pymon
9+
.idea/

.isort.cfg

Lines changed: 0 additions & 6 deletions
This file was deleted.

.pre-commit-config.yaml

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
repos:
2-
- repo: https://github.com/ambv/black
3-
rev: 22.3.0
4-
hooks:
5-
- id: black
6-
- repo: https://github.com/nbQA-dev/nbQA
7-
rev: 1.6.3
8-
hooks:
9-
- id: nbqa-black
10-
- id: nbqa-isort
11-
- id: nbqa-flake8
12-
- repo: https://github.com/PyCQA/isort
13-
rev: 5.12.0
14-
hooks:
15-
- id: isort
16-
- repo: https://github.com/pycqa/flake8
17-
rev: 4.0.1
18-
hooks:
19-
- id: flake8
2+
- repo: https://github.com/astral-sh/ruff-pre-commit
3+
rev: v0.1.7
4+
hooks:
5+
- id: ruff # linter
6+
types_or: [ python, pyi, jupyter ]
7+
args: [ --fix ]
8+
- id: ruff-format # formatter
9+
types_or: [ python, pyi, jupyter ]
10+
- repo: https://github.com/RobertCraigie/pyright-python
11+
rev: v1.1.316
12+
hooks:
13+
- id: pyright
14+
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typing_extensions]

MANIFEST.in

Lines changed: 0 additions & 1 deletion
This file was deleted.

benchmarks/against_scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def speedtest(fn, name):
3636
# INTEGRATE WITH scan
3737

3838

39-
@jax.checkpoint
39+
@jax.checkpoint # pyright: ignore
4040
def body(carry, t):
4141
u, v, dt = carry
4242
u = u + du(t, v, None) * dt

benchmarks/brownian_tree_times.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
"""
2+
v0.5.0 introduced a new implementation for `diffrax.VirtualBrownianTree` that is
3+
additionally capable of computing Levy area.
4+
5+
Here we check the speed of the new implementation against the old implementation, to be
6+
sure that it is still fast.
7+
"""
8+
9+
import timeit
10+
from typing import cast, Optional, Union
11+
from typing_extensions import TypeAlias
12+
13+
import equinox as eqx
14+
import equinox.internal as eqxi
15+
import jax
16+
import jax.lax as lax
17+
import jax.numpy as jnp
18+
import jax.random as jr
19+
import jax.tree_util as jtu
20+
import lineax.internal as lxi
21+
import numpy as np
22+
from diffrax import AbstractBrownianPath, VirtualBrownianTree
23+
from jaxtyping import Array, Float, PRNGKeyArray, PyTree, Real
24+
25+
26+
RealScalarLike: TypeAlias = Real[Union[int, float, Array, np.ndarray], ""]
27+
28+
29+
class _State(eqx.Module):
30+
s: RealScalarLike
31+
t: RealScalarLike
32+
u: RealScalarLike
33+
w_s: Float[Array, " *shape"]
34+
w_t: Float[Array, " *shape"]
35+
w_u: Float[Array, " *shape"]
36+
key: PRNGKeyArray
37+
38+
39+
class OldVBT(AbstractBrownianPath):
40+
t0: RealScalarLike
41+
t1: RealScalarLike
42+
tol: RealScalarLike
43+
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
44+
key: PRNGKeyArray
45+
46+
def __init__(
47+
self,
48+
t0: RealScalarLike,
49+
t1: RealScalarLike,
50+
tol: RealScalarLike,
51+
shape: tuple[int, ...],
52+
key: PRNGKeyArray,
53+
levy_area: str,
54+
):
55+
assert levy_area == ""
56+
self.t0 = t0
57+
self.t1 = t1
58+
self.tol = tol
59+
self.shape = jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype())
60+
self.key = key
61+
62+
@property
63+
def levy_area(self):
64+
assert False
65+
66+
@eqx.filter_jit
67+
def evaluate(
68+
self,
69+
t0: RealScalarLike,
70+
t1: Optional[RealScalarLike] = None,
71+
left: bool = True,
72+
use_levy: bool = False,
73+
) -> PyTree[Array]:
74+
del left, use_levy
75+
t0 = eqxi.nondifferentiable(t0, name="t0")
76+
if t1 is None:
77+
return self._evaluate(t0)
78+
else:
79+
t1 = cast(RealScalarLike, eqxi.nondifferentiable(t1, name="t1"))
80+
return jtu.tree_map(
81+
lambda x, y: x - y,
82+
self._evaluate(t1),
83+
self._evaluate(t0),
84+
)
85+
86+
def _evaluate(self, τ: RealScalarLike) -> PyTree[Array]:
87+
map_func = lambda key, struct: self._evaluate_leaf(key, τ, struct)
88+
return jtu.tree_map(map_func, self.key, self.shape)
89+
90+
def _brownian_bridge(self, s, t, u, w_s, w_u, key, shape, dtype):
91+
mean = w_s + (w_u - w_s) * ((t - s) / (u - s))
92+
var = (u - t) * (t - s) / (u - s)
93+
std = jnp.sqrt(var)
94+
return mean + std * jr.normal(key, shape, dtype)
95+
96+
def _evaluate_leaf(
97+
self,
98+
key,
99+
τ: RealScalarLike,
100+
struct: jax.ShapeDtypeStruct,
101+
) -> Array:
102+
shape, dtype = struct.shape, struct.dtype
103+
104+
cond = self.t0 < self.t1
105+
t0 = jnp.where(cond, self.t0, self.t1).astype(dtype)
106+
t1 = jnp.where(cond, self.t1, self.t0).astype(dtype)
107+
108+
t0 = eqxi.error_if(
109+
t0,
110+
τ < t0,
111+
"Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].",
112+
)
113+
t1 = eqxi.error_if(
114+
t1,
115+
τ > t1,
116+
"Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].",
117+
)
118+
τ = jnp.clip(τ, t0, t1).astype(dtype)
119+
120+
key, init_key = jr.split(key, 2)
121+
thalf = t0 + 0.5 * (t1 - t0)
122+
w_t1 = jr.normal(init_key, shape, dtype) * jnp.sqrt(t1 - t0)
123+
w_thalf = self._brownian_bridge(t0, thalf, t1, 0, w_t1, key, shape, dtype)
124+
init_state = _State(
125+
s=t0,
126+
t=thalf,
127+
u=t1,
128+
w_s=jnp.zeros_like(w_t1),
129+
w_t=w_thalf,
130+
w_u=w_t1,
131+
key=key,
132+
)
133+
134+
def _cond_fun(_state):
135+
return (_state.u - _state.s) > self.tol
136+
137+
def _body_fun(_state):
138+
_key1, _key2 = jr.split(_state.key, 2)
139+
_cond = τ > _state.t
140+
_s = jnp.where(_cond, _state.t, _state.s)
141+
_u = jnp.where(_cond, _state.u, _state.t)
142+
_w_s = jnp.where(_cond, _state.w_t, _state.w_s)
143+
_w_u = jnp.where(_cond, _state.w_u, _state.w_t)
144+
_key = jnp.where(_cond, _key1, _key2)
145+
_t = _s + 0.5 * (_u - _s)
146+
_w_t = self._brownian_bridge(_s, _t, _u, _w_s, _w_u, _key, shape, dtype)
147+
return _State(s=_s, t=_t, u=_u, w_s=_w_s, w_t=_w_t, w_u=_w_u, key=_key)
148+
149+
final_state = lax.while_loop(_cond_fun, _body_fun, init_state)
150+
151+
s = final_state.s
152+
u = final_state.u
153+
w_s = final_state.w_s
154+
w_t = final_state.w_t
155+
w_u = final_state.w_u
156+
rescaled_τ = (τ - s) / (u - s)
157+
158+
A = jnp.array([[2, -4, 2], [-3, 4, -1], [1, 0, 0]])
159+
coeffs = jnp.tensordot(A, jnp.stack([w_s, w_t, w_u]), axes=1)
160+
return jnp.polyval(coeffs, rescaled_τ)
161+
162+
163+
key = jr.PRNGKey(0)
164+
t0, t1 = 0.3, 20.3
165+
166+
167+
def time_tree(tree_cls, num_ts, tol, levy_area):
168+
tree = tree_cls(t0=t0, t1=t1, tol=tol, shape=(10,), key=key, levy_area=levy_area)
169+
170+
if num_ts == 1:
171+
ts = 11.2
172+
173+
@jax.jit
174+
@eqx.debug.assert_max_traces(max_traces=1)
175+
def run(_ts):
176+
return tree.evaluate(_ts, use_levy=True)
177+
else:
178+
ts = jnp.linspace(t0, t1, num_ts)
179+
180+
@jax.jit
181+
@eqx.debug.assert_max_traces(max_traces=1)
182+
def run(_ts):
183+
return jax.vmap(lambda _t: tree.evaluate(_t, use_levy=True))(_ts)
184+
185+
return min(
186+
timeit.repeat(lambda: jax.block_until_ready(run(ts)), number=1, repeat=100)
187+
)
188+
189+
190+
for levy_area in ("", "space-time"):
191+
print(f"- {levy_area=}")
192+
for tol in (2**-3, 2**-12):
193+
print(f"-- {tol=}")
194+
for num_ts in (1, 100):
195+
print(f"--- {num_ts=}")
196+
if levy_area == "":
197+
print(f"Old: {time_tree(OldVBT, num_ts, tol, levy_area):.5f}")
198+
print(f"new: {time_tree(VirtualBrownianTree, num_ts, tol, levy_area):.5f}")
199+
print("")

benchmarks/compile_times.py

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import functools as ft
22
import timeit
3+
from typing import cast
34

45
import diffrax as dfx
56
import equinox as eqx
67
import jax
78
import jax.numpy as jnp
89
import jax.random as jr
10+
from jaxtyping import Array
911

1012

1113
def _weight(in_, out, key):
@@ -30,7 +32,7 @@ def __call__(self, t, y, args):
3032
return jnp.stack(y)
3133

3234

33-
def run(inline: bool, scan_stages: bool, grad: bool, adjoint_name: str):
35+
def run(inline: bool, grad: bool, adjoint_name: str):
3436
if adjoint_name == "direct":
3537
adjoint = dfx.DirectAdjoint()
3638
elif adjoint_name == "recursive":
@@ -48,7 +50,7 @@ def run(inline: bool, scan_stages: bool, grad: bool, adjoint_name: str):
4850
if not inline:
4951
vf = eqx.internal.noinline(vf)
5052
term = dfx.ODETerm(vf)
51-
solver = dfx.Dopri8(scan_stages=scan_stages)
53+
solver = dfx.Dopri8()
5254
stepsize_controller = dfx.PIDController(rtol=1e-3, atol=1e-6)
5355
t0 = 0
5456
t1 = 1
@@ -68,43 +70,25 @@ def solve(y0):
6870
adjoint=adjoint,
6971
max_steps=16**2,
7072
)
71-
return jnp.sum(sol.ys)
73+
return jnp.sum(cast(Array, sol.ys))
7274

7375
solve_ = ft.partial(solve, jnp.array([1.0]))
7476
compile_time = timeit.timeit(solve_, number=1)
75-
print(
76-
f"{inline=}, {scan_stages=}, {grad=}, adjoint={adjoint_name}, {compile_time=}"
77-
)
77+
print(f"{inline=}, {grad=}, adjoint={adjoint_name}, {compile_time=}")
7878

7979

80-
run(inline=False, scan_stages=False, grad=False, adjoint_name="direct")
81-
run(inline=False, scan_stages=False, grad=False, adjoint_name="recursive")
82-
run(inline=False, scan_stages=False, grad=False, adjoint_name="backsolve")
80+
run(inline=False, grad=False, adjoint_name="direct")
81+
run(inline=False, grad=False, adjoint_name="recursive")
82+
run(inline=False, grad=False, adjoint_name="backsolve")
8383

84-
run(inline=False, scan_stages=False, grad=True, adjoint_name="direct")
85-
run(inline=False, scan_stages=False, grad=True, adjoint_name="recursive")
86-
run(inline=False, scan_stages=False, grad=True, adjoint_name="backsolve")
84+
run(inline=False, grad=True, adjoint_name="direct")
85+
run(inline=False, grad=True, adjoint_name="recursive")
86+
run(inline=False, grad=True, adjoint_name="backsolve")
8787

88-
run(inline=False, scan_stages=True, grad=False, adjoint_name="direct")
89-
run(inline=False, scan_stages=True, grad=False, adjoint_name="recursive")
90-
run(inline=False, scan_stages=True, grad=False, adjoint_name="backsolve")
88+
run(inline=True, grad=False, adjoint_name="direct")
89+
run(inline=True, grad=False, adjoint_name="recursive")
90+
run(inline=True, grad=False, adjoint_name="backsolve")
9191

92-
run(inline=False, scan_stages=True, grad=True, adjoint_name="direct")
93-
run(inline=False, scan_stages=True, grad=True, adjoint_name="recursive")
94-
run(inline=False, scan_stages=True, grad=True, adjoint_name="backsolve")
95-
96-
run(inline=True, scan_stages=False, grad=False, adjoint_name="direct")
97-
run(inline=True, scan_stages=False, grad=False, adjoint_name="recursive")
98-
run(inline=True, scan_stages=False, grad=False, adjoint_name="backsolve")
99-
100-
run(inline=True, scan_stages=False, grad=True, adjoint_name="direct")
101-
run(inline=True, scan_stages=False, grad=True, adjoint_name="recursive")
102-
run(inline=True, scan_stages=False, grad=True, adjoint_name="backsolve")
103-
104-
run(inline=True, scan_stages=True, grad=False, adjoint_name="direct")
105-
run(inline=True, scan_stages=True, grad=False, adjoint_name="recursive")
106-
run(inline=True, scan_stages=True, grad=False, adjoint_name="backsolve")
107-
108-
run(inline=True, scan_stages=True, grad=True, adjoint_name="direct")
109-
run(inline=True, scan_stages=True, grad=True, adjoint_name="recursive")
110-
run(inline=True, scan_stages=True, grad=True, adjoint_name="backsolve")
92+
run(inline=True, grad=True, adjoint_name="direct")
93+
run(inline=True, grad=True, adjoint_name="recursive")
94+
run(inline=True, grad=True, adjoint_name="backsolve")

0 commit comments

Comments
 (0)