diff --git a/vbjax/__init__.py b/vbjax/__init__.py index 48bb135..0ace263 100644 --- a/vbjax/__init__.py +++ b/vbjax/__init__.py @@ -21,6 +21,7 @@ def _use_many_cores(): # import stuff from .loops import make_sde, make_ode, make_dde, make_sdde, heun_step, make_continuation +from .noise_generator import make_noise_generator, spectral_exponent from .shtlc import make_shtdiff from .neural_mass import ( JRState, JRTheta, jr_dfun, jr_default_theta, @@ -38,10 +39,15 @@ def _use_many_cores(): from .sparse import make_spmv, csr_to_jax_bcoo, make_sg_spmv from .monitor import ( make_timeavg, make_bold, make_gain, make_offline, make_cov, make_fc) -from .layers import make_dense_layers +from .layers import (make_dense_layers, create_degrees, create_masks, + MaskedLayer, MaskedMLP, OutputLayer) +from .ml_models import GaussianMADE, MAF from .diagnostics import shrinkage_zscore from .embed import embed_neural_flow, embed_polynomial, embed_gradient, embed_autoregress from .util import to_jax, to_np, tuple_meshgrid, tuple_ravel, tuple_shard +from .train_utils import (eval_model, train_step, log_likelihood_MADE, + log_likelihood_MAF, grad_func) + from ._version import __version__ # some random setup for convenience diff --git a/vbjax/layers.py b/vbjax/layers.py index 5de0cf5..41b84b0 100644 --- a/vbjax/layers.py +++ b/vbjax/layers.py @@ -1,4 +1,9 @@ import jax +import jax.numpy as jnp +from flax import linen as nn +from typing import Callable, Sequence +from jaxlib.xla_extension import ArrayImpl +import jax.random as random def make_dense_layers(in_dim, latent_dims=[10], out_dim=None, init_scl=0.1, extra_in=0, @@ -23,4 +28,146 @@ def fwd(params, x): x = act_fn(weights[i] @ x + biases[i]) return weights[-1] @ x + biases[-1] - return (weights, biases), fwd \ No newline at end of file + return (weights, biases), fwd + +def create_degrees(key, n_inputs, n_hiddens, input_order, mode): + """ + Generates a degree for each hidden and input unit. A unit with degree d can only receive input from units with + degree less than d. + :param n_inputs: the number of inputs + :param n_hiddens: a list with the number of hidden units + :param input_order: the order of the inputs; can be 'random', 'sequential', or an array of an explicit order + :param mode: the strategy for assigning degrees to hidden nodes: can be 'random' or 'sequential' + :return: list of degrees + """ + + degrees = [] + + # create degrees for inputs + if isinstance(input_order, str): + + if input_order == 'random': + degrees_0 = jnp.arange(1, n_inputs + 1) + jax.random.permutation(key, degrees_0) + + elif input_order == 'sequential': + degrees_0 = jnp.arange(1, n_inputs + 1) + + else: + raise ValueError('invalid input order') + + else: + input_order = jnp.array(input_order) + assert jnp.all(jnp.sort(input_order) == jnp.arange(1, n_inputs + 1)), 'invalid input order' + degrees_0 = input_order + degrees.append(degrees_0) + + # create degrees for hiddens + if mode == 'random': + for N in n_hiddens: + min_prev_degree = min(jnp.min(degrees[-1]), n_inputs - 1) + degrees_l = jax.random.randint(key, shape=(N,), minval=min_prev_degree, maxval=n_inputs) + degrees.append(degrees_l) + + elif mode == 'sequential': + for N in n_hiddens: + degrees_l = jnp.arange(N) % max(1, n_inputs - 1) + min(1, n_inputs - 1) + degrees.append(degrees_l) + + else: + raise ValueError('invalid mode') + + return degrees + + +def create_masks(degrees): + """ + Creates the binary masks that make the connectivity autoregressive. + :param degrees: a list of degrees for every layer + :return: list of all masks, as theano shared variables + """ + + Ms = [] + + for l, (d0, d1) in enumerate(zip(degrees[:-1], degrees[1:])): + M = d0[:, jnp.newaxis] <= d1 + # M = theano.shared(M.astype(dtype), name='M' + str(l+1), borrow=True) + Ms.append(M) + + Mmp = degrees[-1][:, jnp.newaxis] < degrees[0] + # Mmp = theano.shared(Mmp.astype(dtype), name='Mmp', borrow=True) + + return Ms, Mmp + + +class MaskedLayer(nn.Module): + features: int + mask: ArrayImpl + kernel_init: Callable = lambda key, shape, mask: random.normal(key, shape=shape)*1e-3*mask + bias_init: Callable = nn.initializers.zeros_init() + + @nn.compact + def __call__(self, inputs): + kernel = self.param('kernel', + self.kernel_init, # Initialization function + (inputs.shape[-1], self.features), self.mask) # shape info. + y = jnp.dot(inputs, kernel*self.mask) + bias = self.param('bias', self.bias_init, (self.features,)) + y = y + bias + return y + + +class OutputLayer(nn.Module): + out_dim: int + out_mask: ArrayImpl + kernel_init: Callable = lambda key, shape, out_mask: jax.random.normal(key, shape=shape)/jnp.sqrt(shape[0])*out_mask + bias_init: Callable = nn.initializers.zeros_init() + + @nn.compact + def __call__(self, inputs): + kernel_m = self.param('kernel_m', + self.kernel_init, # Initialization function + (inputs.shape[-1], self.out_dim), self.out_mask) # shape info. + kernel_logp = self.param('kernel_logp', + self.kernel_init, # Initialization function + (inputs.shape[-1], self.out_dim), self.out_mask) + m = jnp.dot(inputs, kernel_m*self.out_mask) + logp = jnp.dot(inputs, kernel_logp*self.out_mask) + bias_m = self.param('bias_m', self.bias_init, (self.out_dim,)) + bias_logp = self.param('bias_logp', self.bias_init, (self.out_dim,)) + m = m + bias_m + logp = logp + bias_logp + return m, logp + + +class MaskedMLP(nn.Module): + n_hiddens: Sequence[int] + act_fn: Callable + masks: Sequence[ArrayImpl] + + def setup(self): + self.hidden = [MaskedLayer(mask.shape[1], mask) for mask in self.masks] + + def __call__(self, inputs): + x = inputs + for i, (layer,) in enumerate(zip(self.hidden)): + x = layer(x) + # if i != len(self.hidden) - 1: + x = self.act_fn(x) + return x + + +class FourierLayer(nn.Module): + kernel_init: Callable = lambda key, shape: jax.random.normal(key, shape=shape)/jnp.sqrt(shape[0]) + bias_init: Callable = nn.initializers.zeros_init() + + @nn.compact + def __call__(self, inputs): + kernel = self.param('kernel', + self.kernel_init, # Initialization function + (inputs.shape[-1], 1)) # shape info. + + bias_m = self.param('bias', self.bias_init, (1,)) + m = jnp.dot(inputs, kernel) + m = m + bias_m + return m diff --git a/vbjax/ml_models.py b/vbjax/ml_models.py new file mode 100644 index 0000000..a494aad --- /dev/null +++ b/vbjax/ml_models.py @@ -0,0 +1,515 @@ +import jax.numpy as jnp +from flax import linen as nn +from typing import Callable, Sequence, Optional, Any +from collections import namedtuple, defaultdict +from jax._src.prng import PRNGKeyArrayImpl +import jax.random as random +from vbjax.layers import MaskedMLP, OutputLayer, create_degrees, create_masks +import jax +from flax.linen.initializers import zeros +import tqdm +from .neural_mass import bold_dfun, bold_default_theta, mpr_default_theta +from flax.core.frozen_dict import freeze, unfreeze + +DelayHelper = namedtuple('DelayHelper', 'Wt lags ix_lag_from max_lag n_to n_from') + +class GaussianMADE(nn.Module): + key: PRNGKeyArrayImpl + in_dim: int + n_hiddens: Sequence[int] + act_fn: Callable + input_order: str = 'sequential' + mode: str = 'sequential' + + def setup(self): + self.degrees = create_degrees(self.key, self.in_dim, self.n_hiddens, input_order=self.input_order, mode=self.mode) + self.masks, self.out_mask = create_masks(self.degrees) + self.mlp = MaskedMLP(self.n_hiddens, self.act_fn, self.masks) + self.output_layer = OutputLayer(self.in_dim, self.out_mask) + + + def __call__(self, inputs): + h = self.mlp(inputs) + m, logp = self.output_layer(h) + return m, logp + + + def gen(self, key, shape, u=None): + x = jnp.zeros(shape) + u = random.normal(key, shape) if u is None else u + + for i in range(1, shape[1] + 1): + h = self.mlp(x) + m, logp = self.output_layer(h) + idx = jnp.argwhere(self.degrees[0] == i)[0, 0] + x = x.at[:, idx].set(m[:, idx] + jnp.exp(jnp.minimum(-0.5 * logp[:, idx], 10.0)) * u[:, idx]) + return x + + +class MAF(nn.Module): + key: PRNGKeyArrayImpl + in_dim: int + n_hiddens: Sequence[int] + act_fn: Callable + n_mades: int + input_order: Optional[Sequence] = None + mode: str = 'sequential' + + def setup(self, input_order: Optional[Sequence] = None): + input_order = jnp.arange(1, self.in_dim+1) if input_order == None else input_order + self.mades = [GaussianMADE(random.split(self.key), self.in_dim, self.n_hiddens, self.act_fn, input_order=input_order[::((-1)**(i%2))], mode=self.mode) for i in range(self.n_mades)] + + def __call__(self, inputs): + u = inputs + logdet_dudx = 0 + for made in self.mades: + ms, logp = made(u) + u = jnp.exp(0.5 * logp) * (u - ms) + logdet_dudx += 0.5 * jnp.sum(logp, axis=1) + return u, logdet_dudx + + def gen(self, key, shape, u=None): + x = random.normal(key, shape) if u is None else u + + for made in self.mades[::-1]: + x = made.gen(key, shape, x) + return x + + +class Heun_step(nn.Module): + dfun: Callable + adhoc: Callable + dt: float + nh: Optional[int] + p: Optional[Any] + stvar: Optional[int] = 0 + external_i: Optional[int] = False + + + @nn.compact + def __call__(self, x, xs, t, *args): + tmap = jax.tree_util.tree_map + d1 = self.dfun(x, xs, *args) if self.p else self.dfun(x, *args) + xi = tmap(lambda x,d: x + self.dt*d, x, d1) + xi = tmap(self.adhoc, xi) + + d2 = self.dfun(xi, xs, *args) if self.p else self.dfun(xi, *args) + nx = tmap(lambda x, d1,d2: x + self.dt*0.5*(d1 + d2), x, d1, d2) + nx = tmap(self.adhoc, nx) + return nx, x + + +class Buffer_step(nn.Module): + dfun: Callable + adhoc: Callable + dt: float + nh: Optional[int] + p: Optional[Any] + external_i: Optional[int] = False + + + + @nn.compact + def __call__(self, buf, dWt, t, *args): + # jax.debug.print('t buffer step {x}', x=t) + t_step = t.at[0,0].get().astype(int) # retrieve time step + stim = t.at[1,:].get() + nh = self.nh + tmap = jax.tree_util.tree_map + x = tmap(lambda buf: buf[nh + t_step], buf) + # jax.debug.print('buffer {x}', x=buf.shape) + # jax.debug.print('x {x}', x=x.shape) + d1 = self.dfun(buf, x, nh + t_step, t) + xi = tmap(lambda x,d,n: x + self.dt * d + n, x, d1, dWt) + xi = tmap(self.adhoc, xi) + + d2 = self.dfun(buf, xi, nh + t_step + 1, t) + + nx = tmap(lambda x,d1,d2,n: x + self.dt * 0.5*(d1 + d2) + n, x, d1, d2, dWt) + nx = tmap(self.adhoc, nx) + buf = tmap(lambda buf, nx: buf.at[nh + t_step + 1].set(nx), buf, nx) + return buf, nx + + +class Buffer_step_euler(nn.Module): + dfun: Callable + adhoc: Callable + dt: float + nh: Optional[int] + p: Optional[Any] + external_i: Optional[int] = False + + @nn.compact + def __call__(self, buf, dWt, t, *args): + t_step = t.at[0,0].get().astype(int) # retrieve time step + nh = self.nh + tmap = jax.tree_util.tree_map + x = tmap(lambda buf: buf[nh + t_step], buf) + d1 = self.dfun(buf, x, nh + t_step, t) + xi = tmap(lambda x,d,n: x + self.dt * d + n, x, d1, dWt) + nx = tmap(self.adhoc, xi) + + # d2 = self.dfun(buf, xi, nh + t_step + 1, t) + + # nx = tmap(lambda x,d1,d2,n: x + self.dt * 0.5*(d1 + d2) + n, x, d1, d2, dWt) + nx = tmap(self.adhoc, nx) + buf = tmap(lambda buf, nx: buf.at[nh + t_step + 1].set(nx), buf, nx) + return buf, nx + + + +class Integrator(nn.Module): + dfun: Callable + step: Callable + adhoc: Callable + dt: float = 1.0 + stvar: Optional[int] = 0 + nh: Optional[int] = None + p: Optional[Any] = True + in_ax: Optional[tuple] = (0,0) + + @nn.compact + def __call__(self, c, xs, t_count, *args): + STEP = nn.scan(self.step, + # variable_broadcast=["params", "noise"], + # split_rngs={"params": False, "noise": True}, + variable_broadcast=["params"], + split_rngs={"params": False}, + in_axes=self.in_ax, + out_axes=0 + ) + return STEP(self.dfun, self.adhoc, self.dt, self.nh, self.p)(c, xs, t_count, *args) + + + +class TVB(nn.Module): + tvb_p: namedtuple + dfun: Callable + nst_vars: int + n_pars: int + dfun_pars: Optional[defaultdict] = jnp.array([]) + dt: float = 0.1 + integrator: Optional[Callable] = Integrator + step: Callable = Buffer_step + adhoc: Callable = lambda x : x + gfun: Callable = lambda x : x + stimulus: Optional[Sequence] = jnp.array([]) + node_stim = 0 + training: bool = False + + def delay_apply(self, dh: DelayHelper, t, buf): + return (dh.Wt * buf[t - dh.lags, dh.ix_lag_from, :]).sum(axis=1) + + def fwd(self, nmm, region_pars, g): + def tvb_dfun(buf, x, t, stim): + coupled_x = self.delay_apply(self.tvb_p['dh'], t, buf[...,:self.nst_vars]) + coupling_term = coupled_x[:,:1] # firing rate coupling only for QIF + # jax.debug.print('stim {x}', x=stim[:,1:]) + # jax.debug.print('x {x} r_pars {y} coupling {z}', x=x[0], y=region_pars[0], z=coupling_term[0]) + return nmm(x, region_pars, g*coupling_term+stim[:,1:]) + return tvb_dfun + + def noise_fill(self, buf, nh, key): + dWt = jax.random.normal(key, buf[nh+1:].transpose(0,2,1).shape) + dWt = dWt.transpose(0,2,1) + noise = self.gfun(dWt, jnp.sqrt(self.dt)) + buf = buf.at[nh+1:].set(noise) + return buf + + def initialize_buffer(self, key, fixed_initial_cond): + dh = self.tvb_p['dh'] + nh = int(dh.max_lag) + buf = jnp.zeros((nh + int(1/self.dt) + 1, dh.n_from, self.nst_vars)) + initial_cond = jnp.c_[ + jax.random.uniform(key=key, shape=(dh.n_from, 1), minval=0.1, maxval=2.0), + jax.random.uniform(key=key, shape=(dh.n_from, 1), minval=-2., maxval=1.5) + ] + initial_cond = fixed_initial_cond if fixed_initial_cond.any() else initial_cond + # horizon is set at the start of the buffer because rolled at the start of chunk + buf = buf.at[int(1/self.dt):,:,:self.nst_vars].add( initial_cond ) + return buf + + def chunk(self, module, buf, stimulus, key): + nh = int(self.tvb_p['dh'].max_lag) + buf = jnp.roll(buf, -int(1/self.dt), axis=0) + buf = self.noise_fill(buf, nh, key) + dWt = buf[nh+1:] # initialize carry noise filled + # jax.debug.print('stim {x}', x=stimulus) + # pass time count to the scanned integrator + t_count = jnp.tile(jnp.arange(int(1/self.dt))[...,None,None], (self.tvb_p['dh'].n_from, 1)) # (buf_len, regions, state_vars) + + stim = jnp.zeros(t_count.shape) + # stimulus = jnp.repeat(stimulus, int(1/self.dt))[...,None] + stim = stim.at[:,:,:].set(jnp.tile(stimulus[...,None], self.tvb_p['dh'].n_from)[...,None]) if self.training else stim.at[:,self.node_stim,:].set(stimulus[...,None]) + stim_t_count = jnp.c_[t_count, stim] + # jax.debug.print('stim_t_count {x}', x=stim_t_count.shape) + buf, rv = module(buf, dWt, stim_t_count) + return buf, rv + + def bold_monitor(self, module, bold_buf, rv, p=bold_default_theta): + t_count = jnp.tile(jnp.arange(rv.shape[0])[...,None, None,None], (4, self.tvb_p['dh'].n_from, 2)) # (buf_len, regions, state_vars) + bold_buf, bold = module(bold_buf, rv, t_count) + s, f, v, q = bold_buf + return bold_buf, p.v0 * (p.k1 * (1. - q) + p.k2 * (1. - q / v) + p.k3 * (1. - v)) + + + + @nn.compact + def __call__(self, inputs, g=0, sim_len=0, seed=42, initial_cond=jnp.array([]), mlp=True): + if inputs==None: + inputs = jnp.ones((1, self.nst_vars)) + region_pars = inputs + key = jax.random.PRNGKey(seed) + # buf = self.initialize_buffer(key, initial_cond) + + if mlp: + nmm = lambda x, xs, *args: self.dfun(self.dfun_pars, x, xs, *args) + tvb_dfun = self.fwd(nmm, region_pars, g) + else: + nmm = lambda x, xs, *args: self.dfun(x, xs, *args) + tvb_dfun = self.fwd(nmm, region_pars, g) + + # nmm = lambda x, xs, *args: self.dfun(self.dfun_pars, x, xs, *args) if mlp else self.dfun.__call__ + # tvb_dfun = self.fwd(nmm, region_pars, g) + + module = self.integrator(tvb_dfun, self.step, self.adhoc, self.dt, nh=int(self.tvb_p['dh'].max_lag)) + run_chunk = nn.scan(self.chunk.__call__) + run_sim = nn.scan(run_chunk) + + buf = self.initialize_buffer(key, initial_cond) + + chunksize = int((self.stimulus.shape[0]/sim_len)) if jnp.any(self.stimulus) else 1000 + stimulus = stimulus.reshape((sim_len, int(chunksize*self.dt), -1)) if jnp.any(self.stimulus) else jnp.zeros((sim_len, int(chunksize*self.dt), 1)) + # jax.debug.print('buf {x}', x=buf.shape) + buf, rv = run_sim(module, buf, stimulus, jax.random.split(key, (sim_len, int(chunksize*self.dt)))) + + # jax.debug.print('rv {x}', x=rv[0].shape) + # jax.debug.print('rv {x}', x=rv[1].shape) + # dummy_adhoc_bold = lambda x: x + # bold_dfun_p = lambda sfvq, x: bold_dfun(sfvq, x, bold_default_theta) + # module = self.integrator(bold_dfun_p, Heun_step, dummy_adhoc_bold, self.dt/10000, nh=int(self.tvb_p['dh'].max_lag), p=1) + # run_bold = nn.scan(self.bold_monitor.__call__) + + # bold_buf = jnp.ones((4, self.tvb_p['dh'].n_from, 1)) + # bold_buf = bold_buf.at[0].set(1.) + + # bold_buf, bold = run_bold(module, bold_buf, rv[...,0].reshape((-1, int(20000/self.dt), self.tvb_p['dh'].n_from, 1))) + return rv + return rv.reshape(-1, self.tvb_p['dh'].n_from, self.nst_vars+self.n_pars)#, bold + + + +class TVB_ODE(nn.Module): + out_dim: int + n_hiddens: Sequence[int] + act_fn: Callable + + def setup(self): + self.Nodes = nn.vmap( + Simple_MLP, + in_axes=0, out_axes=0, + variable_axes={'params': 0}, + split_rngs={'params': True}, + methods=["__call__"])(out_dim=self.out_dim, n_hiddens=self.n_hiddens, act_fn=self.act_fn, coupled=True) + + def __call__(self, x, xs, *args): + y = self.Nodes(x, xs, *args) + return y + + + +class Simple_MLP(nn.Module): + out_dim: int + n_hiddens: Sequence[int] + act_fn: Callable + # kernel_init: Callable = jax.nn.initializers.normal(1e-3) + kernel_init: Callable = jax.nn.initializers.he_normal() + coupled: bool = False + n_pars: int = 0 + scaling_factor: float = .01 + + def setup(self): + self.layers = [nn.Dense(feat, kernel_init=self.kernel_init, bias_init=nn.initializers.zeros) for feat in self.n_hiddens] + self.output = nn.Dense(self.out_dim, kernel_init=self.kernel_init, bias_init=nn.initializers.zeros) + + @nn.compact + def __call__(self, x, xs, *args): + # c = args[0] + # jax.debug.print('x[0] {x} xs[0] {y} args[0] {z}', x=x.shape, y=xs.shape, z=c.shape) + # jax.debug.print('x[0] {x} xs[0] {y}', x=x.shape, y=xs.shape) + # jax.debug.print('kernel {x}', x=self.layers[0]) + x = jnp.c_[x, xs] + x = jnp.c_[x, args[0]] if self.coupled else x + for layer in self.layers: + x = layer(x) + x = self.act_fn(x) + x = self.output(x) + return x*self.scaling_factor + + +class Simple_MLP_additive_c(nn.Module): + out_dim: int + n_hiddens: Sequence[int] + act_fn: Callable + kernel_init: Callable = jax.nn.initializers.he_normal( ) + coupled: bool = True + + def setup(self): + self.layers = [nn.Dense(feat, kernel_init=self.kernel_init, bias_init=nn.initializers.zeros) for feat in self.n_hiddens] + self.output = nn.Dense(self.out_dim, kernel_init=self.kernel_init, bias_init=nn.initializers.zeros) + + @nn.compact + def __call__(self, x, xs, *args, scaling_factor=.01): + c = args[0] + jax.debug.print('x[0] {x} xs[0] {y} args[0] {z}', x=x[0], y=xs[0], z=c[0]) + x = jnp.c_[x, xs] + for layer in self.layers: + x = layer(x) + x = self.act_fn(x) + x = self.output(x) + + x = x*scaling_factor + # jax.debug.print('x before {x}', x=x[0]) + x += jnp.c_[jnp.zeros(args[0].shape), args[0]] if self.coupled else x + # jax.debug.print('x after {x}', x=x[0]) + return x + + + +class MontBrio(nn.Module): + dfun_pars: Optional[defaultdict] = mpr_default_theta + coupled: bool = False + scaling_factor: float = 1. + + def setup(self): + self.eta = self.dfun_pars.eta + self.Delta = self.dfun_pars.Delta + self.tau = self.dfun_pars.tau + self.I = self.dfun_pars.I + self.J = self.dfun_pars.J + self.cr = self.dfun_pars.cr + self.cv = self.dfun_pars.cv + + + @nn.compact + def __call__(self, x, xs, *args): + # xs contains regions parameters not implemented yet + c = args[0] if self.coupled else jnp.zeros(x.shape) + # jax.debug.print('x[0] {x} c[0] {z}', x=x[0], z=c[0]) + r, V = x[:,:1], x[:,1:] + I_c = self.cr * c[:,:1] + r_dot = (1 / self.tau) * (self.Delta / (jnp.pi * self.tau) + 2 * r * V) + v_dot = (1 / self.tau) * (V ** 2 + self.eta + self.J * self.tau * r + self.I + I_c - (jnp.pi ** 2) * (r ** 2) * (self.tau ** 2)) + return jnp.c_[r_dot, v_dot]*self.scaling_factor + + +class NeuralOdeWrapper(nn.Module): + out_dim: int + n_hiddens: Sequence[int] + act_fn: Callable + extra_p: int + dt: Optional[float] = 1. + step: Optional[Callable] = Heun_step + integrator: Optional[Callable] = Integrator + dfun: Optional[Callable] = None + integrate: Optional[bool] = True + coupled: Optional[bool] = False + i_ext: Optional[bool] = False + stvar: Optional[int] = 0 + adhoc: Optional[Callable] = lambda x : x + + + @nn.compact + def __call__(self, inputs): + (x, i_ext) = inputs if self.coupled else (inputs, None) + # dfun = self.dfun(self.out_dim, self.n_hiddens, self.act_fn, coupled=self.coupled) + + if not self.integrate: + deriv = self.dfun(inputs[0], inputs[1]) + # jax.debug.print('deriv {x}', x=deriv[0]) + return deriv + + in_ax = (0,0,0) if self.coupled else (0,0) + integrate = self.integrator(self.dfun.__call__, self.step, self.adhoc, self.dt, in_ax=in_ax, p=True) + + # xs = jnp.zeros_like(x[:,:,:int(self.extra_p)]) # initialize carry + p = x[:,:,-self.extra_p:] # initialize carry param filled + # jax.debug.print('p {x}', x=p.shape) + # i_ext = self.prepare_stimulus(x, i_ext, self.stvar) + t_count = jnp.tile(jnp.arange(x.shape[0])[...,None,None], (x.shape[1], x.shape[2])) # (length, train_samples, state_vars) + + x = x[0,:,:self.out_dim] + return integrate(x, p, t_count, i_ext)[1] + + +class Encoder(nn.Module): + in_dim: int + latent_dim: int + act_fn: Callable + n_hiddens: Sequence[int] = None + + def setup(self, n_hiddens: Optional[Sequence] = None): + n_hiddens = n_hiddens[::-1] if n_hiddens else [self.in_dim, 4*self.latent_dim, 2*self.latent_dim, self.latent_dim][::-1] + self.layers = [nn.Dense(feat) for feat in n_hiddens] + + def __call__(self, inputs): + x = inputs + for layer in self.layers: + x = layer(x) + x = self.act_fn(x) + return x + + +class Decoder(nn.Module): + in_dim: int + latent_dim: int + act_fn: Callable + n_hiddens: Sequence[int] = None + + def setup(self, n_hiddens: Optional[Sequence] = None): + n_hiddens = n_hiddens[::-1] if n_hiddens else [self.in_dim, 4*self.latent_dim, 2*self.latent_dim, self.latent_dim][::-1] + self.layers = [nn.Dense(feat) for feat in n_hiddens] + + def __call__(self, inputs): + x = inputs + for layer in self.layers[:-1]: + x = layer(x) + x = self.act_fn(x) + x = self.layers[-1](x) + return x + + +class Autoencoder(nn.Module): + latent_dim: int + encoder_act_fn: Callable + decoder_act_fn: Callable + ode_act_fn: Callable + ode: bool = False + n_hiddens: Sequence[int] = None + kernel_init: Callable = jax.nn.initializers.normal(10e-3) + step: Optional[Callable] = Heun_step + integrator: Optional[Callable] = Integrator + network: Optional[Callable] = Simple_MLP + i_ext: Optional[bool] = True + ode_n_hiddens: Optional[Sequence] = None + + def integrate(self, encoded, L): + xs = jnp.ones((encoded.shape[0], encoded.shape[1], L)) # initialize carry + dfun = self.network(encoded.shape[1], self.ode_n_hiddens, self.ode_act_fn) + integrator = self.integrator(dfun, self.step) + return integrator(encoded, xs)[1] + + @nn.compact + def __call__(self, inputs): + L = inputs.shape[-1] + + encoder = Encoder(inputs.shape[1], self.latent_dim, self.encoder_act_fn, self.n_hiddens) + encoded = encoder(inputs[:,:,0]) if self.ode else encoder(inputs) # (N, ) + + decoder = Decoder(inputs.shape[1], self.latent_dim, self.decoder_act_fn) + y = decoder(encoded) + if self.ode: + y = self.integrate(encoded, L) + + return y + diff --git a/vbjax/noise_generator.py b/vbjax/noise_generator.py new file mode 100644 index 0000000..dcfc661 --- /dev/null +++ b/vbjax/noise_generator.py @@ -0,0 +1,62 @@ +import jax.numpy as np +import matplotlib.pyplot as plt +import jax + + +def white_noise(f): + return f + +def blue_noise(f): + return np.sqrt(f) + +def violet_noise(f): + return f + +def brownian_noise(f): + return 1/np.where(f == 0, float('inf'), f) + +def pink_noise(f): + return 1/np.where(f == 0, float('inf'), np.sqrt(f)) + +def spectral_exponent(f, exponent=1): + return 1/np.where(f == 0, float('inf'), f**exponent) + + +def make_noise_generator(psd = lambda f: 1): + """Generate noise given desired power spectrum + + Parameters + ========== + shape : tuple + (n_nodes, t_steps) + key : function + jax.random.PRNGKey() + sigma (optional kwarg): float + Standard deviation, defaults to 1 + exponent (optional kwarg) : float + Spectral exponent for 1/f noise, defaults to 1 + + Returns + ======= + noise_stream : (n_nodes, t_steps) array + + Notes + ===== + Example usage for white noise and 1/f noise + + >>> import vbjax as vb, import jax + >>> key = jax.random.PRNGKey(seed) + + """ + def gen(key, shape, sigma=1, **kwargs): + # sigma = kwargs['sigma'] or 1 + X_white = np.fft.rfft(jax.random.normal(key, shape)*sigma) + S = psd(np.fft.rfftfreq(shape[1]), **kwargs) + S = S / np.sqrt(np.mean(S**2)) + X_shaped = X_white * S + return np.fft.irfft(X_shaped) + return gen + + + + diff --git a/vbjax/train.py b/vbjax/train.py new file mode 100644 index 0000000..235ff81 --- /dev/null +++ b/vbjax/train.py @@ -0,0 +1,109 @@ +from flax.training import train_state +import jax.numpy as jnp +import jax.random as random +from flax import linen as nn +from vbjax.ml_models import GaussianMADE, MAF +import optax, jax +import matplotlib.pyplot as plt +import numpy as np +from vbjax.train_utils import eval_model, grad_func, log_likelihood_MAF + +def train_step(state, batch, loss_f): + def loss_fn(params): + output = state.apply_fn( + {'params': params}, batch, + ) + loss = loss_f(output, batch).mean() + return loss + grads = jax.grad(loss_fn)(state.params) + return state.apply_gradients(grads=grads) + + +def train_and_evaluate(model, X, config): + rng = random.key(0) + rng, key = random.split(rng) + + init_data = jnp.ones((config['batch_size'], config['in_dim']), jnp.float32) + params = model.init(key, init_data)['params'] + + state = train_state.TrainState.create( + apply_fn=model.apply, + params=params, + tx=optax.adam(config['learning_rate']), + ) + + # print(state.params['mlp']['hidden_0']['kernel'][0,:]) + batch_size = config['batch_size'] + BATCHES = np.split(np.random.choice(np.arange(len(X)), (len(X)//batch_size)*batch_size), (len(X)//batch_size)) + + for i, epoch in enumerate(range(config['num_epochs'])): + if i%10==0: + u, logp = state.apply_fn({'params': state.params}, X) + loss, u_distro = eval_model( + model, state.params, X, key, log_likelihood_MAF, shape=(20000,2), + ) + print('eval epoch: {}, loss: {}'.format(i + 1, loss)) + # U = jnp.exp(.5 * logp) * (X - ms) + # L = -0.5 * (X.shape[1] * jnp.log(2 * jnp.pi) + jnp.sum(U ** 2 - logp, axis=1)) + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14,4)) + ax1.scatter(np.array(X[:,0]), np.array(X[:,1])) + cs = ax2.scatter(np.array(u[:,0]), np.array(u[:,1]), c=loss, vmin=loss.min(), vmax=loss.max()) + his = ax3.hist2d(np.array(u_distro[:,0]), np.array(u_distro[:,1]), bins=30, density=True, vmax=.03) + fig.colorbar(his[3], ax=ax3) + ax3.set_title('logp') + plt.colorbar(cs, ax=ax2) + fig.tight_layout() + plt.show() + + for j, batch_i in enumerate(BATCHES): + batch = X[batch_i] + # batch = X[i*config['batch_size']:(i+1)*config['batch_size']] + rng, key = random.split(rng) + state = train_step(state, batch, log_likelihood_MAF) + loss, u_distro = eval_model( + model, state.params, X, key, log_likelihood_MAF, shape=(20000,2), + ) + print('eval epoch: {}, loss: {}'.format(i + 1, loss.mean())) + # print(jnp.linalg.norm(grad_func(state, batch)['mlp']['hidden_0']['kernel'])) + + return state, model + + + +config = {} +config['learning_rate'] = .003 +config['in_dim'] = 2 +config['batch_size'] = 256 +config['num_epochs'] = 60 +config['n_hiddens'] = [5, 5] + + +key1, key2 = random.split(random.key(0), 2) +x2 = 4 * random.normal(key1, (config['batch_size']*100,)) +x1 = (.25*x2**2) + random.normal(key2, (config['batch_size']*100,)) +X = jnp.vstack([x2, x1]).T + +# key1, key2 = random.split(random.key(0), 2) +# x2 = 3 * random.normal(key1, (config['batch_size']*100,)) +# x1 = x2 + random.normal(key2, (config['batch_size']*100,)) +# X = jnp.vstack([x2, x1]).T + + +model=MAF(random.PRNGKey(42), 2, config['n_hiddens'], act_fn=nn.relu, n_mades=4) +state_f, mdl = train_and_evaluate(model, X, config) + + +def nnet_fn(X): + ms, logp = mdl.apply({'params': state_f.params}, X) + u = jnp.exp(.5 * logp) * (X - ms) + L = -0.5 * (X.shape[1] * jnp.log(2 * jnp.pi) + jnp.sum(u ** 2 - logp, axis=1)) + return u, L + + +x = y = jnp.linspace(-30, 30, 500) +xx, yy = jnp.meshgrid(x, y) +X_test = jnp.vstack([xx.ravel(), yy.ravel()]).T +u, L = nnet_fn(X_test) +L = L.reshape(500,500) +plt.imshow(np.exp(L)) +plt.show() \ No newline at end of file diff --git a/vbjax/train_utils.py b/vbjax/train_utils.py new file mode 100644 index 0000000..3fb732e --- /dev/null +++ b/vbjax/train_utils.py @@ -0,0 +1,77 @@ +import jax +import jax.numpy as jnp +from flax import linen as nn + +def log_likelihood_MADE(ms, logp, x, *args): + u = jnp.exp(.5 * logp) * (x - ms) + return -(- 0.5 * (x.shape[1] * jnp.log(2 * jnp.pi) + jnp.sum(u ** 2 - logp, axis=1))) + + +def log_likelihood_MAF(x, *arg): + u, logdet_dudx = x + return -(- 0.5 * u.shape[1] * jnp.log(2 * jnp.pi) - 0.5 * jnp.sum(u ** 2, axis=1) + logdet_dudx) + +def mse_ode(traj, x, *arg): + return jnp.mean((traj-x)**2) + + +def eval_model_ode(model, params, batch, loss_fn=None, shape=None): + batch, i_ext = batch + def eval_model(model): + output = model(batch, i_ext) + loss = loss_fn(output, batch) + return loss + return nn.apply(eval_model, model)({'params': params}) + + +def eval_loss(model, params, batch, loss_fn=None, shape=None): + batch, i_ext = batch + def eval_model(model): + output = model(batch, i_ext) + loss = loss_fn(output, batch) + return loss + return nn.apply(eval_model, model)({'params': params}) + +def eval_model(model, params, batch, key, loss_fn, shape=None): + shape = shape if shape else batch.shape + def eval_model(model): + output = model(batch) + loss = loss_fn(output, batch)#.mean() + u_sample = model.gen(key, shape) + return loss, u_sample + return nn.apply(eval_model, model)({'params': params}) + + +def train_step(state, batch, loss_f): + batch, p = batch + def loss_fn(params): + output = state.apply_fn( + {'params': params}, batch, p, + ) + loss = loss_f(output, batch) + return loss + grads = jax.grad(loss_fn)(state.params) + return state.apply_gradients(grads=grads) + + +def grad_func(state, batch, loss_fn): + def loss_fn(params): + output = state.apply_fn( + {'params': params}, batch, + ) + loss = loss_fn(output, batch) + return loss + grads = jax.grad(loss_fn)(state.params) + return grads + +def loss_t(traj, X): + X, iext = X + loss_bias = jnp.var(X, axis=2)*10+1 + squared_loss_vec = jnp.square(X - traj).mean(axis=(2)) + return (loss_bias*squared_loss_vec).sum() + +def loss_t_unpack(traj, X): + X, iext = X + loss_bias = jnp.var(X, axis=2)*40+1 + squared_loss_vec = jnp.square(X - traj).mean(axis=(2)) + return (loss_bias*squared_loss_vec).sum(axis=1) \ No newline at end of file