diff --git a/README.md b/README.md index 4dced1df7..2fd8a2a7b 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,9 @@ NumPyro is a lightweight probabilistic programming library that provides a NumPy NumPyro is designed to be *lightweight* and focuses on providing a flexible substrate that users can build on: - - **Pyro Primitives:** NumPyro programs can contain regular Python and NumPy code, in addition to [Pyro primitives](http://pyro.ai/examples/intro_part_i.html) like `sample` and `param`. The model code should look very similar to Pyro except for some minor differences between PyTorch and Numpy's API. See the [example](https://github.com/pyro-ppl/numpyro#a-simple-example---8-schools) below. - - **Inference algorithms:** NumPyro supports a number of inference algorithms, with a particular focus on MCMC algorithms like Hamiltonian Monte Carlo, including an implementation of the No U-Turn Sampler. Additional MCMC algorithms include [MixedHMC](http://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.mixed_hmc.MixedHMC) (which can accommodate discrete latent variables) as well as [HMCECS](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.HMCECS) (which only computes the likelihood for subsets of the data in each iteration). One of the motivations for NumPyro was to speed up Hamiltonian Monte Carlo by JIT compiling the verlet integrator that includes multiple gradient computations. With JAX, we can compose `jit` and `grad` to compile the entire integration step into an XLA optimized kernel. We also eliminate Python overhead by JIT compiling the entire tree building stage in NUTS (this is possible using [Iterative NUTS](https://github.com/pyro-ppl/numpyro/wiki/Iterative-NUTS)). There is also a basic Variational Inference implementation together with many flexible (auto)guides for Automatic Differentiation Variational Inference (ADVI). The variational inference implementation supports a number of features, including support for models with discrete latent variables (see [TraceGraph_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceGraph_ELBO) and [TraceEnum_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceEnum_ELBO)). - - **Distributions:** The [numpyro.distributions](https://numpyro.readthedocs.io/en/latest/distributions.html) module provides distribution classes, constraints and bijective transforms. The distribution classes wrap over samplers implemented to work with JAX's [functional pseudo-random number generator](https://github.com/google/jax#random-numbers-are-different). The design of the distributions module largely follows from [PyTorch](https://pytorch.org/docs/stable/distributions.html). A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in `torch.distributions`. In addition to distributions, `constraints` and `transforms` are very useful when operating on distribution classes with bounded support. Finally, distributions from TensorFlow Probability ([TFP](http://num.pyro.ai/en/latest/distributions.html?highlight=tfp#numpyro.contrib.tfp.distributions.TFPDistribution)) can directly be used in NumPyro models. + - **Pyro Primitives:** NumPyro programs can contain regular Python and NumPy code, in addition to [Pyro primitives](https://pyro.ai/examples/intro_part_i.html) like `sample` and `param`. The model code should look very similar to Pyro except for some minor differences between PyTorch and Numpy's API. See the [example](https://github.com/pyro-ppl/numpyro#a-simple-example---8-schools) below. + - **Inference algorithms:** NumPyro supports a number of inference algorithms, with a particular focus on MCMC algorithms like Hamiltonian Monte Carlo, including an implementation of the No U-Turn Sampler. Additional MCMC algorithms include [MixedHMC](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.mixed_hmc.MixedHMC) (which can accommodate discrete latent variables) as well as [HMCECS](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.HMCECS) (which only computes the likelihood for subsets of the data in each iteration). One of the motivations for NumPyro was to speed up Hamiltonian Monte Carlo by JIT compiling the verlet integrator that includes multiple gradient computations. With JAX, we can compose `jit` and `grad` to compile the entire integration step into an XLA optimized kernel. We also eliminate Python overhead by JIT compiling the entire tree building stage in NUTS (this is possible using [Iterative NUTS](https://github.com/pyro-ppl/numpyro/wiki/Iterative-NUTS)). There is also a basic Variational Inference implementation together with many flexible (auto)guides for Automatic Differentiation Variational Inference (ADVI). The variational inference implementation supports a number of features, including support for models with discrete latent variables (see [TraceGraph_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceGraph_ELBO) and [TraceEnum_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceEnum_ELBO)). + - **Distributions:** The [numpyro.distributions](https://numpyro.readthedocs.io/en/latest/distributions.html) module provides distribution classes, constraints and bijective transforms. The distribution classes wrap over samplers implemented to work with JAX's [functional pseudo-random number generator](https://github.com/google/jax#random-numbers-are-different). The design of the distributions module largely follows from [PyTorch](https://pytorch.org/docs/stable/distributions.html). A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in `torch.distributions`. In addition to distributions, `constraints` and `transforms` are very useful when operating on distribution classes with bounded support. Finally, distributions from TensorFlow Probability ([TFP](https://num.pyro.ai/en/latest/distributions.html?highlight=tfp#numpyro.contrib.tfp.distributions.TFPDistribution)) can directly be used in NumPyro models. - **Effect handlers:** Like Pyro, primitives like `sample` and `param` can be provided nonstandard interpretations using effect-handlers from the [numpyro.handlers](https://numpyro.readthedocs.io/en/latest/handlers.html) module, and these can be easily extended to implement custom inference algorithms and inference utilities. ## A Simple Example - 8 Schools @@ -50,7 +50,7 @@ The data is given by: ``` -Let us infer the values of the unknown parameters in our model by running MCMC using the No-U-Turn Sampler (NUTS). Note the usage of the `extra_fields` argument in [MCMC.run](http://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.mcmc.MCMC.run). By default, we only collect samples from the target (posterior) distribution when we run inference using `MCMC`. However, collecting additional fields like potential energy or the acceptance probability of a sample can be easily achieved by using the `extra_fields` argument. For a list of possible fields that can be collected, see the [HMCState](http://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.mcmc.HMCState) object. In this example, we will additionally collect the `potential_energy` for each sample. +Let us infer the values of the unknown parameters in our model by running MCMC using the No-U-Turn Sampler (NUTS). Note the usage of the `extra_fields` argument in [MCMC.run](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.mcmc.MCMC.run). By default, we only collect samples from the target (posterior) distribution when we run inference using `MCMC`. However, collecting additional fields like potential energy or the acceptance probability of a sample can be easily achieved by using the `extra_fields` argument. For a list of possible fields that can be collected, see the [HMCState](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc.HMCState) object. In this example, we will additionally collect the `potential_energy` for each sample. ```python >>> from jax import random @@ -88,7 +88,7 @@ Expected log joint density: -54.55 ``` -The values above 1 for the split Gelman Rubin diagnostic (`r_hat`) indicates that the chain has not fully converged. The low value for the effective sample size (`n_eff`), particularly for `tau`, and the number of divergent transitions looks problematic. Fortunately, this is a common pathology that can be rectified by using a [non-centered paramaterization](https://mc-stan.org/docs/2_18/stan-users-guide/reparameterization-section.html) for `tau` in our model. This is straightforward to do in NumPyro by using a [TransformedDistribution](http://num.pyro.ai/en/latest/distributions.html#transformeddistribution) instance together with a [reparameterization](http://num.pyro.ai/en/latest/handlers.html#reparam) effect handler. Let us rewrite the same model but instead of sampling `theta` from a `Normal(mu, tau)`, we will instead sample it from a base `Normal(0, 1)` distribution that is transformed using an [AffineTransform](http://num.pyro.ai/en/latest/distributions.html#affinetransform). Note that by doing so, NumPyro runs HMC by generating samples `theta_base` for the base `Normal(0, 1)` distribution instead. We see that the resulting chain does not suffer from the same pathology — the Gelman Rubin diagnostic is 1 for all the parameters and the effective sample size looks quite good! +The values above 1 for the split Gelman Rubin diagnostic (`r_hat`) indicates that the chain has not fully converged. The low value for the effective sample size (`n_eff`), particularly for `tau`, and the number of divergent transitions looks problematic. Fortunately, this is a common pathology that can be rectified by using a [non-centered paramaterization](https://mc-stan.org/docs/2_18/stan-users-guide/reparameterization-section.html) for `tau` in our model. This is straightforward to do in NumPyro by using a [TransformedDistribution](https://num.pyro.ai/en/latest/distributions.html#transformeddistribution) instance together with a [reparameterization](https://num.pyro.ai/en/latest/handlers.html#reparam) effect handler. Let us rewrite the same model but instead of sampling `theta` from a `Normal(mu, tau)`, we will instead sample it from a base `Normal(0, 1)` distribution that is transformed using an [AffineTransform](https://num.pyro.ai/en/latest/distributions.html#affinetransform). Note that by doing so, NumPyro runs HMC by generating samples `theta_base` for the base `Normal(0, 1)` distribution instead. We see that the resulting chain does not suffer from the same pathology — the Gelman Rubin diagnostic is 1 for all the parameters and the effective sample size looks quite good! ```python >>> from numpyro.infer.reparam import TransformReparam @@ -140,12 +140,12 @@ Expected log joint density: -46.09 ``` -Note that for the class of distributions with `loc,scale` parameters such as `Normal`, `Cauchy`, `StudentT`, we also provide a [LocScaleReparam](http://num.pyro.ai/en/latest/reparam.html#loc-scale-decentering) reparameterizer to achieve the same purpose. The corresponding code will be +Note that for the class of distributions with `loc,scale` parameters such as `Normal`, `Cauchy`, `StudentT`, we also provide a [LocScaleReparam](https://num.pyro.ai/en/latest/reparam.html#loc-scale-decentering) reparameterizer to achieve the same purpose. The corresponding code will be with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}): theta = numpyro.sample('theta', dist.Normal(mu, tau)) -Now, let us assume that we have a new school for which we have not observed any test scores, but we would like to generate predictions. NumPyro provides a [Predictive](http://num.pyro.ai/en/latest/utilities.html#numpyro.infer.util.Predictive) class for such a purpose. Note that in the absence of any observed data, we simply use the population-level parameters to generate predictions. The `Predictive` utility conditions the unobserved `mu` and `tau` sites to values drawn from the posterior distribution from our last MCMC run, and runs the model forward to generate predictions. +Now, let us assume that we have a new school for which we have not observed any test scores, but we would like to generate predictions. NumPyro provides a [Predictive](https://num.pyro.ai/en/latest/utilities.html#numpyro.infer.util.Predictive) class for such a purpose. Note that in the absence of any observed data, we simply use the population-level parameters to generate predictions. The `Predictive` utility conditions the unobserved `mu` and `tau` sites to values drawn from the posterior distribution from our last MCMC run, and runs the model forward to generate predictions. ```python >>> from numpyro.infer import Predictive @@ -188,14 +188,14 @@ We provide an overview of most of the inference algorithms supported by NumPyro ### MCMC -- [NUTS](https://num.pyro.ai/en/latest/mcmc.html#nuts), which is an adaptive variant of [HMC](https://num.pyro.ai/en/latest/mcmc.html#hmc), is probably the most commonly used inference algorithm in NumPyro. Note that NUTS and HMC are not directly applicable to models with discrete latent variables, but in cases where the discrete variables have finite support and summing them out (i.e. enumeration) is tractable, NumPyro will automatically sum out discrete latent variables and perform NUTS/HMC on the remaining continuous latent variables. +- [NUTS](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc.NUTS), which is an adaptive variant of [HMC](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc.HMC), is probably the most commonly used inference algorithm in NumPyro. Note that NUTS and HMC are not directly applicable to models with discrete latent variables, but in cases where the discrete variables have finite support and summing them out (i.e. enumeration) is tractable, NumPyro will automatically sum out discrete latent variables and perform NUTS/HMC on the remaining continuous latent variables. As discussed above, model [reparameterization](https://num.pyro.ai/en/latest/reparam.html#module-numpyro.infer.reparam) may be important in some cases to get good performance. Note that, generally speaking, we expect inference to be harder as the dimension of the latent space increases. See the [bad geometry](https://num.pyro.ai/en/latest/tutorials/bad_posterior_geometry.html) tutorial for additional tips and tricks. -- [MixedHMC](http://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.mixed_hmc.MixedHMC) can be an effective inference strategy for models that contain both continuous and discrete latent variables. -- [HMCECS](http://num.pyro.ai/en/latest/mcmc.html#hmcecs) can be an effective inference strategy for models with a large number of data points. It is applicable to models with continuous latent variables. See [here](https://num.pyro.ai/en/latest/examples/covtype.html) for an example. -- [BarkerMH](https://num.pyro.ai/en/latest/mcmc.html#barkermh) is a gradient-based MCMC method that may be competitive with HMC and NUTS for some models. It is applicable to models with continuous latent variables. -- [HMCGibbs](https://num.pyro.ai/en/latest/mcmc.html#hmcgibbs) combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user. -- [DiscreteHMCGibbs](https://num.pyro.ai/en/latest/mcmc.html#discretehmcgibbs) combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically. -- [SA](https://num.pyro.ai/en/latest/mcmc.html#sa) is the only MCMC method in NumPyro that does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast. +- [MixedHMC](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.mixed_hmc.MixedHMC) can be an effective inference strategy for models that contain both continuous and discrete latent variables. +- [HMCECS](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.HMCECS) can be an effective inference strategy for models with a large number of data points. It is applicable to models with continuous latent variables. See [here](https://num.pyro.ai/en/latest/examples/covtype.html) for an example. +- [BarkerMH](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.barker.BarkerMH) is a gradient-based MCMC method that may be competitive with HMC and NUTS for some models. It is applicable to models with continuous latent variables. +- [HMCGibbs](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.HMCGibbs) combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user. +- [DiscreteHMCGibbs](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.DiscreteHMCGibbs) combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically. +- [SA](https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.sa.SA) is the only MCMC method in NumPyro that does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast. Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see [restrictions](https://pyro.ai/examples/enumeration.html#Restriction-1:-conditional-independence)). Enumerated sites need to be marked with `infer={'enumerate': 'parallel'}` like in the [annotation example](https://num.pyro.ai/en/stable/examples/annotation.html). @@ -209,13 +209,13 @@ Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete l - [TraceGraph_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceGraph_ELBO) offers variance reduction strategies for models with discrete latent variables. Generally speaking, this ELBO should always be used for models with discrete latent variables. - [TraceEnum_ELBO](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.elbo.TraceEnum_ELBO) offers variable enumeration strategies for models with discrete latent variables. Generally speaking, this ELBO should always be used for models with discrete latent variables when enumeration is possible. - Automatic guides (appropriate for models with continuous latent variables) - - [AutoNormal](https://num.pyro.ai/en/latest/autoguide.html#autonormal) and [AutoDiagonalNormal](https://num.pyro.ai/en/latest/autoguide.html#autodiagonalnormal) are our basic mean-field guides. If the latent space is non-euclidean (due to e.g. a positivity constraint on one of the sample sites) an appropriate bijective transformation is automatically used under the hood to map between the unconstrained space (where the Normal variational distribution is defined) to the corresponding constrained space (note this is true for all automatic guides). These guides are a great place to start when trying to get variational inference to work on a model you are developing. - - [AutoMultivariateNormal](https://num.pyro.ai/en/latest/autoguide.html#automultivariatenormal) and [AutoLowRankMultivariateNormal](https://num.pyro.ai/en/latest/autoguide.html#autolowrankmultivariatenormal) also construct Normal variational distributions but offer more flexibility, as they can capture correlations in the posterior. Note that these guides may be difficult to fit in the high-dimensional setting. - - [AutoDelta](https://num.pyro.ai/en/latest/autoguide.html#autodelta) is used for computing point estimates via MAP (maximum a posteriori estimation). See [here](https://github.com/pyro-ppl/numpyro/blob/bbe1f879eede79eebfdd16dfc49c77c4d1fc727c/examples/zero_inflated_poisson.py#L101) for example usage. - - [AutoBNAFNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoBNAFNormal) and [AutoIAFNormal](https://num.pyro.ai/en/latest/autoguide.html#autoiafnormal) offer flexible variational distributions parameterized by normalizing flows. - - [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#autodais) is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model. - - [AutoSurrogateLikelihoodDAIS](https://num.pyro.ai/en/latest/autoguide.html#autosurrogatelikelihooddais) is a powerful variational inference algorithm that leverages HMC and that supports data subsampling. - - [AutoSemiDAIS](https://num.pyro.ai/en/latest/autoguide.html#autosemidais) constructs a posterior approximation like [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#autodais) for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables. + - [AutoNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoNormal) and [AutoDiagonalNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoDiagonalNormal) are our basic mean-field guides. If the latent space is non-euclidean (due to e.g. a positivity constraint on one of the sample sites) an appropriate bijective transformation is automatically used under the hood to map between the unconstrained space (where the Normal variational distribution is defined) to the corresponding constrained space (note this is true for all automatic guides). These guides are a great place to start when trying to get variational inference to work on a model you are developing. + - [AutoMultivariateNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoMultivariateNormal) and [AutoLowRankMultivariateNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoLowRankMultivariateNormal) also construct Normal variational distributions but offer more flexibility, as they can capture correlations in the posterior. Note that these guides may be difficult to fit in the high-dimensional setting. + - [AutoDelta](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoDelta) is used for computing point estimates via MAP (maximum a posteriori estimation). See [here](https://github.com/pyro-ppl/numpyro/blob/bbe1f879eede79eebfdd16dfc49c77c4d1fc727c/examples/zero_inflated_poisson.py#L101) for example usage. + - [AutoBNAFNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoBNAFNormal) and [AutoIAFNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoIAFNormal) offer flexible variational distributions parameterized by normalizing flows. + - [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoDAIS) is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model. + - [AutoSurrogateLikelihoodDAIS](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoSurrogateLikelihoodDAIS) is a powerful variational inference algorithm that leverages HMC and that supports data subsampling. + - [AutoSemiDAIS](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoSemiDAIS) constructs a posterior approximation like [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoDAIS) for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables. - [AutoLaplaceApproximation](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoLaplaceApproximation) can be used to compute a Laplace approximation. ### Stein Variational Inference @@ -251,7 +251,7 @@ For Cloud TPU VM, you need to setup the TPU backend as detailed in the [Cloud TP After you have verified that the TPU backend is properly set up, you can install NumPyro using the `pip install numpyro` command. -> **Default Platform:** JAX will use GPU by default if CUDA-supported `jaxlib` package is installed. You can use [set_platform](http://num.pyro.ai/en/stable/utilities.html#set-platform) utility `numpyro.set_platform("cpu")` to switch to CPU at the beginning of your program. +> **Default Platform:** JAX will use GPU by default if CUDA-supported `jaxlib` package is installed. You can use [set_platform](https://num.pyro.ai/en/stable/utilities.html#set-platform) utility `numpyro.set_platform("cpu")` to switch to CPU at the beginning of your program. You can also install NumPyro from source: @@ -271,7 +271,7 @@ conda install -c conda-forge numpyro 1. Unlike in Pyro, `numpyro.sample('x', dist.Normal(0, 1))` does not work. Why? - You are most likely using a `numpyro.sample` statement outside an inference context. JAX does not have a global random state, and as such, distribution samplers need an explicit random number generator key ([PRNGKey](https://jax.readthedocs.io/en/latest/jax.random.html#jax.random.PRNGKey)) to generate samples from. NumPyro's inference algorithms use the [seed](http://num.pyro.ai/en/latest/handlers.html#seed) handler to thread in a random number generator key, behind the scenes. + You are most likely using a `numpyro.sample` statement outside an inference context. JAX does not have a global random state, and as such, distribution samplers need an explicit random number generator key ([PRNGKey](https://jax.readthedocs.io/en/latest/jax.random.html#jax.random.PRNGKey)) to generate samples from. NumPyro's inference algorithms use the [seed](https://num.pyro.ai/en/latest/handlers.html#seed) handler to thread in a random number generator key, behind the scenes. Your options are: @@ -302,7 +302,7 @@ conda install -c conda-forge numpyro - Any `torch` operation in your model will need to be written in terms of the corresponding `jax.numpy` operation. Additionally, not all `torch` operations have a `numpy` counterpart (and vice-versa), and sometimes there are minor differences in the API. - `pyro.sample` statements outside an inference context will need to be wrapped in a `seed` handler, as mentioned above. - - There is no global parameter store, and as such using `numpyro.param` outside an inference context will have no effect. To retrieve the optimized parameter values from SVI, use the [SVI.get_params](http://num.pyro.ai/en/latest/svi.html#numpyro.infer.svi.SVI.get_params) method. Note that you can still use `param` statements inside a model and NumPyro will use the [substitute](http://num.pyro.ai/en/latest/handlers.html#substitute) effect handler internally to substitute values from the optimizer when running the model in SVI. + - There is no global parameter store, and as such using `numpyro.param` outside an inference context will have no effect. To retrieve the optimized parameter values from SVI, use the [SVI.get_params](https://num.pyro.ai/en/latest/svi.html#numpyro.infer.svi.SVI.get_params) method. Note that you can still use `param` statements inside a model and NumPyro will use the [substitute](https://num.pyro.ai/en/latest/handlers.html#substitute) effect handler internally to substitute values from the optimizer when running the model in SVI. - PyTorch neural network modules will need to rewritten as [stax](https://github.com/google/jax#neural-net-building-with-stax) neural networks. See the [VAE](#examples) example for differences in syntax between the two backends. - JAX works best with functional code, particularly if we would like to leverage JIT compilation, which NumPyro does internally for many inference subroutines. As such, if your model has side-effects that are not visible to the JAX tracer, it may need to rewritten in a more functional style. diff --git a/docs/source/autoguide.rst b/docs/source/autoguide.rst index fbb289f0e..11cb5d3be 100644 --- a/docs/source/autoguide.rst +++ b/docs/source/autoguide.rst @@ -3,13 +3,13 @@ Automatic Guide Generation We provide a brief overview of the automatically generated guides available in NumPyro: -* `AutoNormal `_ and `AutoDiagonalNormal `_ are our basic mean-field guides. If the latent space is non-euclidean (due to e.g. a positivity constraint on one of the sample sites) an appropriate bijective transformation is automatically used under the hood to map between the unconstrained space (where the Normal variational distribution is defined) to the corresponding constrained space (note this is true for all automatic guides). These guides are a great place to start when trying to get variational inference to work on a model you are developing. -* `AutoMultivariateNormal `_ and `AutoLowRankMultivariateNormal `_ also construct Normal variational distributions but offer more flexibility, as they can capture correlations in the posterior. Note that these guides may be difficult to fit in the high-dimensional setting. -* `AutoDelta `_ is used for computing point estimates via MAP (maximum a posteriori estimation). See `here `_ for example usage. -* `AutoBNAFNormal `_ and `AutoIAFNormal `_ offer flexible variational distributions parameterized by normalizing flows. -* `AutoDAIS `_ is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model. -* `AutoSurrogateLikelihoodDAIS `_ is a powerful variational inference algorithm that leverages HMC and that supports data subsampling. -* `AutoSemiDAIS `_ constructs a posterior approximation like `AutoDAIS `_ for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables. +* `AutoNormal `_ and `AutoDiagonalNormal `_ are our basic mean-field guides. If the latent space is non-euclidean (due to e.g. a positivity constraint on one of the sample sites) an appropriate bijective transformation is automatically used under the hood to map between the unconstrained space (where the Normal variational distribution is defined) to the corresponding constrained space (note this is true for all automatic guides). These guides are a great place to start when trying to get variational inference to work on a model you are developing. +* `AutoMultivariateNormal `_ and `AutoLowRankMultivariateNormal `_ also construct Normal variational distributions but offer more flexibility, as they can capture correlations in the posterior. Note that these guides may be difficult to fit in the high-dimensional setting. +* `AutoDelta `_ is used for computing point estimates via MAP (maximum a posteriori estimation). See `here `_ for example usage. +* `AutoBNAFNormal `_ and `AutoIAFNormal `_ offer flexible variational distributions parameterized by normalizing flows. +* `AutoDAIS `_ is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model. +* `AutoSurrogateLikelihoodDAIS `_ is a powerful variational inference algorithm that leverages HMC and that supports data subsampling. +* `AutoSemiDAIS `_ constructs a posterior approximation like `AutoDAIS `_ for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables. * `AutoLaplaceApproximation `_ can be used to compute a Laplace approximation. .. automodule:: numpyro.infer.autoguide diff --git a/docs/source/mcmc.rst b/docs/source/mcmc.rst index 09f53f159..edfda92d2 100644 --- a/docs/source/mcmc.rst +++ b/docs/source/mcmc.rst @@ -4,8 +4,8 @@ Markov Chain Monte Carlo (MCMC) We provide a high-level overview of the MCMC algorithms in NumPyro: * `NUTS `_, which is an adaptive variant of `HMC `_, is probably the most commonly used MCMC algorithm in NumPyro. Note that NUTS and HMC are not directly applicable to models with discrete latent variables, but in cases where the discrete variables have finite support and summing them out (i.e. enumeration) is tractable, NumPyro will automatically sum out discrete latent variables and perform NUTS/HMC on the remaining continuous latent variables. As discussed above, model `reparameterization `_ may be important in some cases to get good performance. Note that, generally speaking, we expect inference to be harder as the dimension of the latent space increases. See the `bad geometry `_ tutorial for additional tips and tricks. -* `MixedHMC `_ can be an effective inference strategy for models that contain both continuous and discrete latent variables. -* `HMCECS `_ can be an effective inference strategy for models with a large number of data points. It is applicable to models with continuous latent variables. See `this example `_ for detailed usage. +* `MixedHMC `_ can be an effective inference strategy for models that contain both continuous and discrete latent variables. +* `HMCECS `_ can be an effective inference strategy for models with a large number of data points. It is applicable to models with continuous latent variables. See `this example `_ for detailed usage. * `BarkerMH `_ is a gradient-based MCMC method that may be competitive with HMC and NUTS for some models. It is applicable to models with continuous latent variables. * `HMCGibbs `_ combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user. * `DiscreteHMCGibbs `_ combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically. diff --git a/setup.py b/setup.py index bba87d335..be293001e 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ "isort>=5.0", "pytest>=4.1", "pyro-api>=0.1.1", - "scipy>=1.6,<1.7", + "scipy>=1.9", ], "dev": [ "dm-haiku", @@ -69,7 +69,7 @@ "pylab-sdk", # jaxns dependency "pyyaml", # flax dependency "requests", # pylab dependency - "tensorflow_probability>=0.17.0", + "tensorflow_probability>=0.18.0", ], "examples": [ "arviz", diff --git a/test/test_distributions.py b/test/test_distributions.py index 7cfbfdaed..6a87ef4f0 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1448,10 +1448,10 @@ def test_log_prob_LKJCholesky(dimension, concentration): def test_zero_inflated_logits_probs_agree(): - concentration = np.exp(np.random.normal(100)) - rate = np.exp(np.random.normal(100)) + concentration = np.exp(np.random.normal(1)) + rate = np.exp(np.random.normal(1)) d = dist.GammaPoisson(concentration, rate) - gate_logits = np.random.normal(100) + gate_logits = np.random.normal(0) gate_probs = expit(gate_logits) zi_logits = dist.ZeroInflatedDistribution(d, gate_logits=gate_logits) zi_probs = dist.ZeroInflatedDistribution(d, gate=gate_probs)