Skip to content

Latest commit

 

History

History
34 lines (26 loc) · 1.69 KB

README.md

File metadata and controls

34 lines (26 loc) · 1.69 KB

Conditional normalizing flows in Jax

Implementation of some common normalizing flow models allowing for a conditioning context using Jax, Flax, and Distrax. The following are currently implemented:

Examples

Basic usage

import jax
from models.maf import MaskedAutoregressiveFlow
from models.nsf import NeuralSplineFlow

n_dim = 2  # Feature dim
n_context = 1  # Context dim

## Define flow model
# model = MaskedAutoregressiveFlow(n_dim=n_dim, n_context=n_context, hidden_dims=[128,128], n_transforms=12, activation="tanh", use_random_permutations=False)
model = NeuralSplineFlow(n_dim=n_dim, n_context=n_context, hidden_dims=[128,128], n_transforms=8, activation="gelu", n_bins=4)

## Initialize model and params
key = jax.random.PRNGKey(42)
x_test = jax.random.uniform(key=key, shape=(64, n_dim))
context = jax.random.uniform(key=key, shape=(64, n_context))
params = model.init(key, x_test, context)

## Log-prob and sampling
log_prob = model.apply(params, x_test, jnp.ones((x_test.shape[0], n_context)))
samples = model.apply(params, n_samples, key, jnp.ones((n_samples, n_context)), method=model.sample)