From ac1ee824a4476d908f717825c8a960b5d29278df Mon Sep 17 00:00:00 2001 From: Marmaduke Woodman Date: Tue, 16 Apr 2024 12:33:45 +0200 Subject: [PATCH] fix missing local & tree map warnings --- vbjax/loops.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/vbjax/loops.py b/vbjax/loops.py index f1aeda5..edb1e52 100644 --- a/vbjax/loops.py +++ b/vbjax/loops.py @@ -8,6 +8,7 @@ import jax.numpy as np zero = 0 +tmap = jax.tree_util.tree_map def heun_step(x, dfun, dt, *args, add=zero, adhoc=None, return_euler=False): """Use a Heun scheme to step state with a right hand sides dfun(.) @@ -16,15 +17,15 @@ def heun_step(x, dfun, dt, *args, add=zero, adhoc=None, return_euler=False): adhoc = adhoc or (lambda x,*args: x) d1 = dfun(x, *args) if add is not zero: - xi = jax.tree_map(lambda x,d,a: x + dt*d + a, x, d1, add) + xi = tmap(lambda x,d,a: x + dt*d + a, x, d1, add) else: - xi = jax.tree_map(lambda x,d: x + dt*d, x, d1) + xi = tmap(lambda x,d: x + dt*d, x, d1) xi = adhoc(xi, *args) d2 = dfun(xi, *args) if add is not zero: - nx = jax.tree_map(lambda x,d1,d2,a: x + dt*0.5*(d1 + d2) + a, x, d1, d2, add) + nx = tmap(lambda x,d1,d2,a: x + dt*0.5*(d1 + d2) + a, x, d1, d2, add) else: - nx = jax.tree_map(lambda x,d1,d2: x + dt*0.5*(d1 + d2), x, d1, d2) + nx = tmap(lambda x,d1,d2: x + dt*0.5*(d1 + d2), x, d1, d2) nx = adhoc(nx, *args) if return_euler: return xi, nx @@ -37,13 +38,13 @@ def _compute_noise(gfun, x, p, sqrt_dt, z_t): noise = g * sqrt_dt * z_t except TypeError: # one of them is a pytree if isinstance(g, float): # z_t is a pytree, g is a scalar - noise = jax.tree_map(lambda z: g * sqrt_dt * z, z_t) + noise = tmap(lambda z: g * sqrt_dt * z, z_t) # otherwise, both must be pytrees and they must match elif not jax.tree_util.tree_all(jax.tree_util.tree_structure(g) == jax.tree_util.tree_structure(z_t)): raise ValueError("gfun and z_t must have the same pytree structure.") else: - noise = jax.tree_map(lambda g,z: g * sqrt_dt * z, g, z_t) + noise = tmap(lambda g,z: g * sqrt_dt * z, g, z_t) return noise def make_sde(dt, dfun, gfun, adhoc=None, return_euler=False, unroll=10): @@ -113,10 +114,15 @@ def step(x, z_t, p): def loop(x0, zs, p): def op(x, z): x = step(x, z, p) + # XXX gets unwieldy, how to improve? if return_euler: ex, x = x + else: + ex = None return x, (ex, x) _, xs = jax.lax.scan(op, x0, zs, unroll=unroll) + if not return_euler: + _, xs = xs return xs return step, loop @@ -244,28 +250,28 @@ def make_sdde(dt, nh, dfun, gfun, unroll=1, zero_delays=False, adhoc=None): def step(buf_t, z_t, p): buf, t = buf_t - x = jax.tree_map(lambda buf: buf[nh + t], buf) + x = tmap(lambda buf: buf[nh + t], buf) noise = _compute_noise(gfun, x, p, sqrt_dt, z_t) d1 = dfun(buf, x, nh + t, p) - xi = jax.tree_map(lambda x,d,n: x + dt*d + n, x, d1, noise) + xi = tmap(lambda x,d,n: x + dt*d + n, x, d1, noise) xi = adhoc(xi, p) if heun: if zero_delays: # severe performance hit (5x+) - buf = jax.tree_map(lambda buf, xi: buf.at[nh + t + 1].set(xi), buf, xi) + buf = tmap(lambda buf, xi: buf.at[nh + t + 1].set(xi), buf, xi) d2 = dfun(buf, xi, nh + t + 1, p) - nx = jax.tree_map(lambda x,d1,d2,n: x + dt*0.5*(d1 + d2) + n, x, d1, d2, noise) + nx = tmap(lambda x,d1,d2,n: x + dt*0.5*(d1 + d2) + n, x, d1, d2, noise) nx = adhoc(nx, p) else: nx = xi - buf = jax.tree_map(lambda buf, nx: buf.at[nh + t + 1].set(nx), buf, nx) + buf = tmap(lambda buf, nx: buf.at[nh + t + 1].set(nx), buf, nx) return (buf, t+1), nx @jax.jit def loop(buf, p, t=0): "xt is the buffer, zt is (ts, zs), p is parameters." op = lambda xt, tz: step(xt, tz, p) - dWt = jax.tree_map(lambda b: b[nh:], buf) # buf[nh:] + dWt = tmap(lambda b: b[nh:], buf) # buf[nh:] (buf, _), nxs = jax.lax.scan(op, (buf, t), dWt, unroll=unroll) return buf, nxs