Skip to content

Commit

Permalink
(ノ☉ヮ⚆)ノ ⌒*:・゚✧ testing some uber loop ideas
Browse files Browse the repository at this point in the history
  • Loading branch information
marmaduke woodman committed Nov 8, 2023
1 parent 74dc9e7 commit 6eb964b
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 1 deletion.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ TODO

## Development

New ideas or even documenting tricks (like how Jax works) should go
into the test suite, and there are some ideas floating there before
making it into the library itself.
```
git clone https://github.com/ins-amu/vbjax
cd vbjax
Expand Down
2 changes: 1 addition & 1 deletion vbjax/tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_sde():


def test_ode():
f = lambda x,_: -x
f = lambda x, _: -x
dt = 0.1
_, run = vb.make_ode(dt, f)
x0 = np.r_[:32].astype('f')
Expand Down
73 changes: 73 additions & 0 deletions vbjax/tests/test_uloop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Tests and benchmarks for a more generic uber loop.
- constant args
- time-dep args
- jit-time constants vs args
- monitors
- jax.checkpoint
"""

import jax
import jax.numpy as np
import vbjax as vb


def _heun_step(x, dt, f, args):
f1 = f(x, args)
f2 = f(x + dt*f1, args)
return x + dt*0.5*(f1 + f2)

def _dfun1(xy, args):
SC, a, k, stim = args
x, y = xy
c = np.dot(SC, x)*k
dx = 5.0*(x - x*x*x/3 + y)
dy = 0.2*(a - x + stim + c)
return np.array([dx, dy])


def make_loop(dt, dfun, constants=()):
def loop(initial_state, parameters=(), t_parameters=()):
def step(state, t_parameters):
args = constants + parameters + t_parameters
next_state = _heun_step(state, dt, dfun, args)
return next_state, next_state
_, states = jax.lax.scan(step, initial_state, t_parameters)
return states
return loop


# can we go for the fully generic graph idea?
# stim(t) -> node -> monitor?
def make_loop_graph(steps, constants):
def loop(initials, params, Tdeps):
def body(states, tdeps):
nexts = tuple(
s(*args)
for s, *args in zip(steps, constants, params, states, tdeps)
)
return nexts, nexts
_, sol = jax.lax.scan(body, initials, Tdeps)
return sol
return loop
# would be more flexible if we had a dict or struct
# e.g. jax dataclasses

















0 comments on commit 6eb964b

Please sign in to comment.