Skip to content

Added steps=n logic for skipping steps #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
29 changes: 21 additions & 8 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,14 +487,24 @@ def maybe_inplace(i, u, x):
return eqxi.buffer_at_set(x, i, u, pred=keep_step)

def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.steps:
if subsaveat.steps != 0:
save_step = (state.num_accepted_steps % subsaveat.steps) == 0
should_save = keep_step & save_step

def save_fn(tprev, y, args):
return lax.cond(
should_save,
lambda: subsaveat.fn(tprev, y, args),
lambda: jnp.zeros_like(save_state.ys[0]),
)

ts = maybe_inplace(save_state.save_index, tprev, save_state.ts)
ys = jtu.tree_map(
ft.partial(maybe_inplace, save_state.save_index),
subsaveat.fn(tprev, y, args),
save_fn(tprev, y, args),
save_state.ys,
)
save_index = save_state.save_index + jnp.where(keep_step, 1, 0)
save_index = save_state.save_index + jnp.where(should_save, 1, 0)
save_state = eqx.tree_at(
lambda s: [s.ts, s.ys, s.save_index],
save_state,
Expand All @@ -505,7 +515,6 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
save_state = jtu.tree_map(
save_steps, saveat.subs, save_state, is_leaf=_is_subsaveat
)

if saveat.dense:
dense_ts = maybe_inplace(dense_save_index + 1, tprev, dense_ts)
dense_infos = jtu.tree_map(
Expand Down Expand Up @@ -1229,16 +1238,20 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
out_size += 1
if subsaveat.ts is not None:
out_size += len(subsaveat.ts)
if subsaveat.steps:
if subsaveat.steps != 0:
# We have no way of knowing how many steps we'll actually end up taking, and
# XLA doesn't support dynamic shapes. So we just have to allocate the
# maximum amount of steps we can possibly take.
if max_steps is None:
raise ValueError(
"`max_steps=None` is incompatible with saving at `steps=True`"
"`max_steps=None` is incompatible with saving at `steps=n`"
)
out_size += max_steps
if subsaveat.t1 and not subsaveat.steps:
out_size += max_steps // subsaveat.steps
if subsaveat.t1 and (
(max_steps is None)
or (subsaveat.steps == 0)
or (max_steps % subsaveat.steps != 0)
):
out_size += 1
saveat_ts_index = 0
save_index = 0
Expand Down
36 changes: 28 additions & 8 deletions diffrax/_saveat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,29 @@ class SubSaveAt(eqx.Module):
relatively niche feature and most users will probably not need to use `SubSaveAt`.)
"""

t0: bool = False
t1: bool = False
ts: Optional[Real[Array, " times"]] = eqx.field(default=None, converter=_convert_ts)
steps: bool = False
fn: Callable = save_y
t0: bool
t1: bool
ts: Optional[Real[Array, " times"]]
steps: int
fn: Callable

def __init__(
self,
*,
t0: bool = False,
t1: bool = False,
ts: Union[None, Sequence[RealScalarLike], Real[Array, " times"]] = None,
steps: Union[bool, int] = 0,
fn: Callable = save_y,
):
self.t0 = t0
self.t1 = t1
self.ts = _convert_ts(ts)
if isinstance(steps, bool):
self.steps = 1 if steps else 0
else:
self.steps = steps
self.fn = fn

def __check_init__(self):
if not self.t0 and not self.t1 and self.ts is None and not self.steps:
Expand All @@ -45,7 +63,8 @@ def __check_init__(self):
- `t0`: If `True`, save the initial input `y0`.
- `t1`: If `True`, save the output at `t1`.
- `ts`: Some array of times at which to save the output.
- `steps`: If `True`, save the output at every step of the numerical solver.
- `steps`: If `n>0`, save the output at every `n`th step of the numerical solver.
`0` means no saving.
- `fn`: A function `fn(t, y, args)` which specifies what to save into `sol.ys` when
using `t0`, `t1`, `ts` or `steps`. Defaults to `fn(t, y, args) -> y`, so that the
evolving solution is saved. This can be useful to save only statistics of your
Expand All @@ -72,7 +91,7 @@ def __init__(
t0: bool = False,
t1: bool = False,
ts: Union[None, Sequence[RealScalarLike], Real[Array, " times"]] = None,
steps: bool = False,
steps: Union[bool, int] = False,
fn: Callable = save_y,
subs: PyTree[SubSaveAt] = None,
dense: bool = False,
Expand Down Expand Up @@ -101,7 +120,8 @@ def __init__(
- `t0`: If `True`, save the initial input `y0`.
- `t1`: If `True`, save the output at `t1`.
- `ts`: Some array of times at which to save the output.
- `steps`: If `True`, save the output at every step of the numerical solver.
- `steps`: If `n>0`, save the output at every `n`th step of the numerical solver.
`0` means no saving.
- `dense`: If `True`, save dense output, that can later be evaluated at any part of
the interval $[t_0, t_1]$ via `sol = diffeqsolve(...); sol.evaluate(...)`.

Expand Down
6 changes: 3 additions & 3 deletions test/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def cond_fn_2(t, y, args, **kwargs):


@pytest.mark.parametrize("steps", (1, 2, 3, 4, 5))
def test_event_save_steps(steps):
def test_event_save_all_steps(steps):
term = diffrax.ODETerm(lambda t, y, args: (1.0, 1.0))
solver = diffrax.Tsit5()
t0 = 0
Expand Down Expand Up @@ -601,8 +601,8 @@ def run(saveat):
diffrax.SaveAt(steps=True, t1=True, t0=True),
diffrax.SaveAt(steps=True, fn=lambda t, y, args: (y[0], y[1] + thr)),
]
num_steps = [steps, steps, steps + 1, steps]
yevents = [(thr, 0), (thr, 0), (thr, 0), (thr, thr)]
num_steps = [steps, steps, steps + 1, steps, 0]
yevents = [(thr, 0), (thr, 0), (thr, 0), (thr, thr), (thr, 0)]

for saveat, n, yevent in zip(saveats, num_steps, yevents):
ts, ys = run(saveat)
Expand Down
1 change: 1 addition & 0 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def get_dt_and_controller(level):
diffrax.SaveAt(t1=True),
diffrax.SaveAt(ts=[3.5, 0.7]),
diffrax.SaveAt(steps=True),
diffrax.SaveAt(steps=2),
diffrax.SaveAt(dense=True),
),
)
Expand Down
90 changes: 86 additions & 4 deletions test/test_saveat_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_saveat_solution():
assert sol.stats["num_steps"] > 0
assert sol.result == diffrax.RESULTS.successful

saveat = diffrax.SaveAt(steps=True)
saveat = diffrax.SaveAt(steps=1)
sol = _integrate(saveat)
assert sol.t0 == _t0
assert sol.t1 == _t1
Expand All @@ -131,6 +131,49 @@ def test_saveat_solution():
assert sol.stats["num_steps"] > 0
assert sol.result == diffrax.RESULTS.successful

saveat = diffrax.SaveAt(steps=2)
sol = _integrate(saveat)
Comment on lines +134 to +135
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add tests for SaveAt(steps=n, t1=True) and SaveAt(steps=n, t1=False) when the number of steps does not divide n. Then check that in the first case with e.g. n=2 and 5 steps that we store [0 2 4 5] and in the second case we store [0 2 4]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the tests, but I can not make it work. See below

assert sol.t0 == _t0
assert sol.t1 == _t1
n = (4096 - 1) // 2 + 1
assert sol.ts.shape == (n,) # pyright: ignore
assert sol.ys.shape == (n, 1) # pyright: ignore
_ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts)
with jax.numpy_rank_promotion("allow"):
_ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None]
_ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys)
assert tree_allclose(sol.ys, _ys)
assert sol.controller_state is None
assert sol.solver_state is None
with pytest.raises(ValueError):
sol.evaluate(0.2, 0.8)
with pytest.raises(ValueError):
sol.derivative(0.2)
assert sol.stats["num_steps"] > 0
assert sol.result == diffrax.RESULTS.successful

saveat = diffrax.SaveAt(steps=2, t1=True)
sol = _integrate(saveat)
assert sol.t0 == _t0
assert sol.t1 == _t1
n = (4096 - 1) // 2 + 1
assert sol.ts.shape == (n,) # pyright: ignore
assert sol.ys.shape == (n, 1) # pyright: ignore
_ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts)
with jax.numpy_rank_promotion("allow"):
_ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None]
_ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys)
print(_ys)
assert tree_allclose(sol.ys, _ys)
assert sol.controller_state is None
assert sol.solver_state is None
with pytest.raises(ValueError):
sol.evaluate(0.2, 0.8)
with pytest.raises(ValueError):
sol.derivative(0.2)
assert sol.stats["num_steps"] > 0
assert sol.result == diffrax.RESULTS.successful

saveat = diffrax.SaveAt(dense=True)
sol = _integrate(saveat)
assert sol.t0 == _t0
Expand All @@ -147,6 +190,45 @@ def test_saveat_solution():
assert sol.result == diffrax.RESULTS.successful


def test_saveat_solution_skip_steps():
def _step_integrate(saveat: diffrax.SaveAt):
with jax.disable_jit():
term = diffrax.ODETerm(lambda t, y, args: -0.5 * y)
ts = jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
sol_ts = diffrax.diffeqsolve(
term,
t0=ts[0],
t1=ts[-1],
y0=jnp.array([1.0]),
dt0=None,
solver=diffrax.Euler(),
saveat=saveat,
stepsize_controller=diffrax.StepTo(ts=ts),
max_steps=10,
).ts
assert sol_ts is not None
return sol_ts[jnp.isfinite(sol_ts)]

saveat = diffrax.SaveAt(steps=2)
ts = _step_integrate(saveat)
assert jnp.allclose(ts, jnp.array([1.0, 3.0, 5.0]))
saveat = diffrax.SaveAt(steps=2, t1=True)
ts = _step_integrate(saveat)
assert jnp.allclose(ts, jnp.array([1.0, 3.0, 5.0, 6.0]))
saveat = diffrax.SaveAt(steps=2, t1=True, t0=True)
ts = _step_integrate(saveat)
assert jnp.allclose(ts, jnp.array([0.0, 1.0, 3.0, 5.0, 6.0]))
saveat = diffrax.SaveAt(steps=3)
ts = _step_integrate(saveat)
assert jnp.allclose(ts, jnp.array([1.0, 4.0]))
saveat = diffrax.SaveAt(steps=3, t1=True)
ts = _step_integrate(saveat)
assert jnp.allclose(ts, jnp.array([1.0, 4.0, 6.0]))
saveat = diffrax.SaveAt(steps=3, t1=True, t0=True)
ts = _step_integrate(saveat)
assert jnp.allclose(ts, jnp.array([0.0, 1.0, 4.0, 6.0]))


@pytest.mark.parametrize("subs", [True, False])
def test_t0_eq_t1(subs):
y0 = jnp.array([2.0])
Expand All @@ -164,7 +246,7 @@ def test_t0_eq_t1(subs):
get2 = diffrax.SubSaveAt(
t0=True,
ts=ts,
steps=True,
steps=1,
)
subs = (get0, get1, get2)
saveat = diffrax.SaveAt(subs=subs)
Expand Down Expand Up @@ -220,7 +302,7 @@ def _solve(tf):
get2 = diffrax.SubSaveAt(
t0=True,
ts=ts,
steps=True,
steps=1,
fn=lambda t, y, args: jnp.where(jnp.isinf(y), 3.0, 4.0),
)
subs = (get0, get1, get2)
Expand Down Expand Up @@ -294,7 +376,7 @@ def test_subsaveat(adjoint, multi_subs, with_fn, getkey):
subsaveat_kwargs: dict = dict()
get2 = diffrax.SubSaveAt(t0=True, ts=jnp.linspace(0.5, 1.5, 3), **subsaveat_kwargs)
if multi_subs:
get0 = diffrax.SubSaveAt(steps=True, fn=lambda _, y, __: y[0])
get0 = diffrax.SubSaveAt(steps=1, fn=lambda _, y, __: y[0])
get1 = diffrax.SubSaveAt(
ts=jnp.linspace(0, 1, 5), t1=True, fn=lambda _, y, __: y[1]
)
Expand Down