diff --git a/vbjax/__init__.py b/vbjax/__init__.py index 48bb135..32b2c01 100644 --- a/vbjax/__init__.py +++ b/vbjax/__init__.py @@ -20,7 +20,8 @@ def _use_many_cores(): cores = _use_many_cores() # import stuff -from .loops import make_sde, make_ode, make_dde, make_sdde, heun_step, make_continuation +from .custom_loops import make_sde, make_ode, make_dde, make_sdde, heun_step, make_continuation +from .noise_generator import make_noise_generator, spectral_exponent from .shtlc import make_shtdiff from .neural_mass import ( JRState, JRTheta, jr_dfun, jr_default_theta, @@ -38,10 +39,15 @@ def _use_many_cores(): from .sparse import make_spmv, csr_to_jax_bcoo, make_sg_spmv from .monitor import ( make_timeavg, make_bold, make_gain, make_offline, make_cov, make_fc) -from .layers import make_dense_layers +from .layers import (make_dense_layers, create_degrees, create_masks, + MaskedLayer, MaskedMLP, OutputLayer) +from .ml_models import GaussianMADE, MAF from .diagnostics import shrinkage_zscore from .embed import embed_neural_flow, embed_polynomial, embed_gradient, embed_autoregress from .util import to_jax, to_np, tuple_meshgrid, tuple_ravel, tuple_shard +from .train_utils import (eval_model, train_step, log_likelihood_MADE, + log_likelihood_MAF, grad_func) + from ._version import __version__ # some random setup for convenience diff --git a/vbjax/layers.py b/vbjax/layers.py index 5de0cf5..4fa1a47 100644 --- a/vbjax/layers.py +++ b/vbjax/layers.py @@ -1,4 +1,9 @@ import jax +import jax.numpy as jnp +from flax import linen as nn +from typing import Callable, Sequence +from jaxlib.xla_extension import ArrayImpl +import jax.random as random def make_dense_layers(in_dim, latent_dims=[10], out_dim=None, init_scl=0.1, extra_in=0, @@ -23,4 +28,133 @@ def fwd(params, x): x = act_fn(weights[i] @ x + biases[i]) return weights[-1] @ x + biases[-1] - return (weights, biases), fwd \ No newline at end of file + return (weights, biases), fwd + +def create_degrees(key, n_inputs, n_hiddens, input_order, mode): + """ + Generates a degree for each hidden and input unit. A unit with degree d can only receive input from units with + degree less than d. + :param n_inputs: the number of inputs + :param n_hiddens: a list with the number of hidden units + :param input_order: the order of the inputs; can be 'random', 'sequential', or an array of an explicit order + :param mode: the strategy for assigning degrees to hidden nodes: can be 'random' or 'sequential' + :return: list of degrees + """ + + degrees = [] + + # create degrees for inputs + if isinstance(input_order, str): + + if input_order == 'random': + degrees_0 = jnp.arange(1, n_inputs + 1) + jax.random.permutation(key, degrees_0) + + elif input_order == 'sequential': + degrees_0 = jnp.arange(1, n_inputs + 1) + + else: + raise ValueError('invalid input order') + + else: + input_order = jnp.array(input_order) + assert jnp.all(jnp.sort(input_order) == jnp.arange(1, n_inputs + 1)), 'invalid input order' + degrees_0 = input_order + degrees.append(degrees_0) + + # create degrees for hiddens + if mode == 'random': + for N in n_hiddens: + min_prev_degree = min(jnp.min(degrees[-1]), n_inputs - 1) + degrees_l = jax.random.randint(key, shape=(N,), minval=min_prev_degree, maxval=n_inputs) + degrees.append(degrees_l) + + elif mode == 'sequential': + for N in n_hiddens: + degrees_l = jnp.arange(N) % max(1, n_inputs - 1) + min(1, n_inputs - 1) + degrees.append(degrees_l) + + else: + raise ValueError('invalid mode') + + return degrees + + +def create_masks(degrees): + """ + Creates the binary masks that make the connectivity autoregressive. + :param degrees: a list of degrees for every layer + :return: list of all masks, as theano shared variables + """ + + Ms = [] + + for l, (d0, d1) in enumerate(zip(degrees[:-1], degrees[1:])): + M = d0[:, jnp.newaxis] <= d1 + # M = theano.shared(M.astype(dtype), name='M' + str(l+1), borrow=True) + Ms.append(M) + + Mmp = degrees[-1][:, jnp.newaxis] < degrees[0] + # Mmp = theano.shared(Mmp.astype(dtype), name='Mmp', borrow=True) + + return Ms, Mmp + + +class MaskedLayer(nn.Module): + features: int + mask: ArrayImpl + kernel_init: Callable = lambda key, shape, mask: random.normal(key, shape=shape)*1e-3*mask + bias_init: Callable = nn.initializers.zeros_init() + + @nn.compact + def __call__(self, inputs): + kernel = self.param('kernel', + self.kernel_init, # Initialization function + (inputs.shape[-1], self.features), self.mask) # shape info. + y = jnp.dot(inputs, kernel*self.mask) + bias = self.param('bias', self.bias_init, (self.features,)) + y = y + bias + return y + + +class OutputLayer(nn.Module): + out_dim: int + out_mask: ArrayImpl + kernel_init: Callable = lambda key, shape, out_mask: jax.random.normal(key, shape=shape)/jnp.sqrt(shape[0])*out_mask + bias_init: Callable = nn.initializers.zeros_init() + + @nn.compact + def __call__(self, inputs): + kernel_m = self.param('kernel_m', + self.kernel_init, # Initialization function + (inputs.shape[-1], self.out_dim), self.out_mask) # shape info. + kernel_logp = self.param('kernel_logp', + self.kernel_init, # Initialization function + (inputs.shape[-1], self.out_dim), self.out_mask) + m = jnp.dot(inputs, kernel_m*self.out_mask) + logp = jnp.dot(inputs, kernel_logp*self.out_mask) + bias_m = self.param('bias_m', self.bias_init, (self.out_dim,)) + bias_logp = self.param('bias_logp', self.bias_init, (self.out_dim,)) + m = m + bias_m + logp = logp + bias_logp + return m, logp + + +class MaskedMLP(nn.Module): + n_hiddens: Sequence[int] + act_fn: Callable + masks: Sequence[ArrayImpl] + + def setup(self): + self.hidden = [MaskedLayer(mask.shape[1], mask) for mask in self.masks] + + def __call__(self, inputs): + x = inputs + for i, (layer,) in enumerate(zip(self.hidden)): + x = layer(x) + # if i != len(self.hidden) - 1: + x = self.act_fn(x) + return x + + + diff --git a/vbjax/ml_models.py b/vbjax/ml_models.py new file mode 100644 index 0000000..52ec6e4 --- /dev/null +++ b/vbjax/ml_models.py @@ -0,0 +1,288 @@ +import jax.numpy as jnp +from flax import linen as nn +from typing import Callable, Sequence, Optional +from jax._src.prng import PRNGKeyArrayImpl +import jax.random as random +from vbjax.layers import MaskedMLP, OutputLayer, create_degrees, create_masks +from vbjax.loops import make_ode_flax, heun_step +import jax +from flax.linen.initializers import zeros + +class GaussianMADE(nn.Module): + key: PRNGKeyArrayImpl + in_dim: int + n_hiddens: Sequence[int] + act_fn: Callable + input_order: str = 'sequential' + mode: str = 'sequential' + + def setup(self): + self.degrees = create_degrees(self.key, self.in_dim, self.n_hiddens, input_order=self.input_order, mode=self.mode) + self.masks, self.out_mask = create_masks(self.degrees) + self.mlp = MaskedMLP(self.n_hiddens, self.act_fn, self.masks) + self.output_layer = OutputLayer(self.in_dim, self.out_mask) + + + def __call__(self, inputs): + h = self.mlp(inputs) + m, logp = self.output_layer(h) + return m, logp + + + def gen(self, key, shape, u=None): + x = jnp.zeros(shape) + u = random.normal(key, shape) if u is None else u + + for i in range(1, shape[1] + 1): + h = self.mlp(x) + m, logp = self.output_layer(h) + idx = jnp.argwhere(self.degrees[0] == i)[0, 0] + x = x.at[:, idx].set(m[:, idx] + jnp.exp(jnp.minimum(-0.5 * logp[:, idx], 10.0)) * u[:, idx]) + return x + + +class MAF(nn.Module): + key: PRNGKeyArrayImpl + in_dim: int + n_hiddens: Sequence[int] + act_fn: Callable + n_mades: int + input_order: Optional[Sequence] = None + mode: str = 'sequential' + + def setup(self, input_order: Optional[Sequence] = None): + input_order = jnp.arange(1, self.in_dim+1) if input_order == None else input_order + self.mades = [GaussianMADE(random.split(self.key), self.in_dim, self.n_hiddens, self.act_fn, input_order=input_order[::((-1)**(i%2))], mode=self.mode) for i in range(self.n_mades)] + + def __call__(self, inputs): + u = inputs + logdet_dudx = 0 + for made in self.mades: + ms, logp = made(u) + u = jnp.exp(0.5 * logp) * (u - ms) + logdet_dudx += 0.5 * jnp.sum(logp, axis=1) + return u, logdet_dudx + + def gen(self, key, shape, u=None): + x = random.normal(key, shape) if u is None else u + + for made in self.mades[::-1]: + x = made.gen(key, shape, x) + return x + + + +class Heun_step(nn.Module): + dfun: Callable + stvar: Optional[int] = 0 + external_i: Optional[int] = False + adhoc: Optional[Callable] = None + + @nn.compact + def __call__(self, x, xs, p, i_ext): + tmap = jax.tree_util.tree_map + dt = 1. + d1 = self.dfun((x, p), i_ext) + xi = tmap(lambda x,d: x + dt*d, x, d1) + # xi = tmap(lambda x,d,a: x + dt*d + a, x, d1, stimulus) + + d2 = self.dfun((xi, p), i_ext) + nx = tmap(lambda x, d1,d2: x + dt*0.5*(d1 + d2), x, d1, d2) + # nx = tmap(lambda x, d1,d2,a: x + dt*0.5*(d1 + d2) + a, x, d1, d2, stimulus) + return nx, x + + +class Integrator(nn.Module): + dfun: Callable + step: Callable + stvar: Optional[int] = 0 + adhoc: Optional[Callable] = None + + @nn.compact + def __call__(self, c, xs, p=None, external_i=None): + STEP = nn.scan(self.step, + variable_broadcast="params", + split_rngs={"params": False}, + in_axes=(2, 2, 2), + out_axes=2 + ) + return STEP(self.dfun, self.stvar, self.adhoc)(c, xs, p, external_i) + + +class MLP_Ode(nn.Module): + out_dim: int + n_hiddens: Sequence[int] + act_fn: Callable + step: Callable + kernel_init: Callable = jax.nn.initializers.normal(1e-6) + bias_init: Callable = jax.nn.initializers.normal(1e-6) + integrate: Optional[bool] = True + i_ext: Optional[bool] = True + stvar: Optional[int] = 0 + p_mix: Optional[bool] = False + + def setup(self): + self.p_layers = [nn.Dense(feat, kernel_init=self.kernel_init) for feat in self.n_hiddens[0]] if self.p_mix else None + dims = self.n_hiddens[1:] if self.p_mix else self.n_hiddens + self.layers = [nn.Dense(feat, kernel_init=self.kernel_init, bias_init=self.bias_init) for feat in dims] + self.output = nn.Dense(self.out_dim, kernel_init=self.kernel_init, bias_init=self.bias_init) + + + def fwd(self, x, i_ext): + x, p = x + if self.p_mix: + for layer in self.p_layers[:-1]: + p = layer(p) + p = self.act_fn(p) + p = self.p_layers[-1](p) + x = jnp.c_[x, p, i_ext] if self.i_ext else x + + for layer in self.layers: + x = layer(x) + x = self.act_fn(x) + x = self.output(x) + return x + + def prepare_stimulus(self, x, external_i, stvar): + stimulus = jnp.zeros(x.shape) + # stimulus = stimulus.at[:,stvar,:].set(external_i) if isinstance(external_i, jnp.ndarray) else stimulus + return stimulus + + @nn.compact + def __call__(self, inputs): + if not self.integrate: + (x, p), i_ext = inputs + deriv = self.fwd((x, p), i_ext) + return deriv + + (x, p), i_ext = inputs if self.i_ext else (inputs, None) + + integrate = Integrator(self.fwd, self.step) + # initialize carry + xs = jnp.zeros_like(x) + # stimulus = self.prepare_stimulus(x, i_ext, self.stvar) + x = x[...,0] + traj = integrate(x, xs, p, i_ext) + + return traj[1] + + + +class Simple_MLP(nn.Module): + out_dim: int + n_hiddens: Sequence[int] + act_fn: Callable + kernel_init: Callable = jax.random.normal + extra_p: bool = False + + @nn.compact + def __call__(self, x, i_ext): + layers = [nn.Dense(feat, kernel_init=self.kernel_init*1e-6, bias_init=self.kernel_init*1e-6) for feat in self.n_hiddens] + output = nn.Dense(self.out_dim, kernel_init=self.kernel_init*1e-6, bias_init=self.kernel_init*1e-6) + x = jnp.c_[x[0], x[1]] if self.extra_p else x[0] + for layer in layers: + x = layer(x) + x = self.act_fn(x) + x = output(x) + return x + + +class NeuralOdeWrapper(nn.Module): + out_dim: int + n_hiddens: Sequence[int] + act_fn: Callable + extra_p: int + step: Optional[Callable] = Heun_step + integrator: Optional[Callable] = Integrator + network: Optional[Callable] = Simple_MLP + kernel_init: Callable = jax.nn.initializers.normal(10e-3) + integrate: Optional[bool] = True + i_ext: Optional[bool] = True + stvar: Optional[int] = 0 + + @nn.compact + def __call__(self, inputs): + x, p, i_ext = inputs + dfun = self.network(inputs, self.n_hiddens, self.act_fn, i_ext, extra_p=self.extra_p) + if not self.integrate: + deriv = dfun(inputs, i_ext) + return deriv + + integrate = self.integrator(dfun, self.step) + xs = jnp.zeros_like(x) # initialize carry + # i_ext = self.prepare_stimulus(x, i_ext, self.stvar) + x = x[...,0] + return integrate(x, xs, p, i_ext)[1] + + + +class Encoder(nn.Module): + in_dim: int + latent_dim: int + act_fn: Callable + n_hiddens: Sequence[int] = None + + def setup(self, n_hiddens: Optional[Sequence] = None): + n_hiddens = n_hiddens[::-1] if n_hiddens else [self.in_dim, 4*self.latent_dim, 2*self.latent_dim, self.latent_dim][::-1] + self.layers = [nn.Dense(feat) for feat in n_hiddens] + + def __call__(self, inputs): + x = inputs + for layer in self.layers: + x = layer(x) + x = self.act_fn(x) + return x + + +class Decoder(nn.Module): + in_dim: int + latent_dim: int + act_fn: Callable + n_hiddens: Sequence[int] = None + + def setup(self, n_hiddens: Optional[Sequence] = None): + n_hiddens = n_hiddens[::-1] if n_hiddens else [self.in_dim, 4*self.latent_dim, 2*self.latent_dim, self.latent_dim][::-1] + self.layers = [nn.Dense(feat) for feat in n_hiddens] + + def __call__(self, inputs): + x = inputs + for layer in self.layers[:-1]: + x = layer(x) + x = self.act_fn(x) + x = self.layers[-1](x) + return x + +class Autoencoder(nn.Module): + latent_dim: int + encoder_act_fn: Callable + decoder_act_fn: Callable + ode_act_fn: Callable + ode: bool = False + n_hiddens: Sequence[int] = None + kernel_init: Callable = jax.nn.initializers.normal(10e-3) + step: Optional[Callable] = Heun_step + integrator: Optional[Callable] = Integrator + network: Optional[Callable] = Simple_MLP + i_ext: Optional[bool] = True + ode_n_hiddens: Optional[Sequence] = None + + def integrate(self, encoded, L): + xs = jnp.ones((encoded.shape[0], encoded.shape[1], L)) # initialize carry + dfun = self.network(encoded.shape[1], self.ode_n_hiddens, self.ode_act_fn) + integrator = self.integrator(dfun, self.step) + return integrator(encoded, xs)[1] + + @nn.compact + def __call__(self, inputs): + L = inputs.shape[-1] + + encoder = Encoder(inputs.shape[1], self.latent_dim, self.encoder_act_fn, self.n_hiddens) + encoded = encoder(inputs[:,:,0]) if self.ode else encoder(inputs) # (N, ) + + decoder = Decoder(inputs.shape[1], self.latent_dim, self.decoder_act_fn) + y = decoder(encoded) + if self.ode: + y = self.integrate(encoded, L) + + return y + diff --git a/vbjax/noise_generator.py b/vbjax/noise_generator.py new file mode 100644 index 0000000..dcfc661 --- /dev/null +++ b/vbjax/noise_generator.py @@ -0,0 +1,62 @@ +import jax.numpy as np +import matplotlib.pyplot as plt +import jax + + +def white_noise(f): + return f + +def blue_noise(f): + return np.sqrt(f) + +def violet_noise(f): + return f + +def brownian_noise(f): + return 1/np.where(f == 0, float('inf'), f) + +def pink_noise(f): + return 1/np.where(f == 0, float('inf'), np.sqrt(f)) + +def spectral_exponent(f, exponent=1): + return 1/np.where(f == 0, float('inf'), f**exponent) + + +def make_noise_generator(psd = lambda f: 1): + """Generate noise given desired power spectrum + + Parameters + ========== + shape : tuple + (n_nodes, t_steps) + key : function + jax.random.PRNGKey() + sigma (optional kwarg): float + Standard deviation, defaults to 1 + exponent (optional kwarg) : float + Spectral exponent for 1/f noise, defaults to 1 + + Returns + ======= + noise_stream : (n_nodes, t_steps) array + + Notes + ===== + Example usage for white noise and 1/f noise + + >>> import vbjax as vb, import jax + >>> key = jax.random.PRNGKey(seed) + + """ + def gen(key, shape, sigma=1, **kwargs): + # sigma = kwargs['sigma'] or 1 + X_white = np.fft.rfft(jax.random.normal(key, shape)*sigma) + S = psd(np.fft.rfftfreq(shape[1]), **kwargs) + S = S / np.sqrt(np.mean(S**2)) + X_shaped = X_white * S + return np.fft.irfft(X_shaped) + return gen + + + + diff --git a/vbjax/train.py b/vbjax/train.py new file mode 100644 index 0000000..b98e65e --- /dev/null +++ b/vbjax/train.py @@ -0,0 +1,82 @@ +from flax.training import train_state +import jax.numpy as jnp +import jax.random as random +from flax import linen as nn +from vbjax.ml_models import GaussianMADE, MAF +import optax +import matplotlib.pyplot as plt +import numpy as np +from vbjax.train_utils import eval_model, train_step, grad_func, log_likelihood_MADE + +def train_and_evaluate(model, X, config): + rng = random.key(0) + rng, key = random.split(rng) + + init_data = jnp.ones((config['batch_size'], config['in_dim']), jnp.float32) + params = model.init(key, init_data)['params'] + + state = train_state.TrainState.create( + apply_fn=model.apply, + params=params, + tx=optax.adam(config['learning_rate']), + ) + + print(state.params['mlp']['hidden_0']['kernel'][0,:]) + + for i, epoch in enumerate(range(config['num_epochs'])): + if i%10==0: + ms, logp = model.apply({'params': state.params}, X) + loss, u_distro = eval_model( + model, state.params, X, key, log_likelihood_MADE, shape=(20000,2), + ) + print('eval epoch: {}, loss: {}'.format(i + 1, loss)) + U = jnp.exp(.5 * logp) * (X - ms) + L = -0.5 * (X.shape[0] * jnp.log(2 * jnp.pi) + jnp.sum(U ** 2 - logp, axis=1)) + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14,4)) + ax1.scatter(np.array(X[:,0]), np.array(X[:,1])) + cs = ax2.scatter(np.array(U[:,0]), np.array(U[:,1]), c=L, vmin=L.min(), vmax=L.max()) + his = ax3.hist2d(np.array(u_distro[:,0]), np.array(u_distro[:,1]), bins=30, density=True, vmax=.03) + fig.colorbar(his[3], ax=ax3) + ax3.set_title('logp') + plt.colorbar(cs, ax=ax2) + fig.tight_layout() + plt.show() + for _ in range(100): + batch = X[i*config['batch_size']:(i+1)*config['batch_size']] + rng, key = random.split(rng) + state = train_step(state, batch, log_likelihood_MADE) + + ms, logp = model.apply({'params': state.params}, batch) + loss, u_distro = eval_model( + model, state.params, batch, key, log_likelihood_MADE + ) + print('eval epoch: {}, loss: {}'.format(i + 1, loss)) + # print(jnp.linalg.norm(grad_func(state, batch)['mlp']['hidden_0']['kernel'])) + + return state, model + + + +config = {} +config['learning_rate'] = .01 +config['in_dim'] = 2 +config['batch_size'] = 256 +config['num_epochs'] = 120 +config['n_hiddens'] = [5, 5] + + +key1, key2 = random.split(random.key(0), 2) +x2 = 4 * random.normal(key1, (config['batch_size']*100,)) +x1 = (.25*x2**2) + random.normal(key2, (config['batch_size']*100,)) +X = jnp.vstack([x2, x1]).T + +# key1, key2 = random.split(random.key(0), 2) +# x2 = 3 * random.normal(key1, (config['batch_size']*100,)) +# x1 = x2 + random.normal(key2, (config['batch_size']*100,)) +# X = jnp.vstack([x2, x1]).T + + +model=GaussianMADE(random.PRNGKey(42), 2, config['n_hiddens'], act_fn=nn.relu) +state_f, mdl = train_and_evaluate(model, X, config) + + diff --git a/vbjax/train_utils.py b/vbjax/train_utils.py new file mode 100644 index 0000000..52655ae --- /dev/null +++ b/vbjax/train_utils.py @@ -0,0 +1,76 @@ +import jax +import jax.numpy as jnp +from flax import linen as nn + +def log_likelihood_MADE(ms, logp, x, *args): + u = jnp.exp(.5 * logp) * (x - ms) + return -(- 0.5 * (x.shape[1] * jnp.log(2 * jnp.pi) + jnp.sum(u ** 2 - logp, axis=1))) + + +def log_likelihood_MAF(u, logdet_dudx, *arg): + return -(- 0.5 * u.shape[1] * jnp.log(2 * jnp.pi) - 0.5 * jnp.sum(u ** 2, axis=1) + logdet_dudx) + +def mse_ode(traj, x, *arg): + return jnp.mean((traj-x)**2) + + +def eval_model_ode(model, params, batch, loss_fn=None, shape=None): + batch, i_ext = batch + def eval_model(model): + output = model(batch, i_ext) + loss = loss_fn(output, batch) + return loss + return nn.apply(eval_model, model)({'params': params}) + + +def eval_loss(model, params, batch, loss_fn=None, shape=None): + batch, i_ext = batch + def eval_model(model): + output = model(batch, i_ext) + loss = loss_fn(output, batch) + return loss + return nn.apply(eval_model, model)({'params': params}) + +def eval_model(model, params, batch, key, loss_fn, shape=None): + shape = shape if shape else batch.shape + def eval_model(model): + output = model(batch) + loss = loss_fn(output, batch).mean() + u_sample = model.gen(key, shape) + return loss, u_sample + return nn.apply(eval_model, model)({'params': params}) + + +def train_step(state, batch, loss_f): + batch, p = batch + def loss_fn(params): + output = state.apply_fn( + {'params': params}, batch, p, + ) + loss = loss_f(output, batch) + return loss + grads = jax.grad(loss_fn)(state.params) + return state.apply_gradients(grads=grads) + + +def grad_func(state, batch, loss_fn): + def loss_fn(params): + output = state.apply_fn( + {'params': params}, batch, + ) + loss = loss_fn(output, batch) + return loss + grads = jax.grad(loss_fn)(state.params) + return grads + +def loss_t(traj, X): + X, iext = X + loss_bias = jnp.var(X, axis=2)*10+1 + squared_loss_vec = jnp.square(X - traj).mean(axis=(2)) + return (loss_bias*squared_loss_vec).sum() + +def loss_t_unpack(traj, X): + X, iext = X + loss_bias = jnp.var(X, axis=2)*40+1 + squared_loss_vec = jnp.square(X - traj).mean(axis=(2)) + return (loss_bias*squared_loss_vec).sum(axis=1) \ No newline at end of file