diff --git a/README.md b/README.md index 4a8a7c7..cda3190 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,9 @@ TODO ## Development +New ideas or even documenting tricks (like how Jax works) should go +into the test suite, and there are some ideas floating there before +making it into the library itself. ``` git clone https://github.com/ins-amu/vbjax cd vbjax diff --git a/vbjax/tests/test_loops.py b/vbjax/tests/test_loops.py index f04d980..758cd36 100644 --- a/vbjax/tests/test_loops.py +++ b/vbjax/tests/test_loops.py @@ -14,7 +14,7 @@ def test_sde(): def test_ode(): - f = lambda x,_: -x + f = lambda x, _: -x dt = 0.1 _, run = vb.make_ode(dt, f) x0 = np.r_[:32].astype('f') diff --git a/vbjax/tests/test_uloop.py b/vbjax/tests/test_uloop.py new file mode 100644 index 0000000..2f75803 --- /dev/null +++ b/vbjax/tests/test_uloop.py @@ -0,0 +1,73 @@ +""" +Tests and benchmarks for a more generic uber loop. + +- constant args +- time-dep args +- jit-time constants vs args +- monitors +- jax.checkpoint + +""" + +import jax +import jax.numpy as np +import vbjax as vb + + +def _heun_step(x, dt, f, args): + f1 = f(x, args) + f2 = f(x + dt*f1, args) + return x + dt*0.5*(f1 + f2) + +def _dfun1(xy, args): + SC, a, k, stim = args + x, y = xy + c = np.dot(SC, x)*k + dx = 5.0*(x - x*x*x/3 + y) + dy = 0.2*(a - x + stim + c) + return np.array([dx, dy]) + + +def make_loop(dt, dfun, constants=()): + def loop(initial_state, parameters=(), t_parameters=()): + def step(state, t_parameters): + args = constants + parameters + t_parameters + next_state = _heun_step(state, dt, dfun, args) + return next_state, next_state + _, states = jax.lax.scan(step, initial_state, t_parameters) + return states + return loop + + +# can we go for the fully generic graph idea? +# stim(t) -> node -> monitor? +def make_loop_graph(steps, constants): + def loop(initials, params, Tdeps): + def body(states, tdeps): + nexts = tuple( + s(*args) + for s, *args in zip(steps, constants, params, states, tdeps) + ) + return nexts, nexts + _, sol = jax.lax.scan(body, initials, Tdeps) + return sol + return loop +# would be more flexible if we had a dict or struct +# e.g. jax dataclasses + + + + + + + + + + + + + + + + +