From ecd6255a3804a697f2a0c0b6aca734ef4c7784cf Mon Sep 17 00:00:00 2001 From: Du Phan Date: Tue, 16 Mar 2021 12:47:32 -0500 Subject: [PATCH] Bump to 0.6.0 (#959) * sketch the plan * add a running implementation (not working yet * remove unnecessary changes * sequential update * temp save * add a working implementation * add 24d example * fix lint * test for the order * merge master * fix bug at reset momentum * add various mh functions * add various discrete gibbs function method * change stay_prob to modified to avoid confusing users * expose more information for mixed hmc * sketch an implementation * temp save * temp save * finish the implementation * keep kinetic energy * add temperature experiment * add dual averaging * add various debug statements * fix bugs * clean up and separating out clock adapter; but target distribution is wrong due to a bug somewhere * clean up * add comments and an example * make sure forward mode work * add docs for new HMC fields * add tests for mixedhmc * fix step_size bug * use modified=False * tests pass with the fix * skip print summary * adjust trajectory length * port update_version script from Pyro * pin jax/jaxlib versions * run isort * fix some issues during collection notes * use result_type instead of canonicalize_dtype * fix lint * change get_dtype to jnp.result_type * add print summary * fix compiling issue for mcmc * also try to avoid compiling issue in other samplers * also fix compiling issue in barkermh * convert init params types to strong types * address comments * fix wrong docs * run isort --- README.md | 2 +- docs/source/distributions.rst | 8 +++++ docs/source/mcmc.rst | 2 ++ examples/annotation.py | 2 +- examples/baseball.py | 2 +- examples/bnn.py | 2 +- examples/covtype.py | 2 +- examples/funnel.py | 2 +- examples/gp.py | 2 +- examples/hmm.py | 2 +- examples/minipyro.py | 2 +- examples/neutra.py | 2 +- examples/ode.py | 2 +- examples/proportion_test.py | 2 +- examples/sparse_regression.py | 2 +- examples/stochastic_volatility.py | 2 +- examples/ucbadmit.py | 2 +- examples/vae.py | 2 +- notebooks/source/bayesian_regression.ipynb | 2 +- notebooks/source/logistic_regression.ipynb | 2 +- notebooks/source/ordinal_regression.ipynb | 2 +- numpyro/__init__.py | 2 +- numpyro/contrib/tfp/distributions.py | 3 +- numpyro/distributions/discrete.py | 15 +++++---- numpyro/distributions/transforms.py | 13 ++------ numpyro/distributions/util.py | 9 ++---- numpyro/infer/__init__.py | 2 +- numpyro/infer/barker.py | 11 +++---- numpyro/infer/hmc.py | 8 ++--- numpyro/infer/hmc_gibbs.py | 4 +-- numpyro/infer/hmc_util.py | 5 ++- numpyro/infer/mcmc.py | 1 + numpyro/infer/mixed_hmc.py | 5 ++- numpyro/infer/sa.py | 2 +- numpyro/infer/svi.py | 7 +++-- numpyro/util.py | 5 ++- numpyro/version.py | 2 +- scripts/update_version.py | 36 ++++++++++++++++++++++ setup.py | 4 +-- test/contrib/test_tfp.py | 2 +- test/test_util.py | 4 +-- 41 files changed, 109 insertions(+), 79 deletions(-) create mode 100644 scripts/update_version.py diff --git a/README.md b/README.md index 16606aa19..c62cb0b61 100644 --- a/README.md +++ b/README.md @@ -182,7 +182,7 @@ Pyro users will note that the API for model specification and inference is large ## Installation -> **Limited Windows Support:** Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this [JAX issue](https://github.com/google/jax/issues/438) for more details. Alternatively, you can install [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/) and use NumPyro on it as on a Linux system. See also [CUDA on Windows Subsystem for Linux](https://developer.nvidia.com/cuda/wsl) if you want to use GPUs on Windows. +> **Limited Windows Support:** Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this [JAX issue](https://github.com/google/jax/issues/438) for more details. Alternatively, you can install [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/) and use NumPyro on it as on a Linux system. See also [CUDA on Windows Subsystem for Linux](https://developer.nvidia.com/cuda/wsl) and [this forum post](https://forum.pyro.ai/t/numpyro-with-gpu-works-on-windows/2690) if you want to use GPUs on Windows. To install NumPyro with a CPU version of JAX, you can use pip: diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index b3c1ff69c..83bf83e63 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -574,6 +574,14 @@ real_vector ----------- .. autodata:: numpyro.distributions.constraints.real_vector +softplus_positive +----------------- +.. autodata:: numpyro.distributions.constraints.softplus_positive + +softplus_lower_cholesky +----------------------- +.. autodata:: numpyro.distributions.constraints.softplus_lower_cholesky + simplex ------- .. autodata:: numpyro.distributions.constraints.simplex diff --git a/docs/source/mcmc.rst b/docs/source/mcmc.rst index 337b5b792..416e0a3a6 100644 --- a/docs/source/mcmc.rst +++ b/docs/source/mcmc.rst @@ -70,6 +70,8 @@ MCMC Kernels .. autofunction:: numpyro.infer.hmc.hmc.sample_kernel +.. autofunction:: numpyro.infer.hmc_gibbs.taylor_proxy + .. autodata:: numpyro.infer.barker.BarkerMHState .. autodata:: numpyro.infer.hmc.HMCState diff --git a/examples/annotation.py b/examples/annotation.py index 7c9cd6ef4..e14e30eae 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -266,7 +266,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.5.0") + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description="Bayesian Models of Annotation") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/baseball.py b/examples/baseball.py index 9be4206c6..45c5c47f0 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -196,7 +196,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description="Baseball batting average using MCMC") parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int) parser.add_argument("--num-warmup", nargs='?', default=1500, type=int) diff --git a/examples/bnn.py b/examples/bnn.py index 08b68407c..b72c75cfe 100644 --- a/examples/bnn.py +++ b/examples/bnn.py @@ -138,7 +138,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description="Bayesian neural network example") parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int) parser.add_argument("--num-warmup", nargs='?', default=1000, type=int) diff --git a/examples/covtype.py b/examples/covtype.py index 1ebe4220d..e81a62b70 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -139,7 +139,7 @@ def main(args): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-samples', default=1000, type=int, help='number of samples') parser.add_argument('--num-warmup', default=1000, type=int, help='number of warmup steps') diff --git a/examples/funnel.py b/examples/funnel.py index 16305aa7c..add3288f4 100644 --- a/examples/funnel.py +++ b/examples/funnel.py @@ -87,7 +87,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description="Non-centered reparameterization example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs='?', default=1000, type=int) diff --git a/examples/gp.py b/examples/gp.py index bccf22a3d..a64405d83 100644 --- a/examples/gp.py +++ b/examples/gp.py @@ -142,7 +142,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs='?', default=1000, type=int) diff --git a/examples/hmm.py b/examples/hmm.py index c07e15fcf..7be0f19b5 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -191,7 +191,7 @@ def main(args): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description='Semi-supervised Hidden Markov Model') parser.add_argument('--num-categories', default=3, type=int) parser.add_argument('--num-words', default=10, type=int) diff --git a/examples/minipyro.py b/examples/minipyro.py index f59abd2c8..db5ca14e5 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -58,7 +58,7 @@ def body_fn(i, val): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description="Mini Pyro demo") parser.add_argument("-f", "--full-pyro", action="store_true", default=False) parser.add_argument("-n", "--num-steps", default=1001, type=int) diff --git a/examples/neutra.py b/examples/neutra.py index bee94edb0..848179aa2 100644 --- a/examples/neutra.py +++ b/examples/neutra.py @@ -146,7 +146,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description="NeuTra HMC") parser.add_argument('-n', '--num-samples', nargs='?', default=4000, type=int) parser.add_argument('--num-warmup', nargs='?', default=1000, type=int) diff --git a/examples/ode.py b/examples/ode.py index df06a84a7..a4c167a58 100644 --- a/examples/ode.py +++ b/examples/ode.py @@ -103,7 +103,7 @@ def main(args): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description='Predator-Prey Model') parser.add_argument('-n', '--num-samples', nargs='?', default=1000, type=int) parser.add_argument('--num-warmup', nargs='?', default=1000, type=int) diff --git a/examples/proportion_test.py b/examples/proportion_test.py index 3e6e0f5a7..359f9c17d 100644 --- a/examples/proportion_test.py +++ b/examples/proportion_test.py @@ -128,7 +128,7 @@ def main(args): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description='Testing whether ') parser.add_argument('-n', '--num-samples', nargs='?', default=500, type=int) parser.add_argument('--num-warmup', nargs='?', default=1500, type=int) diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index a91c27297..f95135feb 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -320,7 +320,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs='?', default=500, type=int) diff --git a/examples/stochastic_volatility.py b/examples/stochastic_volatility.py index fb480c325..3ce8bbe68 100644 --- a/examples/stochastic_volatility.py +++ b/examples/stochastic_volatility.py @@ -112,7 +112,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description="Stochastic Volatility Model") parser.add_argument('-n', '--num-samples', nargs='?', default=600, type=int) parser.add_argument('--num-warmup', nargs='?', default=600, type=int) diff --git a/examples/ucbadmit.py b/examples/ucbadmit.py index a7fdff6ec..eef87f06d 100644 --- a/examples/ucbadmit.py +++ b/examples/ucbadmit.py @@ -131,7 +131,7 @@ def main(args): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description='UCBadmit gender discrimination using HMC') parser.add_argument('-n', '--num-samples', nargs='?', default=2000, type=int) parser.add_argument('--num-warmup', nargs='?', default=500, type=int) diff --git a/examples/vae.py b/examples/vae.py index 3935fbf3e..4d8ec0b2f 100644 --- a/examples/vae.py +++ b/examples/vae.py @@ -131,7 +131,7 @@ def reconstruct_img(epoch, rng_key): if __name__ == '__main__': - assert numpyro.__version__.startswith('0.5.0') + assert numpyro.__version__.startswith('0.6.0') parser = argparse.ArgumentParser(description="parse args") parser.add_argument('-n', '--num-epochs', default=15, type=int, help='number of training epochs') parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, help='learning rate') diff --git a/notebooks/source/bayesian_regression.ipynb b/notebooks/source/bayesian_regression.ipynb index 329fa6801..d73afefec 100644 --- a/notebooks/source/bayesian_regression.ipynb +++ b/notebooks/source/bayesian_regression.ipynb @@ -66,7 +66,7 @@ "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats('svg')\n", "\n", - "assert numpyro.__version__.startswith('0.5.0')" + "assert numpyro.__version__.startswith('0.6.0')" ] }, { diff --git a/notebooks/source/logistic_regression.ipynb b/notebooks/source/logistic_regression.ipynb index 901a0da06..a3902514c 100644 --- a/notebooks/source/logistic_regression.ipynb +++ b/notebooks/source/logistic_regression.ipynb @@ -31,7 +31,7 @@ "import numpyro.distributions as dist\n", "from numpyro.examples.datasets import COVTYPE, load_dataset\n", "from numpyro.infer import HMC, MCMC, NUTS\n", - "assert numpyro.__version__.startswith('0.5.0')\n", + "assert numpyro.__version__.startswith('0.6.0')\n", "\n", "# NB: replace gpu by cpu to run this notebook in cpu\n", "numpyro.set_platform(\"gpu\")" diff --git a/notebooks/source/ordinal_regression.ipynb b/notebooks/source/ordinal_regression.ipynb index 8b1ac5b74..14ff1a699 100644 --- a/notebooks/source/ordinal_regression.ipynb +++ b/notebooks/source/ordinal_regression.ipynb @@ -30,7 +30,7 @@ "from numpyro.infer import MCMC, NUTS\n", "import pandas as pd\n", "import seaborn as sns\n", - "assert numpyro.__version__.startswith('0.5.0')" + "assert numpyro.__version__.startswith('0.6.0')" ] }, { diff --git a/numpyro/__init__.py b/numpyro/__init__.py index 018b34cb5..91970b71b 100644 --- a/numpyro/__init__.py +++ b/numpyro/__init__.py @@ -14,7 +14,7 @@ plate_stack, prng_key, sample, - subsample, + subsample ) from numpyro.util import enable_x64, set_host_device_count, set_platform from numpyro.version import __version__ diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index 5714099a3..cb5ce0a0d 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -3,7 +3,6 @@ import numpy as np -from jax.dtypes import canonicalize_dtype import jax.numpy as jnp from tensorflow_probability.substrates.jax import bijectors as tfb from tensorflow_probability.substrates.jax import distributions as tfd @@ -162,7 +161,7 @@ class OneHotCategorical(tfd.OneHotCategorical, TFPDistributionMixin): def enumerate_support(self, expand=True): n = self.event_shape[-1] - values = jnp.identity(n, dtype=canonicalize_dtype(self.dtype)) + values = jnp.identity(n, dtype=jnp.result_type(self.dtype)) values = values.reshape((n,) + (1,) * len(self.batch_shape) + (n,)) if expand: values = jnp.broadcast_to(values, (n,) + self.batch_shape + (n,)) diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index 50790c14f..76bf5ad82 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -42,7 +42,6 @@ binomial, categorical, clamp_probs, - get_dtype, is_prng_key, lazy_property, multinomial, @@ -66,7 +65,7 @@ def _to_probs_multinom(logits): def _to_logits_multinom(probs): - minval = jnp.finfo(get_dtype(probs)).min + minval = jnp.finfo(jnp.result_type(probs)).min return jnp.clip(jnp.log(probs), a_min=minval) @@ -292,11 +291,11 @@ def logits(self): @property def mean(self): - return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.probs)) + return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.probs)) @property def variance(self): - return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.probs)) + return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.probs)) @constraints.dependent_property(is_discrete=True, event_dim=0) def support(self): @@ -340,11 +339,11 @@ def probs(self): @property def mean(self): - return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.logits)) + return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.logits)) @property def variance(self): - return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.logits)) + return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.logits)) @constraints.dependent_property(is_discrete=True, event_dim=0) def support(self): @@ -609,7 +608,7 @@ def __init__(self, probs, validate_args=None): def sample(self, key, sample_shape=()): assert is_prng_key(key) probs = self.probs - dtype = get_dtype(probs) + dtype = jnp.result_type(probs) shape = sample_shape + self.batch_shape u = random.uniform(key, shape, dtype) return jnp.floor(jnp.log1p(-u) / jnp.log1p(-probs)) @@ -649,7 +648,7 @@ def probs(self): def sample(self, key, sample_shape=()): assert is_prng_key(key) logits = self.logits - dtype = get_dtype(logits) + dtype = jnp.result_type(logits) shape = sample_shape + self.batch_shape u = random.uniform(key, shape, dtype) return jnp.floor(jnp.log1p(-u) / -softplus(logits)) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 0f769b547..f71af1f58 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -8,7 +8,6 @@ import numpy as np from jax import lax, ops, tree_flatten, tree_map, vmap -from jax.dtypes import canonicalize_dtype from jax.flatten_util import ravel_pytree from jax.nn import softplus import jax.numpy as jnp @@ -16,13 +15,7 @@ from jax.scipy.special import expit, logit from numpyro.distributions import constraints -from numpyro.distributions.util import ( - get_dtype, - matrix_to_tril_vec, - signed_stick_breaking_tril, - sum_rightmost, - vec_to_tril_matrix -) +from numpyro.distributions.util import matrix_to_tril_vec, signed_stick_breaking_tril, sum_rightmost, vec_to_tril_matrix from numpyro.util import not_jax_tracer __all__ = [ @@ -51,7 +44,7 @@ def _clipped_expit(x): - finfo = jnp.finfo(get_dtype(x)) + finfo = jnp.finfo(jnp.result_type(x)) return jnp.clip(expit(x), a_min=finfo.tiny, a_max=1. - finfo.eps) @@ -654,7 +647,7 @@ def __call__(self, x): def _inverse(self, y): size = self.permutation.size - permutation_inv = ops.index_update(jnp.zeros(size, dtype=canonicalize_dtype(jnp.int64)), + permutation_inv = ops.index_update(jnp.zeros(size, dtype=jnp.result_type(int)), self.permutation, jnp.arange(size)) return y[..., permutation_inv] diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 9601a77bb..79aaeab06 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -8,7 +8,6 @@ import numpy as np from jax import jit, lax, random, vmap -from jax.dtypes import canonicalize_dtype from jax.lib import xla_bridge import jax.numpy as jnp from jax.scipy.linalg import solve_triangular @@ -253,10 +252,6 @@ def promote_shapes(*args, shape=()): if len(s) < num_dims else arg for arg, s in zip(args, shapes)] -def get_dtype(x): - return canonicalize_dtype(lax.dtype(x)) - - def sum_rightmost(x, dim): """ Sum out ``dim`` many rightmost dimensions of a given tensor. @@ -351,7 +346,7 @@ def logmatmulexp(x, y): def clamp_probs(probs): - finfo = jnp.finfo(get_dtype(probs)) + finfo = jnp.finfo(jnp.result_type(probs)) return jnp.clip(probs, a_min=finfo.tiny, a_max=1. - finfo.eps) @@ -392,7 +387,7 @@ def von_mises_centered(key, concentration, shape=(), dtype=jnp.float64): :return: centered samples from von Mises """ shape = shape or jnp.shape(concentration) - dtype = canonicalize_dtype(dtype) + dtype = jnp.result_type(dtype) concentration = lax.convert_element_type(concentration, dtype) concentration = jnp.broadcast_to(concentration, shape) return _von_mises_centered(key, concentration, shape, dtype) diff --git a/numpyro/infer/__init__.py b/numpyro/infer/__init__.py index 665c43e87..66857a33a 100644 --- a/numpyro/infer/__init__.py +++ b/numpyro/infer/__init__.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from numpyro.infer.barker import BarkerMH from numpyro.infer.elbo import ELBO, RenyiELBO, Trace_ELBO, TraceMeanField_ELBO from numpyro.infer.hmc import HMC, NUTS from numpyro.infer.hmc_gibbs import HMCECS, DiscreteHMCGibbs, HMCGibbs @@ -11,7 +12,6 @@ init_to_uniform, init_to_value ) -from numpyro.infer.barker import BarkerMH from numpyro.infer.mcmc import MCMC from numpyro.infer.mixed_hmc import MixedHMC from numpyro.infer.sa import SA diff --git a/numpyro/infer/barker.py b/numpyro/infer/barker.py index 6e910971d..e4f79f91f 100644 --- a/numpyro/infer/barker.py +++ b/numpyro/infer/barker.py @@ -6,16 +6,15 @@ import jax from jax import random from jax.flatten_util import ravel_pytree +from jax.nn import softplus import jax.numpy as jnp from jax.scipy.special import expit -from jax.nn import softplus from numpyro.infer.hmc_util import warmup_adapter +from numpyro.infer.initialization import init_to_uniform +from numpyro.infer.mcmc import MCMCKernel from numpyro.infer.util import initialize_model from numpyro.util import identity -from numpyro.infer import init_to_uniform -from numpyro.infer.mcmc import MCMCKernel - BarkerMHState = namedtuple("BarkerMHState", [ "i", "z", "potential_energy", "z_grad", "accept_prob", "mean_accept_prob", "adapt_state", "rng_key"]) @@ -163,8 +162,8 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): size = len(ravel_pytree(init_params)[0]) wa_state = wa_init(None, rng_key_wa, self._step_size, mass_matrix_size=size) wa_state = wa_state._replace(rng_key=None) - init_state = BarkerMHState(jnp.array(0), init_params, pe, grad, jnp.array(0.), - jnp.array(0.), wa_state, rng_key) + init_state = BarkerMHState(jnp.array(0), init_params, pe, grad, jnp.zeros(()), + jnp.zeros(()), wa_state, rng_key) return jax.device_put(init_state) def postprocess_fn(self, args, kwargs): diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index e81923768..985ed3720 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -6,7 +6,6 @@ import os from jax import device_put, lax, partial, random, vmap -from jax.dtypes import canonicalize_dtype from jax.flatten_util import ravel_pytree import jax.numpy as jnp @@ -64,7 +63,7 @@ def _get_num_steps(step_size, trajectory_length): num_steps = jnp.ceil(trajectory_length / step_size) # NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead) # if jax_enable_x64 is False - return num_steps.astype(canonicalize_dtype(jnp.int64)) + return num_steps.astype(jnp.result_type(int)) def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key): @@ -208,7 +207,8 @@ def init_kernel(init_params, randomness. """ - step_size = lax.convert_element_type(step_size, canonicalize_dtype(jnp.float64)) + step_size = lax.convert_element_type(step_size, jnp.result_type(float)) + trajectory_length = lax.convert_element_type(trajectory_length, jnp.result_type(float)) nonlocal wa_update, max_treedepth, vv_update, wa_steps, forward_mode_ad forward_mode_ad = forward_mode_differentiation wa_steps = num_warmup @@ -250,7 +250,7 @@ def init_kernel(init_params, energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(jnp.array(0), vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, None, trajectory_length, - jnp.array(0), jnp.array(0.), jnp.array(0.), jnp.array(False), wa_state, rng_key_hmc) + jnp.array(0), jnp.zeros(()), jnp.zeros(()), jnp.array(False), wa_state, rng_key_hmc) return device_put(hmc_state) def _hmc_next(step_size, inverse_mass_matrix, vv_state, diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index e5bc9e05a..7e92fd7ab 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -478,7 +478,7 @@ class HMCECS(HMCGibbs): :param inner_kernel: One of :class:`~numpyro.infer.hmc.HMC` or :class:`~numpyro.infer.hmc.NUTS`. :param int num_blocks: Number of blocks to partition subsample into. - :param proxy: Either :function `~numpyro.infer.hmc_gibbs.taylor_proxy` for likelihood estimation, + :param proxy: Either :func:`~numpyro.infer.hmc_gibbs.taylor_proxy` for likelihood estimation, or, None for naive (in-between trajectory) subsampling as outlined in [4]. **Example** @@ -554,7 +554,7 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): model_kwargs["_gibbs_state"] = gibbs_state state = super().init(rng_key, num_warmup, init_params, model_args, model_kwargs) - return HMCECSState(state.z, state.hmc_state, state.rng_key, gibbs_state, jnp.array(0.)) + return HMCECSState(state.z, state.hmc_state, state.rng_key, gibbs_state, jnp.zeros(())) def sample(self, state, model_args, model_kwargs): model_kwargs = {} if model_kwargs is None else model_kwargs.copy() diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py index bed612478..09370f7a2 100644 --- a/numpyro/infer/hmc_util.py +++ b/numpyro/infer/hmc_util.py @@ -12,7 +12,6 @@ from jax.tree_util import tree_flatten, tree_map, tree_multimap import numpyro.distributions as dist -from numpyro.distributions.util import get_dtype from numpyro.util import cond, identity, while_loop AdaptWindow = namedtuple('AdaptWindow', ['start', 'end']) @@ -265,7 +264,7 @@ def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator, z, _, potential_energy, z_grad = z_info if potential_energy is None or z_grad is None: potential_energy, z_grad = value_and_grad(potential_fn)(z) - finfo = jnp.finfo(get_dtype(init_step_size)) + finfo = jnp.finfo(jnp.result_type(init_step_size)) def _body_fn(state): step_size, _, direction, rng_key = state @@ -459,7 +458,7 @@ def update_fn(t, accept_prob, z_info, state): jnp.exp(log_step_size_avg), jnp.exp(log_step_size)) # account the the case log_step_size is an extreme number - finfo = jnp.finfo(get_dtype(step_size)) + finfo = jnp.finfo(jnp.result_type(step_size)) step_size = jnp.clip(step_size, a_min=finfo.tiny, a_max=finfo.max) # update mass matrix state diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 8535551ac..a992ee0ba 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -470,6 +470,7 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): See https://jax.readthedocs.io/en/latest/async_dispatch.html and https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs. """ + init_params = tree_map(lambda x: lax.convert_element_type(x, jnp.result_type(x)), init_params) self._args = args self._kwargs = kwargs init_state = self._get_cached_init_state(rng_key, args, kwargs) diff --git a/numpyro/infer/mixed_hmc.py b/numpyro/infer/mixed_hmc.py index 282372f28..57a173a89 100644 --- a/numpyro/infer/mixed_hmc.py +++ b/numpyro/infer/mixed_hmc.py @@ -4,7 +4,7 @@ from collections import namedtuple from functools import partial -from jax import jacfwd, grad, lax, ops, random +from jax import grad, jacfwd, lax, ops, random import jax.numpy as jnp from numpyro.infer.hmc import momentum_generator @@ -12,7 +12,6 @@ from numpyro.infer.hmc_util import euclidean_kinetic_energy, warmup_adapter from numpyro.util import cond, fori_loop, identity, ravel_pytree - MixedHMCState = namedtuple("MixedHMCState", "z, hmc_state, rng_key, accept_prob") @@ -94,7 +93,7 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): # In HMC, when `hmc_state.r` is not None, we will skip drawing a random momemtum at the # beginning of an HMC step. The reason is we need to maintain `r` between each sub-trajectories. r = momentum_generator(state.hmc_state.z, state.hmc_state.adapt_state.mass_matrix_sqrt, rng_r) - return MixedHMCState(state.z, state.hmc_state._replace(r=r), state.rng_key, jnp.array(0.)) + return MixedHMCState(state.z, state.hmc_state._replace(r=r), state.rng_key, jnp.zeros(())) def sample(self, state, model_args, model_kwargs): model_kwargs = {} if model_kwargs is None else model_kwargs diff --git a/numpyro/infer/sa.py b/numpyro/infer/sa.py index 1bb6eb389..4b8525289 100644 --- a/numpyro/infer/sa.py +++ b/numpyro/infer/sa.py @@ -135,7 +135,7 @@ def init_kernel(init_params, k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0])) z = unravel_fn(zs[k]) pe = pes[k] - sa_state = SAState(jnp.array(0), z, pe, jnp.array(0.), jnp.array(0.), jnp.array(False), + sa_state = SAState(jnp.array(0), z, pe, jnp.zeros(()), jnp.zeros(()), jnp.array(False), adapt_state, rng_key_sa) return device_put(sa_state) diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index c8f1972b3..fbfe80cbd 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -5,7 +5,7 @@ import tqdm -from jax import jit, lax, random +from jax import jit, lax, random, tree_map import jax.numpy as jnp from numpyro.distributions import constraints @@ -116,6 +116,9 @@ def init(self, rng_key, *args, **kwargs): params[site['name']] = transform.inv(site['value']) self.constrain_fn = partial(transform_fn, inv_transforms) + # we convert weak types like float to float32/float64 + # to avoid recompiling body_fn in svi.run + params = tree_map(lambda x: lax.convert_element_type(x, jnp.result_type(x)), params) return SVIState(self.optim.init(params), rng_key) def get_params(self, svi_state): @@ -188,7 +191,7 @@ def run(self, rng_key, num_steps, *args, progress_bar=True, stable_update=False, and `losses` is the collected loss during the process. :rtype: SVIRunResult """ - def body_fn(svi_state, carry): + def body_fn(svi_state, _): if stable_update: svi_state, loss = self.stable_update(svi_state, *args, **kwargs) else: diff --git a/numpyro/util.py b/numpyro/util.py index 150000054..8b247e0b6 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -14,10 +14,9 @@ import jax from jax import device_put, jit, lax, ops, vmap from jax.core import Tracer -from jax.dtypes import canonicalize_dtype +from jax.experimental import host_callback import jax.numpy as jnp from jax.tree_util import tree_flatten, tree_map, tree_unflatten -from jax.experimental import host_callback _DISABLE_CONTROL_FLOW_PRIM = False @@ -325,7 +324,7 @@ def _body_fn(i, vals): def _ravel_list(*leaves): leaves_metadata = tree_map(lambda l: pytree_metadata( - jnp.ravel(l), jnp.shape(l), jnp.size(l), canonicalize_dtype(lax.dtype(l))), leaves) + jnp.ravel(l), jnp.shape(l), jnp.size(l), jnp.result_type(l)), leaves) leaves_idx = jnp.cumsum(jnp.array((0,) + tuple(d.size for d in leaves_metadata))) def unravel_list(arr): diff --git a/numpyro/version.py b/numpyro/version.py index f86a79a48..48baadee9 100644 --- a/numpyro/version.py +++ b/numpyro/version.py @@ -1,4 +1,4 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -__version__ = '0.5.0' +__version__ = '0.6.0' diff --git a/scripts/update_version.py b/scripts/update_version.py new file mode 100644 index 000000000..5cb590545 --- /dev/null +++ b/scripts/update_version.py @@ -0,0 +1,36 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import glob +import os +import re + +root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +# Get new version. +with open(os.path.join(root, "numpyro", "version.py")) as f: + for line in f: + if line.startswith("__version__ ="): + new_version = line.strip().split()[-1] + +# Collect potential files. +filenames = [] +for path in ["examples", "notebooks/source"]: + for ext in ["*.py", "*.ipynb"]: + filenames.extend(glob.glob(os.path.join(root, path, "**", ext), + recursive=True)) +filenames.sort() + +# Update version string. +pattern1 = re.compile("assert numpyro.__version__.startswith\\(\"[^\"]*\"\\)") +pattern2 = re.compile("assert numpyro.__version__.startswith\\('[^']*'\\)") +text = f"assert numpyro.__version__.startswith({new_version})" +for filename in filenames: + with open(filename) as f: + old_text = f.read() + new_text = pattern1.sub(text, old_text) + new_text = pattern2.sub(text, new_text) + if new_text != old_text: + print("updating {}".format(filename)) + with open(filename, "w") as f: + f.write(new_text) diff --git a/setup.py b/setup.py index 4ad1c9adf..452d2eee6 100644 --- a/setup.py +++ b/setup.py @@ -33,9 +33,9 @@ author='Uber AI Labs', install_requires=[ # TODO: pin to a specific version for the release (until JAX's API becomes stable) - 'jax>=0.2.8', + 'jax==0.2.10', # check min version here: https://github.com/google/jax/blob/master/jax/lib/__init__.py#L26 - 'jaxlib>=0.1.59', + 'jaxlib==0.1.62', 'tqdm', ], extras_require={ diff --git a/test/contrib/test_tfp.py b/test/contrib/test_tfp.py index 545068cd7..09677d25d 100644 --- a/test/contrib/test_tfp.py +++ b/test/contrib/test_tfp.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import inspect -from numpyro.distributions.transforms import AffineTransform import os from numpy.testing import assert_allclose @@ -13,6 +12,7 @@ import numpyro import numpyro.distributions as dist +from numpyro.distributions.transforms import AffineTransform from numpyro.infer import MCMC, NUTS from numpyro.infer.reparam import TransformReparam diff --git a/test/test_util.py b/test/test_util.py index 3bf38a832..2d2bf75d2 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -4,8 +4,6 @@ from numpy.testing import assert_allclose import pytest -from jax import lax -from jax.dtypes import canonicalize_dtype import jax.numpy as jnp from jax.test_util import check_eq from jax.tree_util import tree_flatten, tree_multimap @@ -73,7 +71,7 @@ def test_ravel_pytree(pytree): unravel = unravel_fn(flat) tree_flatten(tree_multimap(lambda x, y: assert_allclose(x, y), unravel, pytree)) assert all(tree_flatten(tree_multimap(lambda x, y: - canonicalize_dtype(lax.dtype(x)) == canonicalize_dtype(lax.dtype(y)), + jnp.result_type(x) == jnp.result_type(y), unravel, pytree))[0])