Skip to content

Commit

Permalink
[ ] adhoc bounds support
Browse files Browse the repository at this point in the history
  • Loading branch information
Marmaduke Woodman committed Dec 7, 2023
1 parent 1465109 commit cf44ef1
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 25 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@

## Installation

Installs with `pip install "vbjax"`, but you can use the source,
Installs with `pip install "vbjax"`, but for the latest features,
you can use the source,
```bash
git clone https://github.com/ins-amu/vbjax
cd vbjax
pip install .[dev]
pip install -e ".[dev]"
```
You're encouraged to have the source handy to consult and change, but you can also just
```bash
pip install git+https://github.com/ins-amu/vbjax
```

The primary additional dependency of vbjax is
[JAX](github.com/google/jax), which itself depends only on
NumPy, SciPy & opt-einsum, so it should be safe to add to your
Expand Down
6 changes: 5 additions & 1 deletion examples/delays-hcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ def dfun(buf, rv, t: int, p):
crv = (Wt * buf[t - lags, :, ix_lag_from]).sum(axis=1).T
return vb.mpr_dfun(rv, k*crv, mpr_theta)

def rgt0(rv, p):
r, v = rv
return jp.array([ r*(r>0), v ])

# compile dfun w/ heun sdde for running a chunk
_, run_chunk = vb.make_sdde(dt, max_lag, dfun, gfun=1e-3, unroll=4)
_, run_chunk = vb.make_sdde(dt, max_lag, dfun, gfun=1e-3, unroll=4, adhoc=rgt0)

# we'll run chunk of time this long
chunk_len = 1000
Expand Down
61 changes: 39 additions & 22 deletions vbjax/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
import jax.numpy as np


def heun_step(x, dfun, dt, *args, add=0):
def heun_step(x, dfun, dt, *args, add=0, adhoc=None):
"""Use a Heun scheme to step state with a right hand sides dfun(.)
and additional forcing term add.
"""
adhoc = adhoc or (lambda x,*args: x)
d1 = dfun(x, *args)
d2 = dfun(x + dt*d1 + add, *args)
return x + dt*0.5*(d1 + d2) + add
xi = adhoc(x + dt*d1 + add, *args)
d2 = dfun(xi, *args)
nx = adhoc(x + dt*0.5*(d1 + d2) + add, *args)
return nx

def make_sde(dt, dfun, gfun):
def make_sde(dt, dfun, gfun, adhoc=None):
"""Use a stochastic Heun scheme to integrate autonomous stochastic
differential equations (SDEs).
Expand All @@ -31,6 +34,9 @@ def make_sde(dt, dfun, gfun):
of the stochastic differential equation. If a numerical value is
provided, this is used as a constant diffusion coefficient for additive
linear SDE.
adhoc : function or None
Function of the form `f(x, p)` that allows making adhoc corrections
to states after a step.
Returns
=======
Expand Down Expand Up @@ -67,7 +73,7 @@ def make_sde(dt, dfun, gfun):

def step(x, z_t, p):
noise = gfun(x, p) * sqrt_dt * z_t
return heun_step(x, dfun, dt, p, add=noise)
return heun_step(x, dfun, dt, p, add=noise, adhoc=adhoc)

@jax.jit
def loop(x0, zs, p):
Expand All @@ -79,7 +85,7 @@ def op(x, z):
return step, loop


def make_ode(dt, dfun):
def make_ode(dt, dfun, adhoc=None):
"""Use a Heun scheme to integrate autonomous ordinary differential
equations (ODEs).
Expand All @@ -90,6 +96,9 @@ def make_ode(dt, dfun):
dfun : function
Function of the form `dfun(x, p)` that computes derivatives of the
ordinary differential equations.
adhoc : function or None
Function of the form `f(x, p)` that allows making adhoc corrections
to states after a step.
Returns
=======
Expand All @@ -114,7 +123,7 @@ def make_ode(dt, dfun):
"""

def step(x, t, p):
return heun_step(x, dfun, dt, p)
return heun_step(x, dfun, dt, p, adhoc=adhoc)

@jax.jit
def loop(x0, ts, p):
Expand All @@ -126,12 +135,12 @@ def op(x, t):
return step, loop


def make_dde(dt, nh, dfun, unroll=10):
def make_dde(dt, nh, dfun, unroll=10, adhoc=None):
"Invokes make_sdde w/ gfun 0."
return make_sdde(dt, nh, dfun, 0, unroll)
return make_sdde(dt, nh, dfun, 0, unroll, adhoc=adhoc)


def make_sdde(dt, nh, dfun, gfun, unroll=1, zero_delays=False):
def make_sdde(dt, nh, dfun, gfun, unroll=1, zero_delays=False, adhoc=None):
"""Use a stochastic Heun scheme to integrate autonomous
stochastic delay differential equations (SDEs).
Expand All @@ -149,6 +158,9 @@ def make_sdde(dt, nh, dfun, gfun, unroll=1, zero_delays=False):
of the stochastic differential equation. If a numerical value is
provided, this is used as a constant diffusion coefficient for additive
linear SDE.
adhoc : function or None
Function of the form `f(x,p)` that allows making adhoc corrections after
each step.
Returns
=======
Expand Down Expand Up @@ -189,28 +201,33 @@ def make_sdde(dt, nh, dfun, gfun, unroll=1, zero_delays=False):
sig = gfun
gfun = lambda *_: sig

def step(xt_t, z_t, p):
xt, t = xt_t
x = xt[nh + t]
if adhoc is None:
adhoc = lambda x,p : x

def step(buf_t, z_t, p):
buf, t = buf_t
x = buf[nh + t]
noise = gfun(x, p) * sqrt_dt * z_t
d1 = dfun(xt, x, t, p)
xi = x + dt*d1 + noise
d1 = dfun(buf, x, nh + t, p)
xi = adhoc(x + dt*d1 + noise, p)
if heun:
if zero_delays:
# severe performance hit (5x+)
xt = xt.at[nh + t + 1].set(xi)
d2 = dfun(xt, xi, t+1, p)
nx = x + dt*0.5*(d1 + d2) + noise
buf = buf.at[nh + t + 1].set(xi)
d2 = dfun(buf, xi, nh + t + 1, p)
nx = adhoc(x + dt*0.5*(d1 + d2) + noise, p)
else:
nx = xi
xt = xt.at[nh + t + 1].set(nx)
return (xt,t+1), nx
buf = buf.at[nh + t + 1].set(nx)
return (buf, t+1), nx

@jax.jit
def loop(xt, p, t=0):
def loop(buf, p, t=0):
"xt is the buffer, zt is (ts, zs), p is parameters."
op = lambda xt, tz: step(xt, tz, p)
return jax.lax.scan(op, (xt,t), xt[nh:], unroll=unroll)[0]
dWt = buf[nh:]
(buf, _), nxs = jax.lax.scan(op, (buf, t), dWt, unroll=unroll)
return buf, nxs

return step, loop

0 comments on commit cf44ef1

Please sign in to comment.