diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index b41426d3..690201c6 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -487,14 +487,24 @@ def maybe_inplace(i, u, x): return eqxi.buffer_at_set(x, i, u, pred=keep_step) def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: - if subsaveat.steps: + if subsaveat.steps != 0: + save_step = (state.num_accepted_steps % subsaveat.steps) == 0 + should_save = keep_step & save_step + + def save_fn(tprev, y, args): + return lax.cond( + should_save, + lambda: subsaveat.fn(tprev, y, args), + lambda: jnp.zeros_like(save_state.ys[0]), + ) + ts = maybe_inplace(save_state.save_index, tprev, save_state.ts) ys = jtu.tree_map( ft.partial(maybe_inplace, save_state.save_index), - subsaveat.fn(tprev, y, args), + save_fn(tprev, y, args), save_state.ys, ) - save_index = save_state.save_index + jnp.where(keep_step, 1, 0) + save_index = save_state.save_index + jnp.where(should_save, 1, 0) save_state = eqx.tree_at( lambda s: [s.ts, s.ys, s.save_index], save_state, @@ -505,7 +515,6 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: save_state = jtu.tree_map( save_steps, saveat.subs, save_state, is_leaf=_is_subsaveat ) - if saveat.dense: dense_ts = maybe_inplace(dense_save_index + 1, tprev, dense_ts) dense_infos = jtu.tree_map( @@ -1229,16 +1238,20 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState: out_size += 1 if subsaveat.ts is not None: out_size += len(subsaveat.ts) - if subsaveat.steps: + if subsaveat.steps != 0: # We have no way of knowing how many steps we'll actually end up taking, and # XLA doesn't support dynamic shapes. So we just have to allocate the # maximum amount of steps we can possibly take. if max_steps is None: raise ValueError( - "`max_steps=None` is incompatible with saving at `steps=True`" + "`max_steps=None` is incompatible with saving at `steps=n`" ) - out_size += max_steps - if subsaveat.t1 and not subsaveat.steps: + out_size += max_steps // subsaveat.steps + if subsaveat.t1 and ( + (max_steps is None) + or (subsaveat.steps == 0) + or (max_steps % subsaveat.steps != 0) + ): out_size += 1 saveat_ts_index = 0 save_index = 0 diff --git a/diffrax/_saveat.py b/diffrax/_saveat.py index 6ee373de..fed6d3e3 100644 --- a/diffrax/_saveat.py +++ b/diffrax/_saveat.py @@ -29,11 +29,29 @@ class SubSaveAt(eqx.Module): relatively niche feature and most users will probably not need to use `SubSaveAt`.) """ - t0: bool = False - t1: bool = False - ts: Optional[Real[Array, " times"]] = eqx.field(default=None, converter=_convert_ts) - steps: bool = False - fn: Callable = save_y + t0: bool + t1: bool + ts: Optional[Real[Array, " times"]] + steps: int + fn: Callable + + def __init__( + self, + *, + t0: bool = False, + t1: bool = False, + ts: Union[None, Sequence[RealScalarLike], Real[Array, " times"]] = None, + steps: Union[bool, int] = 0, + fn: Callable = save_y, + ): + self.t0 = t0 + self.t1 = t1 + self.ts = _convert_ts(ts) + if isinstance(steps, bool): + self.steps = 1 if steps else 0 + else: + self.steps = steps + self.fn = fn def __check_init__(self): if not self.t0 and not self.t1 and self.ts is None and not self.steps: @@ -45,7 +63,8 @@ def __check_init__(self): - `t0`: If `True`, save the initial input `y0`. - `t1`: If `True`, save the output at `t1`. - `ts`: Some array of times at which to save the output. -- `steps`: If `True`, save the output at every step of the numerical solver. +- `steps`: If `n>0`, save the output at every `n`th step of the numerical solver. + `0` means no saving. - `fn`: A function `fn(t, y, args)` which specifies what to save into `sol.ys` when using `t0`, `t1`, `ts` or `steps`. Defaults to `fn(t, y, args) -> y`, so that the evolving solution is saved. This can be useful to save only statistics of your @@ -72,7 +91,7 @@ def __init__( t0: bool = False, t1: bool = False, ts: Union[None, Sequence[RealScalarLike], Real[Array, " times"]] = None, - steps: bool = False, + steps: Union[bool, int] = False, fn: Callable = save_y, subs: PyTree[SubSaveAt] = None, dense: bool = False, @@ -101,7 +120,8 @@ def __init__( - `t0`: If `True`, save the initial input `y0`. - `t1`: If `True`, save the output at `t1`. - `ts`: Some array of times at which to save the output. -- `steps`: If `True`, save the output at every step of the numerical solver. +- `steps`: If `n>0`, save the output at every `n`th step of the numerical solver. + `0` means no saving. - `dense`: If `True`, save dense output, that can later be evaluated at any part of the interval $[t_0, t_1]$ via `sol = diffeqsolve(...); sol.evaluate(...)`. diff --git a/test/test_event.py b/test/test_event.py index 80f0102c..67ae45da 100644 --- a/test/test_event.py +++ b/test/test_event.py @@ -564,7 +564,7 @@ def cond_fn_2(t, y, args, **kwargs): @pytest.mark.parametrize("steps", (1, 2, 3, 4, 5)) -def test_event_save_steps(steps): +def test_event_save_all_steps(steps): term = diffrax.ODETerm(lambda t, y, args: (1.0, 1.0)) solver = diffrax.Tsit5() t0 = 0 @@ -601,8 +601,8 @@ def run(saveat): diffrax.SaveAt(steps=True, t1=True, t0=True), diffrax.SaveAt(steps=True, fn=lambda t, y, args: (y[0], y[1] + thr)), ] - num_steps = [steps, steps, steps + 1, steps] - yevents = [(thr, 0), (thr, 0), (thr, 0), (thr, thr)] + num_steps = [steps, steps, steps + 1, steps, 0] + yevents = [(thr, 0), (thr, 0), (thr, 0), (thr, thr), (thr, 0)] for saveat, n, yevent in zip(saveats, num_steps, yevents): ts, ys = run(saveat) diff --git a/test/test_integrate.py b/test/test_integrate.py index 15d83f3e..cfcaadfd 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -334,6 +334,7 @@ def get_dt_and_controller(level): diffrax.SaveAt(t1=True), diffrax.SaveAt(ts=[3.5, 0.7]), diffrax.SaveAt(steps=True), + diffrax.SaveAt(steps=2), diffrax.SaveAt(dense=True), ), ) diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index 8ddca38d..59ca2fbd 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -111,7 +111,7 @@ def test_saveat_solution(): assert sol.stats["num_steps"] > 0 assert sol.result == diffrax.RESULTS.successful - saveat = diffrax.SaveAt(steps=True) + saveat = diffrax.SaveAt(steps=1) sol = _integrate(saveat) assert sol.t0 == _t0 assert sol.t1 == _t1 @@ -131,6 +131,49 @@ def test_saveat_solution(): assert sol.stats["num_steps"] > 0 assert sol.result == diffrax.RESULTS.successful + saveat = diffrax.SaveAt(steps=2) + sol = _integrate(saveat) + assert sol.t0 == _t0 + assert sol.t1 == _t1 + n = (4096 - 1) // 2 + 1 + assert sol.ts.shape == (n,) # pyright: ignore + assert sol.ys.shape == (n, 1) # pyright: ignore + _ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) + with jax.numpy_rank_promotion("allow"): + _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] + _ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys) + assert tree_allclose(sol.ys, _ys) + assert sol.controller_state is None + assert sol.solver_state is None + with pytest.raises(ValueError): + sol.evaluate(0.2, 0.8) + with pytest.raises(ValueError): + sol.derivative(0.2) + assert sol.stats["num_steps"] > 0 + assert sol.result == diffrax.RESULTS.successful + + saveat = diffrax.SaveAt(steps=2, t1=True) + sol = _integrate(saveat) + assert sol.t0 == _t0 + assert sol.t1 == _t1 + n = (4096 - 1) // 2 + 1 + assert sol.ts.shape == (n,) # pyright: ignore + assert sol.ys.shape == (n, 1) # pyright: ignore + _ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) + with jax.numpy_rank_promotion("allow"): + _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] + _ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys) + print(_ys) + assert tree_allclose(sol.ys, _ys) + assert sol.controller_state is None + assert sol.solver_state is None + with pytest.raises(ValueError): + sol.evaluate(0.2, 0.8) + with pytest.raises(ValueError): + sol.derivative(0.2) + assert sol.stats["num_steps"] > 0 + assert sol.result == diffrax.RESULTS.successful + saveat = diffrax.SaveAt(dense=True) sol = _integrate(saveat) assert sol.t0 == _t0 @@ -147,6 +190,45 @@ def test_saveat_solution(): assert sol.result == diffrax.RESULTS.successful +def test_saveat_solution_skip_steps(): + def _step_integrate(saveat: diffrax.SaveAt): + with jax.disable_jit(): + term = diffrax.ODETerm(lambda t, y, args: -0.5 * y) + ts = jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + sol_ts = diffrax.diffeqsolve( + term, + t0=ts[0], + t1=ts[-1], + y0=jnp.array([1.0]), + dt0=None, + solver=diffrax.Euler(), + saveat=saveat, + stepsize_controller=diffrax.StepTo(ts=ts), + max_steps=10, + ).ts + assert sol_ts is not None + return sol_ts[jnp.isfinite(sol_ts)] + + saveat = diffrax.SaveAt(steps=2) + ts = _step_integrate(saveat) + assert jnp.allclose(ts, jnp.array([1.0, 3.0, 5.0])) + saveat = diffrax.SaveAt(steps=2, t1=True) + ts = _step_integrate(saveat) + assert jnp.allclose(ts, jnp.array([1.0, 3.0, 5.0, 6.0])) + saveat = diffrax.SaveAt(steps=2, t1=True, t0=True) + ts = _step_integrate(saveat) + assert jnp.allclose(ts, jnp.array([0.0, 1.0, 3.0, 5.0, 6.0])) + saveat = diffrax.SaveAt(steps=3) + ts = _step_integrate(saveat) + assert jnp.allclose(ts, jnp.array([1.0, 4.0])) + saveat = diffrax.SaveAt(steps=3, t1=True) + ts = _step_integrate(saveat) + assert jnp.allclose(ts, jnp.array([1.0, 4.0, 6.0])) + saveat = diffrax.SaveAt(steps=3, t1=True, t0=True) + ts = _step_integrate(saveat) + assert jnp.allclose(ts, jnp.array([0.0, 1.0, 4.0, 6.0])) + + @pytest.mark.parametrize("subs", [True, False]) def test_t0_eq_t1(subs): y0 = jnp.array([2.0]) @@ -164,7 +246,7 @@ def test_t0_eq_t1(subs): get2 = diffrax.SubSaveAt( t0=True, ts=ts, - steps=True, + steps=1, ) subs = (get0, get1, get2) saveat = diffrax.SaveAt(subs=subs) @@ -220,7 +302,7 @@ def _solve(tf): get2 = diffrax.SubSaveAt( t0=True, ts=ts, - steps=True, + steps=1, fn=lambda t, y, args: jnp.where(jnp.isinf(y), 3.0, 4.0), ) subs = (get0, get1, get2) @@ -294,7 +376,7 @@ def test_subsaveat(adjoint, multi_subs, with_fn, getkey): subsaveat_kwargs: dict = dict() get2 = diffrax.SubSaveAt(t0=True, ts=jnp.linspace(0.5, 1.5, 3), **subsaveat_kwargs) if multi_subs: - get0 = diffrax.SubSaveAt(steps=True, fn=lambda _, y, __: y[0]) + get0 = diffrax.SubSaveAt(steps=1, fn=lambda _, y, __: y[0]) get1 = diffrax.SubSaveAt( ts=jnp.linspace(0, 1, 5), t1=True, fn=lambda _, y, __: y[1] )