Skip to content

Commit

Permalink
fix missing local & tree map warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
Marmaduke Woodman committed Apr 16, 2024
1 parent b4f27ec commit ac1ee82
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions vbjax/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax.numpy as np

zero = 0
tmap = jax.tree_util.tree_map

def heun_step(x, dfun, dt, *args, add=zero, adhoc=None, return_euler=False):
"""Use a Heun scheme to step state with a right hand sides dfun(.)
Expand All @@ -16,15 +17,15 @@ def heun_step(x, dfun, dt, *args, add=zero, adhoc=None, return_euler=False):
adhoc = adhoc or (lambda x,*args: x)
d1 = dfun(x, *args)
if add is not zero:
xi = jax.tree_map(lambda x,d,a: x + dt*d + a, x, d1, add)
xi = tmap(lambda x,d,a: x + dt*d + a, x, d1, add)
else:
xi = jax.tree_map(lambda x,d: x + dt*d, x, d1)
xi = tmap(lambda x,d: x + dt*d, x, d1)
xi = adhoc(xi, *args)
d2 = dfun(xi, *args)
if add is not zero:
nx = jax.tree_map(lambda x,d1,d2,a: x + dt*0.5*(d1 + d2) + a, x, d1, d2, add)
nx = tmap(lambda x,d1,d2,a: x + dt*0.5*(d1 + d2) + a, x, d1, d2, add)
else:
nx = jax.tree_map(lambda x,d1,d2: x + dt*0.5*(d1 + d2), x, d1, d2)
nx = tmap(lambda x,d1,d2: x + dt*0.5*(d1 + d2), x, d1, d2)
nx = adhoc(nx, *args)
if return_euler:
return xi, nx
Expand All @@ -37,13 +38,13 @@ def _compute_noise(gfun, x, p, sqrt_dt, z_t):
noise = g * sqrt_dt * z_t
except TypeError: # one of them is a pytree
if isinstance(g, float): # z_t is a pytree, g is a scalar
noise = jax.tree_map(lambda z: g * sqrt_dt * z, z_t)
noise = tmap(lambda z: g * sqrt_dt * z, z_t)
# otherwise, both must be pytrees and they must match
elif not jax.tree_util.tree_all(jax.tree_util.tree_structure(g) ==
jax.tree_util.tree_structure(z_t)):
raise ValueError("gfun and z_t must have the same pytree structure.")
else:
noise = jax.tree_map(lambda g,z: g * sqrt_dt * z, g, z_t)
noise = tmap(lambda g,z: g * sqrt_dt * z, g, z_t)
return noise

def make_sde(dt, dfun, gfun, adhoc=None, return_euler=False, unroll=10):
Expand Down Expand Up @@ -113,10 +114,15 @@ def step(x, z_t, p):
def loop(x0, zs, p):
def op(x, z):
x = step(x, z, p)
# XXX gets unwieldy, how to improve?
if return_euler:
ex, x = x
else:
ex = None
return x, (ex, x)
_, xs = jax.lax.scan(op, x0, zs, unroll=unroll)
if not return_euler:
_, xs = xs
return xs

return step, loop
Expand Down Expand Up @@ -244,28 +250,28 @@ def make_sdde(dt, nh, dfun, gfun, unroll=1, zero_delays=False, adhoc=None):

def step(buf_t, z_t, p):
buf, t = buf_t
x = jax.tree_map(lambda buf: buf[nh + t], buf)
x = tmap(lambda buf: buf[nh + t], buf)
noise = _compute_noise(gfun, x, p, sqrt_dt, z_t)
d1 = dfun(buf, x, nh + t, p)
xi = jax.tree_map(lambda x,d,n: x + dt*d + n, x, d1, noise)
xi = tmap(lambda x,d,n: x + dt*d + n, x, d1, noise)
xi = adhoc(xi, p)
if heun:
if zero_delays:
# severe performance hit (5x+)
buf = jax.tree_map(lambda buf, xi: buf.at[nh + t + 1].set(xi), buf, xi)
buf = tmap(lambda buf, xi: buf.at[nh + t + 1].set(xi), buf, xi)
d2 = dfun(buf, xi, nh + t + 1, p)
nx = jax.tree_map(lambda x,d1,d2,n: x + dt*0.5*(d1 + d2) + n, x, d1, d2, noise)
nx = tmap(lambda x,d1,d2,n: x + dt*0.5*(d1 + d2) + n, x, d1, d2, noise)
nx = adhoc(nx, p)
else:
nx = xi
buf = jax.tree_map(lambda buf, nx: buf.at[nh + t + 1].set(nx), buf, nx)
buf = tmap(lambda buf, nx: buf.at[nh + t + 1].set(nx), buf, nx)
return (buf, t+1), nx

@jax.jit
def loop(buf, p, t=0):
"xt is the buffer, zt is (ts, zs), p is parameters."
op = lambda xt, tz: step(xt, tz, p)
dWt = jax.tree_map(lambda b: b[nh:], buf) # buf[nh:]
dWt = tmap(lambda b: b[nh:], buf) # buf[nh:]
(buf, _), nxs = jax.lax.scan(op, (buf, t), dWt, unroll=unroll)
return buf, nxs

Expand Down

0 comments on commit ac1ee82

Please sign in to comment.