Skip to content

Commit

Permalink
simplify hcp example
Browse files Browse the repository at this point in the history
  • Loading branch information
marmaduke woodman committed Feb 15, 2024
1 parent 628d034 commit d46d648
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 29 deletions.
48 changes: 25 additions & 23 deletions examples/delays-hcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,27 @@ def load(fname):

# setup delays
dt = 0.1
dh = vb.make_delay_helper( jp.log(W+1), L, dt=dt)
dh = vb.make_delay_helper( weights=jp.log(W+1), lengths=L, dt=dt)

# define parameters
from collections import namedtuple
Params = namedtuple('Params', 'dh theta k')
# define parameters for dfun
params = {
'dh': dh,
'theta': vb.mpr_default_theta,
'k': 0.01
}

# define our model
def dfun(buf, rv, t: int, p: Params):
crv = vb.delay_apply(p.dh, t, buf) # compute delay coupling
return vb.mpr_dfun(rv, p.k*crv, p.theta) # compute dynamics

def ensure_r_positive(rv, _):
r, v = rv
return jp.array([ r*(r>0), v ])
def dfun(buf, rv, t: int, p):
crv = vb.delay_apply(p['dh'], t, buf) # compute delay coupling
return vb.mpr_dfun(rv, p['k']*crv, p['theta']) # compute dynamics

# buf should cover all delays + noise for time steps to take
chunk_len = int(10 / dt) # 10 ms
buf = jp.zeros((dh.max_lag + chunk_len, 2, dh.n_from))
buf = buf.at[:dh.max_lag+1].add( jp.r_[0.1,-2.0].reshape(2,1) )

# compile model and enable continuations
_, run_chunk = vb.make_sdde(dt, dh.max_lag, dfun, gfun=1e-3, unroll=10, adhoc=ensure_r_positive)
_, run_chunk = vb.make_sdde(dt, dh.max_lag, dfun, gfun=1e-3, unroll=10, adhoc=vb.mpr_r_positive)
cont_chunk = vb.make_continuation(run_chunk, chunk_len, dh.max_lag, dh.n_from, n_svar=2, stochastic=True)

# setup time avg and bold monitors
Expand All @@ -62,25 +61,28 @@ def ensure_r_positive(rv, _):
bold_samp = vb.make_offline(bold_step, bold_samp)

# run chunk w/ monitors
def chunk_ta_bold(bufs, key):
p, buf, ta_buf, bold_buf = bufs
buf, rv = cont_chunk(buf, p, key)
ta_buf, ta = ta_samp(ta_buf, rv)
bold_buf, bold = bold_samp(bold_buf, rv)
return (p, buf, ta_buf, bold_buf), (ta, bold)
def chunk_ta_bold(sim, key):
sim['buf'], rv = cont_chunk(sim['buf'], sim['params'], key)
sim['ta_buf'], ta = ta_samp(sim['ta_buf'], rv)
sim['bold_buf'], bold = bold_samp(sim['bold_buf'], rv)
return sim, (ta, bold)

@jax.jit
def run_one_second(bufs, key):
def run_one_second(sim, key):
keys = jax.random.split(key, 100) # 100 * 10 ms
return jax.lax.scan(chunk_ta_bold, bufs, keys)
return jax.lax.scan(chunk_ta_bold, sim, keys)

# pack buffers and run it one minute
params = Params(dh, vb.mpr_default_theta, 0.01)
bufs = params, buf, ta_buf, bold_buf
sim = {
'params': params,
'buf': buf,
'ta_buf': ta_buf,
'bold_buf': bold_buf
}
ta, bold = [], []
keys = jax.random.split(jax.random.PRNGKey(42), 60)
for i, key in enumerate(tqdm.tqdm(keys)):
bufs, (ta_i, bold_i) = run_one_second(bufs, key)
sim, (ta_i, bold_i) = run_one_second(sim, key)
ta.append(ta_i)
bold.append(bold_i)
ta = jp.array(ta).reshape((-1, 2, 70))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies = [
"jaxlib",
"scipy",
"numpyro",
"jax-dataclasses",
]
classifiers = [
"Programming Language :: Python :: 3",
Expand All @@ -40,7 +41,6 @@ dev = [
"grip",
"python-lsp-server[all]",
"jedi-language-server",
"jax-dataclasses",
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion vbjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _use_many_cores():
from .shtlc import make_shtdiff
from .neural_mass import (
JRState, JRTheta, jr_dfun, jr_default_theta,
MPRState, MPRTheta, mpr_dfun, mpr_default_theta,
MPRState, MPRTheta, mpr_dfun, mpr_default_theta, mpr_r_positive,
BOLDTheta, compute_bold_theta, bold_default_theta, bold_dfun,
BVEPTheta, bvep_default_theta, bvep_dfun, DCMTheta, dcm_dfun,
)
Expand Down
13 changes: 9 additions & 4 deletions vbjax/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,21 @@

DelayHelper = namedtuple('DelayHelper', 'Wt lags ix_lag_from max_lag n_to n_from')

def make_delay_helper(W, L, dt=0.1, v_c=10.0) -> DelayHelper:
n_to, n_from = W.shape
lags = jp.floor(L / v_c / dt).astype('i')
def make_delay_helper(weights, lengths, dt=0.1, v_c=10.0) -> DelayHelper:
"""Construct a helper with auxiliary variables for applying
delays to a buffer.
"""
n_to, n_from = weights.shape
lags = jp.floor(lengths / v_c / dt).astype('i')
ix_lag_from = jp.tile(jp.r_[:n_from], (n_to, 1))
max_lag = lags.max() + 1
Wt = W.T[:,:,None] # enable bcast for coupling vars
Wt = weights.T[:,:,None] # enable bcast for coupling vars
dh = DelayHelper(Wt, lags, ix_lag_from, max_lag, n_to, n_from)
return dh

def delay_apply(dh: DelayHelper, t, buf):
"""Apply delays to buffer `buf` at time `t`.
"""
return (dh.Wt * buf[t - dh.lags, :, dh.ix_lag_from]).sum(axis=1).T

# TODO impl sparse delay_apply
Expand Down
4 changes: 4 additions & 0 deletions vbjax/neural_mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def mpr_dfun(ys, c, p):
r + p.I + I_c - (np.pi ** 2) * (r ** 2) * (p.tau ** 2))
])

def mpr_r_positive(rv, _):
r, v = rv
return np.array([ r*(r>0), v ])


BOLDTheta = collections.namedtuple(
typename='BOLDTheta',
Expand Down

0 comments on commit d46d648

Please sign in to comment.