Skip to content

Commit 7e7d1b4

Browse files
committed
typing fixes and notebook updates
1 parent b8f4d77 commit 7e7d1b4

File tree

4 files changed

+52
-95
lines changed

4 files changed

+52
-95
lines changed

diffrax/_delays.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919
from optimistix._custom_types import Aux, Fn, Y
2020

21-
from ._custom_types import IntScalarLike, RealScalarLike, VF
21+
from ._custom_types import BoolScalarLike, IntScalarLike, RealScalarLike
2222
from ._global_interpolation import DenseInterpolation
2323
from ._local_interpolation import AbstractLocalInterpolation
2424
from ._term import VectorFieldWrapper
@@ -32,7 +32,7 @@ class _FixedPointState(eqx.Module, strict=True):
3232
class ModifiedFixedPointIteration(AbstractFixedPointSolver):
3333
rtol: float
3434
atol: float
35-
implicit_step: bool
35+
implicit_step: BoolScalarLike
3636
max_steps: int = eqx.field(static=True)
3737
norm: Callable[[PyTree], Scalar] = rms_norm
3838

@@ -95,7 +95,7 @@ def postprocess(
9595

9696

9797
class Delays(eqx.Module):
98-
"""Module that incorportes all the information needed for integrating DDEs"""
98+
"""Module that incorporates all the information needed for integrating DDEs"""
9999

100100
delays: PyTree[Callable]
101101
initial_discontinuities: Optional[Array] = jnp.array([0.0])
@@ -121,7 +121,7 @@ class HistoryVectorField(eqx.Module):
121121
- `delays` : DDE's different deviated arguments
122122
"""
123123

124-
vector_field: VF
124+
vector_field: Callable[..., PyTree]
125125
t0: RealScalarLike
126126
tprev: RealScalarLike
127127
tnext: RealScalarLike

diffrax/_integrate.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class State(eqx.Module):
113113
event_mask: Optional[PyTree[BoolScalarLike]]
114114
num_dde_implicit_step: IntScalarLike
115115
num_dde_explicit_step: IntScalarLike
116-
discontinuities: Optional[eqxi.MaybeBuffer[Float[Array, " times_plus_1"]]] # noqa: F821
116+
discontinuities: Optional[eqxi.MaybeBuffer[ArrayLike]] # noqa: F821
117117
discontinuities_save_index: Optional[IntScalarLike]
118118
# Output that is .at[].set() updated during the solve (and their indices)
119119

@@ -347,7 +347,6 @@ def body_fun_aux(state):
347347
state.solver_state,
348348
state.made_jump,
349349
)
350-
implicit_step = False
351350
else:
352351
min_delay = []
353352
flat_delays = jtu.tree_leaves(delays.delays)
@@ -423,6 +422,7 @@ def get_struct_dense_info(init_state):
423422
assert jnp.result_type(keep_step) is jnp.dtype(bool)
424423
# Finding all of the potential discontinuity roots
425424
discont_update = False
425+
num_dde_explicit_step = num_dde_implicit_step = 0
426426
if delays is not None:
427427
# _part_maybe_find_discontinuity = ft.partial(
428428
# maybe_find_discontinuity,
@@ -467,11 +467,11 @@ def get_struct_dense_info(init_state):
467467

468468
# Count the number of steps in DDEs, just for statistical purposes
469469
num_dde_implicit_step = state.num_dde_implicit_step + (
470-
keep_step & implicit_step
471-
)
472-
num_dde_explicit_step = state.num_dde_explicit_step + (
473-
keep_step & jnp.invert(implicit_step)
470+
jnp.where(keep_step, 1, 0) & jnp.where(implicit_step, 1, 0) # type: ignore
474471
)
472+
num_dde_explicit_step = state.num_dde_explicit_step + jnp.where(
473+
keep_step, 1, 0
474+
) & jnp.where(jnp.invert(implicit_step), 1, 0) # type: ignore
475475

476476
assert jnp.result_type(keep_step) is jnp.dtype(bool)
477477

@@ -694,7 +694,9 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
694694
discontinuities = maybe_inplace_delay(
695695
discontinuities_save_index + 1, tnext, discontinuities
696696
)
697-
discontinuities_save_index = discontinuities_save_index + discont_update
697+
discontinuities_save_index = discontinuities_save_index + jnp.where(
698+
discont_update, 1, 0
699+
)
698700

699701
new_state = State(
700702
y=y,
@@ -717,8 +719,8 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
717719
event_dense_info=event_dense_info,
718720
event_values=event_values,
719721
event_mask=event_mask,
720-
num_dde_explicit_step=num_dde_explicit_step, # type: ignore
721-
num_dde_implicit_step=num_dde_implicit_step, # type: ignore
722+
num_dde_explicit_step=num_dde_explicit_step,
723+
num_dde_implicit_step=num_dde_implicit_step,
722724
discontinuities=discontinuities, # type: ignore
723725
discontinuities_save_index=discontinuities_save_index,
724726
)
@@ -932,7 +934,7 @@ def diffeqsolve(
932934
t0: RealScalarLike,
933935
t1: RealScalarLike,
934936
dt0: Optional[RealScalarLike],
935-
y0: PyTree[ArrayLike],
937+
y0: Union[PyTree[ArrayLike], Callable[[RealScalarLike], PyTree[ArrayLike]]],
936938
args: PyTree[Any] = None,
937939
*,
938940
saveat: SaveAt = SaveAt(t1=True),

examples/dde.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
"cell_type": "markdown",
103103
"metadata": {},
104104
"source": [
105-
"In our case we only have 0 since $y\\prime(t=0^{-}) \\neq y\\prime(t=0^{+})$ because $y\\prime(t=0^{+}) = - 2 \\alpha$ and $y\\prime(t=0^{-}) = 0$. \n",
105+
"In our case we only have 0 since $y^\\prime(t=0^{-}) \\neq y^\\prime(t=0^{+})$ because $y^\\prime(t=0^{+}) = - 2 \\alpha$ and $y^\\prime(t=0^{-}) = 0$. \n",
106106
"We choose $\\tau=1$."
107107
]
108108
},
@@ -169,7 +169,7 @@
169169
"name": "stdout",
170170
"output_type": "stream",
171171
"text": [
172-
"Integration took in 0.0011103153228759766 seconds.\n"
172+
"Integration took in 0.0004780292510986328 seconds.\n"
173173
]
174174
},
175175
{
@@ -212,7 +212,7 @@
212212
"name": "stdout",
213213
"output_type": "stream",
214214
"text": [
215-
"Integration took in 0.0009412765502929688 seconds.\n"
215+
"Integration took in 0.0006003379821777344 seconds.\n"
216216
]
217217
},
218218
{

examples/neural_dde.ipynb

Lines changed: 33 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@
243243
" dataset_size=256,\n",
244244
" batch_size=128,\n",
245245
" lr_strategy=(3e-3,),\n",
246-
" steps_strategy=(120,),\n",
246+
" steps_strategy=(500,),\n",
247247
" length_strategy=(1.0,),\n",
248248
" width_size=32,\n",
249249
" depth=3,\n",
@@ -310,83 +310,38 @@
310310
"name": "stdout",
311311
"output_type": "stream",
312312
"text": [
313-
"Step: 0, Loss: 1.401884913444519, Computation time: 14.423743486404419\n",
314-
"Step: 1, Loss: 1.1716960668563843, Computation time: 0.00439143180847168\n",
315-
"Step: 2, Loss: 1.244404673576355, Computation time: 0.003352642059326172\n",
316-
"Step: 3, Loss: 1.152793049812317, Computation time: 0.0034520626068115234\n",
317-
"Step: 4, Loss: 1.2006347179412842, Computation time: 0.0023109912872314453\n",
318-
"Step: 5, Loss: 1.134800910949707, Computation time: 0.003433704376220703\n",
319-
"Step: 6, Loss: 1.0922808647155762, Computation time: 0.00430607795715332\n",
320-
"Step: 7, Loss: 1.0771468877792358, Computation time: 0.0031037330627441406\n",
321-
"Step: 8, Loss: 1.1282471418380737, Computation time: 0.0026216506958007812\n",
322-
"Step: 9, Loss: 1.0664700269699097, Computation time: 0.0036363601684570312\n",
323-
"Step: 10, Loss: 0.9904348254203796, Computation time: 0.0024759769439697266\n",
324-
"Step: 11, Loss: 1.0465335845947266, Computation time: 0.002873659133911133\n",
325-
"Step: 12, Loss: 1.0017845630645752, Computation time: 0.004104137420654297\n",
326-
"Step: 13, Loss: 1.0248370170593262, Computation time: 0.002857208251953125\n",
327-
"Step: 14, Loss: 0.9131743311882019, Computation time: 0.0031785964965820312\n",
328-
"Step: 15, Loss: 0.9832042455673218, Computation time: 0.003298044204711914\n",
329-
"Step: 16, Loss: 0.9668886661529541, Computation time: 0.002664327621459961\n",
330-
"Step: 17, Loss: 0.9930294752120972, Computation time: 0.0028023719787597656\n",
331-
"Step: 18, Loss: 0.9732811450958252, Computation time: 0.004416465759277344\n",
332-
"Step: 19, Loss: 0.9101904630661011, Computation time: 0.0021491050720214844\n",
333-
"Step: 20, Loss: 0.9569293260574341, Computation time: 0.003906965255737305\n",
334-
"Step: 21, Loss: 1.000235676765442, Computation time: 0.0031125545501708984\n",
335-
"Step: 22, Loss: 0.948157787322998, Computation time: 0.0025708675384521484\n",
336-
"Step: 23, Loss: 0.9466567039489746, Computation time: 0.003038644790649414\n",
337-
"Step: 24, Loss: 0.9593256115913391, Computation time: 0.003378629684448242\n",
338-
"Step: 25, Loss: 0.9038560390472412, Computation time: 0.003355264663696289\n",
339-
"Step: 26, Loss: 0.9106528162956238, Computation time: 0.0029053688049316406\n",
340-
"Step: 27, Loss: 0.8822335004806519, Computation time: 0.0023543834686279297\n",
341-
"Step: 28, Loss: 0.8793952465057373, Computation time: 0.002279043197631836\n",
342-
"Step: 29, Loss: 0.8758573532104492, Computation time: 0.0025527477264404297\n",
343-
"Step: 30, Loss: 0.8163420557975769, Computation time: 0.0034754276275634766\n",
344-
"Step: 31, Loss: 0.7726603150367737, Computation time: 0.0041882991790771484\n",
345-
"Step: 32, Loss: 0.7940323352813721, Computation time: 0.0024156570434570312\n",
346-
"Step: 33, Loss: 0.7175382375717163, Computation time: 0.0027916431427001953\n",
347-
"Step: 34, Loss: 0.7028713226318359, Computation time: 0.0023674964904785156\n",
348-
"Step: 35, Loss: 0.6886805295944214, Computation time: 0.0023441314697265625\n",
349-
"Step: 36, Loss: 0.6005609035491943, Computation time: 0.002215862274169922\n",
350-
"Step: 37, Loss: 0.5269209742546082, Computation time: 0.003516674041748047\n",
351-
"Step: 38, Loss: 0.4020206332206726, Computation time: 0.0025489330291748047\n",
352-
"Step: 39, Loss: 0.3255264461040497, Computation time: 0.0038254261016845703\n",
353-
"Step: 40, Loss: 0.3398251533508301, Computation time: 0.0026879310607910156\n",
354-
"Step: 41, Loss: 0.23914429545402527, Computation time: 0.0031821727752685547\n",
355-
"Step: 42, Loss: 0.14592041075229645, Computation time: 0.002313852310180664\n",
356-
"Step: 43, Loss: 0.13987970352172852, Computation time: 0.002420186996459961\n",
357-
"Step: 44, Loss: 0.1373867690563202, Computation time: 0.002546548843383789\n",
358-
"Step: 45, Loss: 0.16126586496829987, Computation time: 0.0025119781494140625\n",
359-
"Step: 46, Loss: 0.11544477194547653, Computation time: 0.0023653507232666016\n",
360-
"Step: 47, Loss: 0.061478693038225174, Computation time: 0.0027551651000976562\n",
361-
"Step: 48, Loss: 0.042316604405641556, Computation time: 0.0026128292083740234\n",
362-
"Step: 49, Loss: 0.09910032153129578, Computation time: 0.0021860599517822266\n",
363-
"Step: 50, Loss: 0.08195476979017258, Computation time: 0.0030493736267089844\n",
364-
"Step: 51, Loss: 0.036347728222608566, Computation time: 0.0036330223083496094\n",
365-
"Step: 52, Loss: 0.04329509660601616, Computation time: 0.002357959747314453\n",
366-
"Step: 53, Loss: 0.0774608924984932, Computation time: 0.003058910369873047\n",
367-
"Step: 54, Loss: 0.07515737414360046, Computation time: 0.0029230117797851562\n",
368-
"Step: 55, Loss: 0.06952822208404541, Computation time: 0.004669904708862305\n",
369-
"Step: 56, Loss: 0.04167735204100609, Computation time: 0.004377126693725586\n",
370-
"Step: 57, Loss: 0.043842863291502, Computation time: 0.002735614776611328\n",
371-
"Step: 58, Loss: 0.06545793265104294, Computation time: 0.003109455108642578\n",
372-
"Step: 59, Loss: 0.055285390466451645, Computation time: 0.003002643585205078\n",
373-
"Step: 60, Loss: 0.031119057908654213, Computation time: 0.0036330223083496094\n",
374-
"Step: 61, Loss: 0.03198694810271263, Computation time: 0.0037326812744140625\n",
375-
"Step: 62, Loss: 0.039938874542713165, Computation time: 0.003992319107055664\n",
376-
"Step: 63, Loss: 0.045956145972013474, Computation time: 0.0034766197204589844\n",
377-
"Step: 64, Loss: 0.03436319902539253, Computation time: 0.0034666061401367188\n",
378-
"Step: 65, Loss: 0.025405289605259895, Computation time: 0.003632783889770508\n",
379-
"Step: 66, Loss: 0.023711830377578735, Computation time: 0.002114534378051758\n",
380-
"Step: 67, Loss: 0.03284265473484993, Computation time: 0.0043947696685791016\n",
381-
"Step: 68, Loss: 0.03023228421807289, Computation time: 0.003482341766357422\n",
382-
"Step: 69, Loss: 0.021171605214476585, Computation time: 0.0029871463775634766\n",
383-
"Step: 70, Loss: 0.017744384706020355, Computation time: 0.003162860870361328\n",
384-
"Step: 71, Loss: 0.022380519658327103, Computation time: 0.004060029983520508\n",
385-
"Step: 72, Loss: 0.02576189674437046, Computation time: 0.0023908615112304688\n",
386-
"Step: 73, Loss: 0.019239962100982666, Computation time: 0.002658367156982422\n",
387-
"Step: 74, Loss: 0.016447639092803, Computation time: 0.003663301467895508\n",
388-
"Step: 75, Loss: 0.01782667636871338, Computation time: 0.0035588741302490234\n",
389-
"Step: 76, Loss: 0.02072978764772415, Computation time: 0.0031981468200683594\n"
313+
"Step: 0, Loss: 1.401884913444519, Computation time: 14.241726636886597\n",
314+
"Step: 5, Loss: 1.134800910949707, Computation time: 1.192063570022583\n",
315+
"Step: 10, Loss: 0.9904348254203796, Computation time: 1.9642481803894043\n",
316+
"Step: 15, Loss: 0.9832042455673218, Computation time: 2.6581308841705322\n",
317+
"Step: 20, Loss: 0.9569293260574341, Computation time: 3.8219752311706543\n",
318+
"Step: 25, Loss: 0.9038560390472412, Computation time: 3.6977927684783936\n",
319+
"Step: 30, Loss: 0.8163420557975769, Computation time: 35.42212176322937\n",
320+
"Step: 35, Loss: 0.6886805295944214, Computation time: 6.909891843795776\n",
321+
"Step: 40, Loss: 0.3398251533508301, Computation time: 5.504199266433716\n",
322+
"Step: 45, Loss: 0.16126586496829987, Computation time: 5.103270769119263\n",
323+
"Step: 50, Loss: 0.08195476979017258, Computation time: 5.78333044052124\n",
324+
"Step: 55, Loss: 0.06952822208404541, Computation time: 10.413585901260376\n",
325+
"Step: 60, Loss: 0.031119057908654213, Computation time: 7.735013723373413\n",
326+
"Step: 65, Loss: 0.025405289605259895, Computation time: 7.267561912536621\n",
327+
"Step: 70, Loss: 0.017744384706020355, Computation time: 6.29371190071106\n",
328+
"Step: 75, Loss: 0.01782667636871338, Computation time: 25.47852373123169\n",
329+
"Step: 80, Loss: 0.015500097535550594, Computation time: 6.827113628387451\n",
330+
"Step: 85, Loss: 0.011661469005048275, Computation time: 7.4889373779296875\n",
331+
"Step: 90, Loss: 0.00916498713195324, Computation time: 5.9714484214782715\n",
332+
"Step: 95, Loss: 0.010490368120372295, Computation time: 6.338549613952637\n",
333+
"Step: 100, Loss: 0.007394495420157909, Computation time: 6.486926078796387\n",
334+
"Step: 105, Loss: 0.007423707749694586, Computation time: 7.48775839805603\n",
335+
"Step: 110, Loss: 0.007470142096281052, Computation time: 6.819472312927246\n",
336+
"Step: 115, Loss: 0.0059195710346102715, Computation time: 6.100851058959961\n",
337+
"Step: 120, Loss: 0.005787399597465992, Computation time: 6.755363941192627\n",
338+
"Step: 125, Loss: 0.005841915961354971, Computation time: 6.580211639404297\n",
339+
"Step: 130, Loss: 0.006159897893667221, Computation time: 6.810395956039429\n",
340+
"Step: 135, Loss: 0.0052039725705981255, Computation time: 7.518130779266357\n",
341+
"Step: 140, Loss: 0.0053937858901917934, Computation time: 6.242229223251343\n",
342+
"Step: 145, Loss: 0.00430111913010478, Computation time: 5.085596799850464\n",
343+
"Step: 150, Loss: 0.004397619515657425, Computation time: 5.2558135986328125\n",
344+
"Step: 155, Loss: 0.004206538666039705, Computation time: 7.212834596633911\n"
390345
]
391346
}
392347
],

0 commit comments

Comments
 (0)