Skip to content

Commit

Permalink
enable integrators to use pytree states
Browse files Browse the repository at this point in the history
  • Loading branch information
marmaduke woodman committed Feb 13, 2024
1 parent 86ad3a7 commit c98d0a2
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 13 deletions.
48 changes: 37 additions & 11 deletions vbjax/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import jax
import jax.tree_util
import jax.numpy as np


Expand All @@ -13,11 +14,35 @@ def heun_step(x, dfun, dt, *args, add=0, adhoc=None):
"""
adhoc = adhoc or (lambda x,*args: x)
d1 = dfun(x, *args)
xi = adhoc(x + dt*d1 + add, *args)
if add:
xi = jax.tree_map(lambda x,d,a: x + dt*d + a, x, d1, add)
else:
xi = jax.tree_map(lambda x,d: x + dt*d, x, d1)
xi = adhoc(xi, *args)
d2 = dfun(xi, *args)
nx = adhoc(x + dt*0.5*(d1 + d2) + add, *args)
if add:
nx = jax.tree_map(lambda x,d1,d2,a: x + dt*0.5*(d1 + d2) + a, x, d1, d2, add)
else:
nx = jax.tree_map(lambda x,d1,d2: x + dt*0.5*(d1 + d2), x, d1, d2)
nx = adhoc(nx, *args)
return nx


def _compute_noise(gfun, x, p, sqrt_dt, z_t):
g = gfun(x, p)
try: # maybe g & z_t are just arrays
noise = g * sqrt_dt * z_t
except TypeError: # one of them is a pytree
if isinstance(g, float): # z_t is a pytree, g is a scalar
noise = jax.tree_map(lambda z: g * sqrt_dt * z, z_t)
# otherwise, both must be pytrees and they must match
elif not jax.tree_util.tree_all(jax.tree_util.tree_structure(g) ==
jax.tree_util.tree_structure(z_t)):
raise ValueError("gfun and z_t must have the same pytree structure.")
else:
noise = jax.tree_map(lambda g,z: g * sqrt_dt * z, g, z_t)
return noise

def make_sde(dt, dfun, gfun, adhoc=None):
"""Use a stochastic Heun scheme to integrate autonomous stochastic
differential equations (SDEs).
Expand Down Expand Up @@ -72,7 +97,7 @@ def make_sde(dt, dfun, gfun, adhoc=None):
gfun = lambda *_: sig

def step(x, z_t, p):
noise = gfun(x, p) * sqrt_dt * z_t
noise = _compute_noise(gfun, x, p, sqrt_dt, z_t)
return heun_step(x, dfun, dt, p, add=noise, adhoc=adhoc)

@jax.jit
Expand Down Expand Up @@ -207,27 +232,28 @@ def make_sdde(dt, nh, dfun, gfun, unroll=1, zero_delays=False, adhoc=None):

def step(buf_t, z_t, p):
buf, t = buf_t
x = buf[nh + t]
noise = gfun(x, p) * sqrt_dt * z_t
x = jax.tree_map(lambda buf: buf[nh + t], buf)
noise = _compute_noise(gfun, x, p, sqrt_dt, z_t)
d1 = dfun(buf, x, nh + t, p)
xi = adhoc(x + dt*d1 + noise, p)
xi = jax.tree_map(lambda x,d,n: x + dt*d + n, x, d1, noise)
xi = adhoc(xi, p)
if heun:
if zero_delays:
# severe performance hit (5x+)
buf = buf.at[nh + t + 1].set(xi)
buf = jax.tree_map(lambda buf, xi: buf.at[nh + t + 1].set(xi), buf, xi)
d2 = dfun(buf, xi, nh + t + 1, p)
nx = adhoc(x + dt*0.5*(d1 + d2) + noise, p)
nx = jax.tree_map(lambda x,d1,d2,n: x + dt*0.5*(d1 + d2) + n, x, d1, d2, noise)
nx = adhoc(nx, p)
else:
nx = xi
buf = buf.at[nh + t + 1].set(nx)
buf = jax.tree_map(lambda buf, nx: buf.at[nh + t + 1].set(nx), buf, nx)
return (buf, t+1), nx

@jax.jit
def loop(buf, p, t=0):
"xt is the buffer, zt is (ts, zs), p is parameters."
op = lambda xt, tz: step(xt, tz, p)
print(nh)
dWt = buf[nh:]
dWt = jax.tree_map(lambda b: b[nh:], buf) # buf[nh:]
(buf, _), nxs = jax.lax.scan(op, (buf, t), dWt, unroll=unroll)
return buf, nxs

Expand Down
35 changes: 33 additions & 2 deletions vbjax/tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,42 @@ def dfun(xt, x, t, p):


def test_sdde():
import vbjax as vb, jax.numpy as np
def dfun(xt, x, t, p):
return -xt[t - 5]
_, sdde = vb.make_sdde(1.0, 5, dfun, 0.01)
sdde(vb.randn(20)+10, None)


# TODO theta method? https://gist.github.com/maedoc/c47acb9d346e31017e05324ffc4582c1
# TODO theta method? https://gist.github.com/maedoc/c47acb9d346e31017e05324ffc4582c1

def test_heun_pytree():
from collections import namedtuple
State = namedtuple('State', 'x y')
def f(x: State, p):
return State(x.y, -x.x)
dt = 0.1

# first test with ode
_, loop = vb.make_ode(dt, f)
x = np.ones(32)
y = np.zeros(32)
x0 = State(x, y)
xs = loop(x0, np.r_[:64], None)
assert xs.x.shape == (64, 32)
assert xs.y.shape == (64, 32)

# then test with sde
_, loop = vb.make_sde(dt, f, 1e-2)
z = State(x=vb.randn(100, 32), y=np.zeros((100, 32)))
xs = loop(x0, z, None)
assert xs.x.shape == z.x.shape

# now with sdde
def f(xs: State, x: State, t, p):
return State(xs.y[t-3], -x.x)
nh = 5
_, loop = vb.make_sdde(dt, nh, f, 1e-2)
_, xs = loop(xs, None)
assert xs.x.shape == z.x[:-nh].shape

# TODO test also w/ gfun generating pytree

0 comments on commit c98d0a2

Please sign in to comment.