Skip to content

Commit 37775d7

Browse files
patrick-kidgerthibmonsel
authored andcommitted
In progress commit on branch delay.
1 parent fe1ca9a commit 37775d7

24 files changed

+12272
-17
lines changed

diffrax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88
from .autocitation import citation, citation_rules
99
from .brownian import AbstractBrownianPath, UnsafeBrownianPath, VirtualBrownianTree
10+
from .delays import Delays
1011
from .event import (
1112
AbstractDiscreteTerminatingEvent,
1213
DiscreteTerminatingEvent,

diffrax/adjoint.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def loop(
117117
solver,
118118
stepsize_controller,
119119
discrete_terminating_event,
120+
delays,
120121
saveat,
121122
t0,
122123
t1,
@@ -522,13 +523,15 @@ def _loop_backsolve_bwd(
522523
solver,
523524
stepsize_controller,
524525
discrete_terminating_event,
526+
delays,
525527
saveat,
526528
t0,
527529
t1,
528530
dt0,
529531
max_steps,
530532
throw,
531533
init_state,
534+
y0_history,
532535
):
533536
assert discrete_terminating_event is None
534537

@@ -566,6 +569,8 @@ def _loop_backsolve_bwd(
566569
adjoint=self,
567570
solver=solver,
568571
stepsize_controller=stepsize_controller,
572+
discrete_terminating_event=discrete_terminating_event,
573+
delays=delays,
569574
terms=adjoint_terms,
570575
dt0=None if dt0 is None else -dt0,
571576
max_steps=max_steps,
@@ -745,6 +750,7 @@ def loop(
745750
passed_solver_state,
746751
passed_controller_state,
747752
discrete_terminating_event,
753+
delays,
748754
**kwargs,
749755
):
750756
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
@@ -790,6 +796,10 @@ def loop(
790796
raise NotImplementedError(
791797
"`diffrax.BacksolveAdjoint` is not compatible with events."
792798
)
799+
if delays is not None:
800+
raise NotImplementedError(
801+
"Cannot use `delays` with `adjoint=BacksolveAdjoint()`"
802+
)
793803

794804
y = init_state.y
795805
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
@@ -804,6 +814,7 @@ def loop(
804814
init_state=init_state,
805815
solver=solver,
806816
discrete_terminating_event=discrete_terminating_event,
817+
delays=delays,
807818
**kwargs,
808819
)
809820
final_state = _only_transpose_ys(final_state)

0 commit comments

Comments
 (0)