Skip to content

Commit fe9c473

Browse files
Now using Optimistix's implementation of implicit_jvp
1 parent 3fc76a3 commit fe9c473

File tree

3 files changed

+40
-126
lines changed

3 files changed

+40
-126
lines changed

diffrax/_ad.py

Lines changed: 0 additions & 115 deletions
This file was deleted.

diffrax/_adjoint.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
import abc
22
import functools as ft
33
import warnings
4-
from typing import Any, Optional
4+
from collections.abc import Iterable
5+
from typing import Any, Optional, Union
56

67
import equinox as eqx
78
import equinox.internal as eqxi
89
import jax
910
import jax.lax as lax
1011
import jax.numpy as jnp
1112
import jax.tree_util as jtu
13+
import lineax as lx
14+
import optimistix.internal as optxi
1215
from equinox.internal import ω
1316

14-
from ._ad import implicit_jvp
1517
from ._heuristics import is_sde, is_unsafe_sde
1618
from ._saveat import save_y, SaveAt, SubSaveAt
1719
from ._solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver
@@ -384,7 +386,7 @@ def loop(
384386
return final_state
385387

386388

387-
def _vf(ys, residual, args__terms, closure):
389+
def _vf(ys, residual, inputs):
388390
state_no_y, _ = residual
389391
t = state_no_y.tprev
390392

@@ -393,14 +395,12 @@ def _unpack(_y):
393395
return _y1
394396

395397
y = jtu.tree_map(_unpack, ys)
396-
args, terms = args__terms
397-
_, _, solver, _, _ = closure
398+
args, terms, _, _, solver, _, _ = inputs
398399
return solver.func(terms, t, y, args)
399400

400401

401-
def _solve(args__terms, closure):
402-
args, terms = args__terms
403-
self, kwargs, solver, saveat, init_state = closure
402+
def _solve(inputs):
403+
args, terms, self, kwargs, solver, saveat, init_state = inputs
404404
final_state, aux_stats = self._loop(
405405
**kwargs,
406406
args=args,
@@ -423,6 +423,15 @@ def _solve(args__terms, closure):
423423
)
424424

425425

426+
def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]:
427+
try:
428+
iter_x = iter(x) # pyright: ignore
429+
except TypeError:
430+
return frozenset([x])
431+
else:
432+
return frozenset(iter_x)
433+
434+
426435
class ImplicitAdjoint(AbstractAdjoint):
427436
r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).
428437
@@ -433,8 +442,16 @@ class ImplicitAdjoint(AbstractAdjoint):
433442
the solver and instead directly compute
434443
$\frac{\mathrm{d}y}{\mathrm{d}θ} = - (\frac{\mathrm{d}f}{\mathrm{d}y})^{-1}\frac{\mathrm{d}f}{\mathrm{d}θ}$
435444
via the implicit function theorem.
445+
446+
Observe that this involves solving a linear system with matrix given by the Jacobian
447+
`df/dy`.
436448
""" # noqa: E501
437449

450+
linear_solver: lx.AbstractLinearSolver = lx.AutoLinearSolver(well_posed=None)
451+
tags: frozenset[object] = eqx.field(
452+
default_factory=frozenset, converter=_frozenset, static=True
453+
)
454+
438455
def loop(
439456
self,
440457
*,
@@ -459,8 +476,10 @@ def loop(
459476
init_state = _nondiff_solver_controller_state(
460477
self, init_state, passed_solver_state, passed_controller_state
461478
)
462-
closure = (self, kwargs, solver, saveat, init_state)
463-
ys, residual = implicit_jvp(_solve, _vf, (args, terms), closure)
479+
inputs = (args, terms, self, kwargs, solver, saveat, init_state)
480+
ys, residual = optxi.implicit_jvp(
481+
_solve, _vf, inputs, self.tags, self.linear_solver
482+
)
464483

465484
final_state_no_ys, aux_stats = residual
466485
# Note that `final_state.save_state` has type PyTree[SaveState]. To access `.ys`
@@ -473,6 +492,15 @@ def loop(
473492
return final_state, aux_stats
474493

475494

495+
ImplicitAdjoint.__init__.__doc__ = """**Arguments:**
496+
497+
- `linear_solver`: A [Lineax](https://github.com/google/lineax) solver for solving the
498+
linear system.
499+
- `tags`: Any Lineax [tags](https://docs.kidger.site/lineax/api/tags/) describing the
500+
Jacobian matrix `df/dy`.
501+
"""
502+
503+
476504
# Compute derivatives with respect to the first argument:
477505
# - y, corresponding to the initial state;
478506
# - args, corresponding to explicit parameters;

docs/api/adjoints.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax
3737

3838
::: diffrax.ImplicitAdjoint
3939
selection:
40-
members: false
40+
members:
41+
- __init__
4142

4243
::: diffrax.DirectAdjoint
4344
selection:

0 commit comments

Comments
 (0)