1
1
import abc
2
2
import functools as ft
3
3
import warnings
4
- from typing import Any , Optional
4
+ from collections .abc import Iterable
5
+ from typing import Any , Optional , Union
5
6
6
7
import equinox as eqx
7
8
import equinox .internal as eqxi
8
9
import jax
9
10
import jax .lax as lax
10
11
import jax .numpy as jnp
11
12
import jax .tree_util as jtu
13
+ import lineax as lx
14
+ import optimistix .internal as optxi
12
15
from equinox .internal import ω
13
16
14
- from ._ad import implicit_jvp
15
17
from ._heuristics import is_sde , is_unsafe_sde
16
18
from ._saveat import save_y , SaveAt , SubSaveAt
17
19
from ._solver import AbstractItoSolver , AbstractRungeKutta , AbstractStratonovichSolver
@@ -384,7 +386,7 @@ def loop(
384
386
return final_state
385
387
386
388
387
- def _vf (ys , residual , args__terms , closure ):
389
+ def _vf (ys , residual , inputs ):
388
390
state_no_y , _ = residual
389
391
t = state_no_y .tprev
390
392
@@ -393,14 +395,12 @@ def _unpack(_y):
393
395
return _y1
394
396
395
397
y = jtu .tree_map (_unpack , ys )
396
- args , terms = args__terms
397
- _ , _ , solver , _ , _ = closure
398
+ args , terms , _ , _ , solver , _ , _ = inputs
398
399
return solver .func (terms , t , y , args )
399
400
400
401
401
- def _solve (args__terms , closure ):
402
- args , terms = args__terms
403
- self , kwargs , solver , saveat , init_state = closure
402
+ def _solve (inputs ):
403
+ args , terms , self , kwargs , solver , saveat , init_state = inputs
404
404
final_state , aux_stats = self ._loop (
405
405
** kwargs ,
406
406
args = args ,
@@ -423,6 +423,15 @@ def _solve(args__terms, closure):
423
423
)
424
424
425
425
426
+ def _frozenset (x : Union [object , Iterable [object ]]) -> frozenset [object ]:
427
+ try :
428
+ iter_x = iter (x ) # pyright: ignore
429
+ except TypeError :
430
+ return frozenset ([x ])
431
+ else :
432
+ return frozenset (iter_x )
433
+
434
+
426
435
class ImplicitAdjoint (AbstractAdjoint ):
427
436
r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).
428
437
@@ -433,8 +442,16 @@ class ImplicitAdjoint(AbstractAdjoint):
433
442
the solver and instead directly compute
434
443
$\frac{\mathrm{d}y}{\mathrm{d}θ} = - (\frac{\mathrm{d}f}{\mathrm{d}y})^{-1}\frac{\mathrm{d}f}{\mathrm{d}θ}$
435
444
via the implicit function theorem.
445
+
446
+ Observe that this involves solving a linear system with matrix given by the Jacobian
447
+ `df/dy`.
436
448
""" # noqa: E501
437
449
450
+ linear_solver : lx .AbstractLinearSolver = lx .AutoLinearSolver (well_posed = None )
451
+ tags : frozenset [object ] = eqx .field (
452
+ default_factory = frozenset , converter = _frozenset , static = True
453
+ )
454
+
438
455
def loop (
439
456
self ,
440
457
* ,
@@ -459,8 +476,10 @@ def loop(
459
476
init_state = _nondiff_solver_controller_state (
460
477
self , init_state , passed_solver_state , passed_controller_state
461
478
)
462
- closure = (self , kwargs , solver , saveat , init_state )
463
- ys , residual = implicit_jvp (_solve , _vf , (args , terms ), closure )
479
+ inputs = (args , terms , self , kwargs , solver , saveat , init_state )
480
+ ys , residual = optxi .implicit_jvp (
481
+ _solve , _vf , inputs , self .tags , self .linear_solver
482
+ )
464
483
465
484
final_state_no_ys , aux_stats = residual
466
485
# Note that `final_state.save_state` has type PyTree[SaveState]. To access `.ys`
@@ -473,6 +492,15 @@ def loop(
473
492
return final_state , aux_stats
474
493
475
494
495
+ ImplicitAdjoint .__init__ .__doc__ = """**Arguments:**
496
+
497
+ - `linear_solver`: A [Lineax](https://github.com/google/lineax) solver for solving the
498
+ linear system.
499
+ - `tags`: Any Lineax [tags](https://docs.kidger.site/lineax/api/tags/) describing the
500
+ Jacobian matrix `df/dy`.
501
+ """
502
+
503
+
476
504
# Compute derivatives with respect to the first argument:
477
505
# - y, corresponding to the initial state;
478
506
# - args, corresponding to explicit parameters;
0 commit comments