Skip to content

Commit

Permalink
<"(((<3 refactor heun step
Browse files Browse the repository at this point in the history
  • Loading branch information
marmaduke woodman committed Nov 10, 2023
1 parent c7bc8ba commit a822d04
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 28 deletions.
3 changes: 2 additions & 1 deletion vbjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ def _use_many_cores():
n = mp.cpu_count()
value = '--xla_force_host_platform_device_count=%d' % n
os.environ['XLA_FLAGS'] = value
print(f'vbjax ( ˘ ³˘)ノ°゚º❍。 using {n} cores')
else:
sys.stderr.write('XLA_FLAGS already set\n')

_use_many_cores()

# import stuff
from .loops import make_sde, make_ode, make_dde, make_sdde
from .loops import make_sde, make_ode, make_dde, make_sdde, heun_step
from .shtlc import make_shtdiff
from .neural_mass import (
JRState, JRTheta, jr_dfun, jr_default_theta,
Expand Down
18 changes: 10 additions & 8 deletions vbjax/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
import jax.numpy as np


def heun_step(x, dfun, dt, *args, add=0):
"""Use a Heun scheme to step state with a right hand sides dfun(.)
and additional forcing term add.
"""
d1 = dfun(x, *args)
d2 = dfun(x + dt*d1 + add, *args)
return x + dt*0.5*(d1 + d2) + add

def make_sde(dt, dfun, gfun):
"""Use a stochastic Heun scheme to integrate autonomous stochastic
differential equations (SDEs).
Expand Down Expand Up @@ -59,10 +67,7 @@ def make_sde(dt, dfun, gfun):

def step(x, z_t, p):
noise = gfun(x, p) * sqrt_dt * z_t
d1 = dfun(x, p)
x1 = x + dt*d1 + noise
d2 = dfun(x1, p)
return x + dt*0.5*(d1 + d2) + noise
return heun_step(x, dfun, dt, p, add=noise)

@jax.jit
def loop(x0, zs, p):
Expand Down Expand Up @@ -109,10 +114,7 @@ def make_ode(dt, dfun):
"""

def step(x, t, p):
d1 = dfun(x, p)
x1 = x + dt*d1
d2 = dfun(x1, p)
return x + dt*0.5*(d1 + d2)
return heun_step(x, dfun, dt, p)

@jax.jit
def loop(x0, ts, p):
Expand Down
30 changes: 11 additions & 19 deletions vbjax/tests/test_jaxisms.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ def loop(ns: dict, inps: np.ndarray):
assert ns2["int"] == 8


def _dict_step(sim, t):
x = sim['x']
z = sim['noise'] * jax.random.normal(sim['rng'], shape=x.shape)
f = lambda x: x - x**3/3 + sim['coupling'] * sim['weights']@x
x = vb.heun_step(x, f, sim['dt'])
sim['x'] = x
eeg = sim['eeg_gain'] @ x
return sim, eeg

def test_loop_dict2():

# a more complicated example
Expand All @@ -119,16 +128,8 @@ def test_loop_dict2():
'rng': jax.random.PRNGKey(42)
}

def step(sim, t):
x = sim['x']
z = sim['noise'] * jax.random.normal(sim['rng'], shape=x.shape)
x = x + sim['dt']*(x - x**3/3 + sim['coupling'] * sim['weights']@x) + z
sim['x'] = x
eeg = sim['eeg_gain'] @ x
return sim, eeg

def loop(sim, ts):
sim, eegs = jax.lax.scan(step, sim, ts)
sim, eegs = jax.lax.scan(_dict_step, sim, ts)
return eegs

params = {
Expand Down Expand Up @@ -162,17 +163,8 @@ def test_loop_dict_vmap():
'rng': jax.random.PRNGKey(42)
}


def step(sim, t):
x = sim['x']
z = sim['noise'] * jax.random.normal(sim['rng'], shape=x.shape)
x = x + sim['dt']*(x - x**3/3 + sim['coupling'] * sim['weights']@x) + z
sim['x'] = x
eeg = sim['eeg_gain'] @ x
return sim, eeg

def loop(sim, ts):
sim, eegs = jax.lax.scan(step, sim, ts)
sim, eegs = jax.lax.scan(_dict_step, sim, ts)
return eegs

# to vmap we need specific args to vmap over
Expand Down

0 comments on commit a822d04

Please sign in to comment.