Skip to content

Commit

Permalink
Merge pull request #56 from ins-amu/dody
Browse files Browse the repository at this point in the history
Add the dopa model
  • Loading branch information
marmaduke woodman authored Feb 20, 2024
2 parents d46d648 + b470a6c commit fe025f0
Show file tree
Hide file tree
Showing 9 changed files with 368 additions and 3 deletions.
29 changes: 29 additions & 0 deletions examples/dopa-sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import jax.numpy as jp
import vbjax as vb
from vbjax.app.dopa import sweep_node, sweep_network

# start with default parameters
params = vb.dopa_default_theta

# update params and sweep over Km and Vmax
params = params._replace(
Eta=18.2,
Km=jp.r_[100:200:32j],
Vmax=jp.r_[1000:2000:32j],
)

# initial conditions
y0 = jp.array([0., -2.0, 0.0, 0.0, 0.0, 0.0])

# run sweep
end_time = 256.0
pkeys, ys = sweep_node(y0, params, T=end_time, cores=4)

# pkeys provides the names for the extra dims of ys result
print(pkeys, ys.shape)

# now similar for network sweep
n_nodes = 8
Ci, Ce, Cd = jp.zeros((3, n_nodes))
pkeys, ys = sweep_network(y0, params, Ci, Ce, Cd, T=end_time, cores=4)
print(pkeys, ys.shape)
9 changes: 8 additions & 1 deletion vbjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def _use_many_cores():
MPRState, MPRTheta, mpr_dfun, mpr_default_theta, mpr_r_positive,
BOLDTheta, compute_bold_theta, bold_default_theta, bold_dfun,
BVEPTheta, bvep_default_theta, bvep_dfun, DCMTheta, dcm_dfun,
DopaTheta, dopa_dfun, dopa_default_theta, dopa_default_initial_state,
dopa_net_dfun,
)
from .regmap import make_region_mapping
from .coupling import (
Expand All @@ -39,10 +41,15 @@ def _use_many_cores():
from .layers import make_dense_layers
from .diagnostics import shrinkage_zscore
from .embed import embed_neural_flow, embed_polynomial, embed_gradient, embed_autoregress
from .util import to_jax, to_np
from .util import to_jax, to_np, tuple_meshgrid, tuple_ravel, tuple_shard
from ._version import __version__

# some random setup for convenience
import jax
platform = jax.local_devices()[0].platform
is_gpu = platform == 'gpu'
is_cpu = platform == 'cpu'

from jax import random
from jax import numpy as np
key = random.PRNGKey(42)
Expand Down
1 change: 1 addition & 0 deletions vbjax/app/dopa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .sweep import sweep_network, sweep_node
84 changes: 84 additions & 0 deletions vbjax/app/dopa/sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# functions for running sweeps over the dopa model and collecting results

import jax
import jax.numpy as jp
import vbjax as vb


def sweep_node(init, params, T=10.0, dt=0.01, sigma=1e-3, seed=42, cores=4):
"Run sweep for single dopa node on params matrix"

# setup grid for parameters
pkeys, pgrid = vb.tuple_meshgrid(params)
pshape, pravel = vb.tuple_ravel(pgrid)

# distribute params for cpu; doesn't work for now
if vb.is_cpu:
pravel = vb.tuple_shard(pravel, cores)

# setup model
f = lambda x, p: vb.dopa_dfun(x, (0,0,0), p)
_, loop = vb.make_sde(dt, f, sigma)

# assume same inits and noise for all params
key = jax.random.PRNGKey(seed)
nt = int(T / dt)
dw = jax.random.normal(key, (nt, 6))

# run sweep
runv = jax.vmap(lambda p: loop(init, dw, p))
run_params = jax.jit(jax.vmap(runv) if vb.is_cpu else runv)
ys = run_params(pravel)

# reshape the resulting time series
# assert ys.shape == (pravel[0].size, nt, 6)
ys = ys.reshape(pshape + (nt, 6))

return pkeys, ys


def sweep_network(init, params, Ci, Ce, Cd,
T=10.0, dt=0.01, sigma=1e-3, seed=42, cores=4):
"Run sweep for single dopa node on params matrix"

# check & convert connectivities
assert Ci.shape == Ce.shape == Cd.shape
n_nodes = Ci.shape[0]
Ci, Ce, Cd = [jp.array(_.astype('f')) for _ in (Ci, Ce, Cd)]

# expand initial conditions if required
if init.ndim == 1:
init = jp.outer(init, jp.ones(n_nodes))

# setup grid for parameters
pkeys, pgrid = vb.tuple_meshgrid(params)
pshape, pravel = vb.tuple_ravel(pgrid)

# distribute params for cpu; doesn't work for now
if vb.is_cpu:
pravel = vb.tuple_shard(pravel, cores)

# setup model
_, loop = vb.make_sde(dt, vb.dopa_net_dfun, sigma)

# assume same inits and noise for all params
key = jax.random.PRNGKey(seed)
nt = int(T / dt)
dw = jax.random.normal(key, (nt, 6, n_nodes))

# run sweep
runv = jax.vmap(lambda p: loop(init, dw, (Ci,Ce,Cd,p)))
run_params = jax.jit(jax.vmap(runv) if vb.is_cpu else runv)
ys = run_params(pravel)

# reshape the resulting time series
# assert ys.shape == (pravel[0].size, nt, 6)
ys = ys.reshape(pshape + (nt, 6, n_nodes))

return pkeys, ys


if __name__ == '__main__':

# TODO set up an argparser to have a cli e.g. on slurm
pass
Empty file added vbjax/app/vep/__init__.py
Empty file.
44 changes: 44 additions & 0 deletions vbjax/neural_mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,47 @@ def dcm_dfun(x, u, p: DCMTheta):

# TODO other models
# TODO codim3 https://gist.github.com/maedoc/01cea5cad9c833c56349392ee7d9b627


DopaTheta = collections.namedtuple(
typename='dopaTheta',
field_names='a, b, c, ga, gg, Eta, Delta, Iext, Ea, Eg, Sja, Sjg, tauSa, tauSg, alpha, beta, ud, k, Km, Vmax, Bd, Ad, tau_Dp, wi, we, wd')

DopaState = collections.namedtuple(
typename='DopaState',
field_names='r V u Sa Sg Dp')

dopa_default_theta = DopaTheta(
a=0.04, b=5., c=140., ga=12., gg=12.,
Delta=1., Eta=18., Iext=0., Ea=0., Eg=-80., tauSa=5., tauSg=5., Sja=0.8, Sjg=1.2,
ud=12., alpha=0.013, beta=.4, k=10e4, Vmax=1300., Km=150., Bd=0.2, Ad=1., tau_Dp=500.,
wi=1.e-4, we=1.e-4, wd=1.e-4,
)

dopa_default_initial_state = DopaState(
r=0.0, V=-2.0, u=0.0, Sa=0.0, Sg=0.0, Dp=0.0)

def dopa_dfun(y, cy, p: DopaTheta):
"Adaptive QIF model with dopamine modulation."

r, V, u, Sa, Sg, Dp = y
c_inh, c_exc, c_dopa = cy
a, b, c, ga, gg, Eta, Delta, Iext, Ea, Eg, Sja, Sjg, tauSa, tauSg, alpha, beta, ud, k, Vmax, Km, Bd, Ad, tau_Dp, *_ = p

dr = 2. * a * r * V + b * r - ga * Sa * r - gg * Sg * r + (a * Delta) / np.pi
dV = a * V**2 + b * V + c + Eta - (np.pi**2 * r**2) / a + (Ad * Dp + Bd) * ga * Sa * (Ea - V) + gg * Sg * (Eg - V) + Iext - u
du = alpha * (beta * V - u) + ud * r
dSa = -Sa / tauSa + Sja * c_exc
dSg = -Sg / tauSg + Sjg * c_inh
dDp = (k * c_dopa - Vmax * Dp / (Km + Dp)) / tau_Dp

return np.array([dr, dV, du, dSa, dSg, dDp])

def dopa_net_dfun(y, p):
"Canonical form for network of dopa nodes."
Ci, Ce, Cd, node_params = p
r = y[0]
c_inh = node_params.wi * Ci @ r
c_exc = node_params.we * Ce @ r
c_dopa = node_params.wd * Cd @ r
return dopa_dfun(y, (c_inh, c_exc, c_dopa), node_params)
140 changes: 140 additions & 0 deletions vbjax/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import numpy as np
import jax.numpy as jp
import vbjax as vb


def true_dopa():

a=0.04
b=5.
c=140.
ga=12.
gg=12.
Delta=1.
Eta=18.
Iext=0.
Ea=0.
Eg=-80.
tauSa=5.
tauSg=5.
Sja=0.8
Sjg=1.2
ud=12.
alpha=0.013
beta=.4
k=10e4 #10e4,
Vmax=1300.
Km=150.
Bd=0.2
Ad=1.
tau_Dp=500.
params=np.array([a, b, c, ga, gg, Eta, Delta, Iext, Ea, Eg, Sja, Sjg, tauSa, tauSg, alpha, beta, ud, k, Vmax, Km, Bd, Ad, tau_Dp])

n_nodes = 8
conn_inhibitor, conn_excitator, conn_dopamine = np.random.randn(3, n_nodes, n_nodes)**2

dt = 0.01
t0 = 0.0
tf = 1.0
ckk= 1e-4 #coupling scaling
sigma=1e-3 #amplitude of noise - for sigma=0 --> Heun methd original
r0 = np.full(n_nodes, 0.1)
V0 = np.full(n_nodes, -70.0)
u0 = np.full(n_nodes, 0.0)
Sa0 = np.full(n_nodes, 0.0)
Sg0 = np.full(n_nodes, 0.0)
Dp0 = np.full(n_nodes, 0.05)
y0 = np.concatenate((r0, V0, u0, Sa0, Sg0, Dp0))

def aQIFdopa(y,t,params,coupling_inhibitor,coupling_excitator,coupling_dopamine):
r = y[0*n_nodes : 1*n_nodes]
V = y[1*n_nodes : 2*n_nodes]
u = y[2*n_nodes : 3*n_nodes]
Sa = y[3*n_nodes : 4*n_nodes]
Sg = y[4*n_nodes : 5*n_nodes]
Dp = y[5*n_nodes : 6*n_nodes]
a, b, c, ga, gg, Eta, Delta, Iext, Ea, Eg, Sja, Sjg, tauSa, tauSg, alpha, beta, ud, k, Vmax, Km, Bd, Ad, tau_Dp=params
c_inh = coupling_inhibitor
c_exc = coupling_excitator
c_dopa = coupling_dopamine

dydt = np.concatenate((
2. * a * r * V + b * r - ga * Sa * r - gg * Sg * r + (a * Delta) / np.pi,
a * V**2 + b * V + c + Eta - (np.pi**2 * r**2) / a + (Ad * Dp + Bd) * ga * Sa * (Ea - V) + gg * Sg * (Eg - V) + Iext - u,
alpha * (beta * V - u) + ud * r,
-Sa / tauSa + Sja * c_exc,
-Sg / tauSg + Sjg * c_inh,
(k * c_dopa - Vmax * Dp / (Km + Dp)) / tau_Dp
)).flatten()

return dydt

def network(y, t, ckk, params):
r = y[0*n_nodes : 1*n_nodes]

aff_inhibitor = conn_inhibitor @ r * ckk
aff_excitator = conn_excitator @ r * ckk
aff_dopamine = conn_dopamine @ r * ckk

dx = aQIFdopa(y, t, params, aff_inhibitor, aff_excitator, aff_dopamine)
return dx

def heun_SDE(network,y0,t0,t_max,dt,params,ckk,sigma):
num_steps = int((t_max - t0) / dt)
y_all = np.empty((num_steps, len(y0)))
t_all = np.empty((num_steps, ))
stochastic_matrix = np.random.normal(0, 1, (len(y0),num_steps))
t=t0; i=0
t_all[i] = t0
y_all[i, :] = y0
y=y0
for step in range(num_steps):
dw = stochastic_matrix[:,step]*sigma * np.sqrt(dt)
dy1 = network(y, t, ckk,params)
ye = y + dt * dy1 + dw
y = y + 0.5 * dt * (dy1 + network(ye, t + dt, ckk,params)) + dw
t=t+dt
t_all[i]=t
y_all[i,:]=y
i+=1
return y_all, t_all, stochastic_matrix.T

y1, t1, dw = heun_SDE(network,y0,t0,tf,dt,params,ckk,sigma)
return y1, t1, dw, ckk, params, conn_inhibitor, conn_excitator, conn_dopamine, n_nodes, r0, V0, u0, Sa0, Sg0, Dp0, network, dt, sigma

def test_dopa():

y1, t1, dw, ckk, params, conn_inhibitor, conn_excitator, conn_dopamine, n_nodes, r0, V0, u0, Sa0, Sg0, Dp0, network, dt, sigma = true_dopa()

_, loop = vb.make_sde(dt=dt, dfun=vb.dopa_net_dfun, gfun=sigma)
j_y0 = jp.array([r0, V0, u0, Sa0, Sg0, Dp0])
j_params = vb.DopaTheta(*params, wi=ckk, we=ckk, wd=ckk)
j_Ci, j_Ce, j_Cd = [jp.array(_) for _ in (conn_inhibitor, conn_excitator, conn_dopamine)]
j_dw = jp.array(dw).reshape(-1, 6, n_nodes)
assert j_dw.shape == (t1.size, 6, n_nodes)
j_y2 = loop(j_y0, j_dw, (j_Ci, j_Ce, j_Cd, j_params))

# compare derivatives
for i in range(t1.size):
dy1 = network(y1[i], t1[i], ckk, params).reshape((6, -1))
dy2 = vb.dopa_net_dfun(y1[i].reshape((6,-1)), (j_Ci, j_Ce, j_Cd, j_params))
for j in range(6):
np.testing.assert_allclose(dy1[j], dy2[j], rtol=1e-5, atol=1e-5)

# compare trajectories
y1_ = y1.reshape((-1, 6, n_nodes))
if False:
# do plots
import matplotlib.pyplot as pl
for i in range(6):
pl.subplot(3, 2, i + 1)
pl.plot(t1, y1_[:,i], 'k', alpha=0.2)
pl.plot(t1, j_y2[i], 'r', alpha=0.2)
pl.grid(1)
np.testing.assert_allclose(y1_[:,i], j_y2[i])
pl.savefig('dopa.png', dpi=300)
else:
# don't bother plots just assert all close each var
for i in range(6):
np.testing.assert_allclose(y1_[:,i], j_y2[:,i], rtol=1e-5, atol=1e-5)

2 changes: 1 addition & 1 deletion vbjax/tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _test_spmv(spmv, A, n):
numpy.testing.assert_allclose(jb, nb, 1e-4, 1e-6)

# now its gradient
jax.test_util.check_grads(spmv, (jx,), order=1, modes=('rev',))
jax.test_util.check_grads(spmv, (jx,), order=1, modes=('rev',), atol=0.02, rtol=0.002)


def test_csr_scipy():
Expand Down
Loading

0 comments on commit fe025f0

Please sign in to comment.