diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bcd250abe..f8a55e98c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,10 +26,10 @@ jobs: run: | sudo apt install -y pandoc gsfonts python -m pip install --upgrade pip - pip install https://github.com/pyro-ppl/funsor/archive/master.zip pip install jaxlib pip install jax pip install .[doc,test] + pip install https://github.com/pyro-ppl/funsor/archive/master.zip pip install -r docs/requirements.txt pip freeze - name: Lint with flake8 @@ -64,10 +64,10 @@ jobs: python -m pip install --upgrade pip # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install https://github.com/pyro-ppl/funsor/archive/master.zip pip install jaxlib pip install jax pip install .[dev,test] + pip install https://github.com/pyro-ppl/funsor/archive/master.zip pip freeze - name: Test with pytest run: | @@ -93,10 +93,10 @@ jobs: python -m pip install --upgrade pip # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install https://github.com/pyro-ppl/funsor/archive/master.zip pip install jaxlib pip install jax pip install .[dev,test] + pip install https://github.com/pyro-ppl/funsor/archive/master.zip pip freeze - name: Test with pytest run: | @@ -129,10 +129,10 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install https://github.com/pyro-ppl/funsor/archive/master.zip pip install jaxlib pip install jax pip install .[dev,examples,test] + pip install https://github.com/pyro-ppl/funsor/archive/master.zip pip freeze - name: Test with pytest run: | diff --git a/docs/source/optimizers.rst b/docs/source/optimizers.rst index 68e2bb0ab..45d792256 100644 --- a/docs/source/optimizers.rst +++ b/docs/source/optimizers.rst @@ -68,4 +68,4 @@ SM3 Optax support ------------- -.. autofunction:: numpyro.contrib.optim.optax_to_numpyro +.. autofunction:: numpyro.optim.optax_to_numpyro diff --git a/examples/annotation.py b/examples/annotation.py index 6dad619db..b8a900bb3 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -351,7 +351,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Bayesian Models of Annotation") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/ar2.py b/examples/ar2.py index 0e61d46fb..a99832a11 100644 --- a/examples/ar2.py +++ b/examples/ar2.py @@ -116,7 +116,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="AR2 example") parser.add_argument("--num-data", nargs="?", default=142, type=int) parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) diff --git a/examples/baseball.py b/examples/baseball.py index cc5768e4c..e9a36899e 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -210,7 +210,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Baseball batting average using MCMC") parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1500, type=int) diff --git a/examples/bnn.py b/examples/bnn.py index b67aefb3d..63b39c34c 100644 --- a/examples/bnn.py +++ b/examples/bnn.py @@ -160,7 +160,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Bayesian neural network example") parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/capture_recapture.py b/examples/capture_recapture.py index 6d4382500..87d78e214 100644 --- a/examples/capture_recapture.py +++ b/examples/capture_recapture.py @@ -72,7 +72,11 @@ def transition_fn(carry, y): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. - z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) + z = numpyro.sample( + "z", + dist.Bernoulli(dist.util.clamp_probs(mu_z_t)), + infer={"enumerate": "parallel"}, + ) mu_y_t = rho * z numpyro.sample( "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y @@ -112,7 +116,11 @@ def transition_fn(carry, y): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. - z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) + z = numpyro.sample( + "z", + dist.Bernoulli(dist.util.clamp_probs(mu_z_t)), + infer={"enumerate": "parallel"}, + ) mu_y_t = rho * z numpyro.sample( "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y @@ -160,7 +168,11 @@ def transition_fn(carry, y): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. - z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) + z = numpyro.sample( + "z", + dist.Bernoulli(dist.util.clamp_probs(mu_z_t)), + infer={"enumerate": "parallel"}, + ) mu_y_t = rho * z numpyro.sample( "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y @@ -202,7 +214,11 @@ def transition_fn(carry, y): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. - z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) + z = numpyro.sample( + "z", + dist.Bernoulli(dist.util.clamp_probs(mu_z_t)), + infer={"enumerate": "parallel"}, + ) mu_y_t = rho * z numpyro.sample( "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y @@ -249,7 +265,11 @@ def transition_fn(carry, y): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. - z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) + z = numpyro.sample( + "z", + dist.Bernoulli(dist.util.clamp_probs(mu_z_t)), + infer={"enumerate": "parallel"}, + ) mu_y_t = rho * z numpyro.sample( "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y diff --git a/examples/covtype.py b/examples/covtype.py index db8d8dab1..e79d40d2c 100644 --- a/examples/covtype.py +++ b/examples/covtype.py @@ -206,7 +206,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="parse args") parser.add_argument( "-n", "--num-samples", default=1000, type=int, help="number of samples" diff --git a/examples/funnel.py b/examples/funnel.py index b350964c9..00fda3990 100644 --- a/examples/funnel.py +++ b/examples/funnel.py @@ -139,7 +139,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser( description="Non-centered reparameterization example" ) diff --git a/examples/gaussian_shells.py b/examples/gaussian_shells.py index 973412a53..0155e7e4f 100644 --- a/examples/gaussian_shells.py +++ b/examples/gaussian_shells.py @@ -81,7 +81,7 @@ def run_inference(args, data): num_warmup=args.num_warmup, num_samples=args.num_samples, ) - mcmc.run(random.PRNGKey(2), **data) + mcmc.run(random.PRNGKey(2), **data, enum=args.enum) mcmc.print_summary() mcmc_samples = mcmc.get_samples() @@ -123,7 +123,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells") parser.add_argument("-n", "--num-samples", nargs="?", default=10000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/gp.py b/examples/gp.py index fbd3e6e74..731cf1107 100644 --- a/examples/gp.py +++ b/examples/gp.py @@ -170,7 +170,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/hmm.py b/examples/hmm.py index b0490b001..472a0d70c 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -263,7 +263,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Semi-supervised Hidden Markov Model") parser.add_argument("--num-categories", default=3, type=int) parser.add_argument("--num-words", default=10, type=int) diff --git a/examples/hmm_enum.py b/examples/hmm_enum.py index 62c6f39bb..6ec0e3f69 100644 --- a/examples/hmm_enum.py +++ b/examples/hmm_enum.py @@ -94,7 +94,11 @@ def transition_fn(carry, y): x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): - x = numpyro.sample("x", dist.Categorical(probs_x[x_prev])) + x = numpyro.sample( + "x", + dist.Categorical(probs_x[x_prev]), + infer={"enumerate": "parallel"}, + ) with numpyro.plate("tones", data_dim, dim=-1): numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y) return (x, t + 1), None @@ -127,7 +131,11 @@ def transition_fn(carry, y): x_prev, y_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): - x = numpyro.sample("x", dist.Categorical(probs_x[x_prev])) + x = numpyro.sample( + "x", + dist.Categorical(probs_x[x_prev]), + infer={"enumerate": "parallel"}, + ) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. @@ -175,8 +183,16 @@ def transition_fn(carry, y): w_prev, x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): - w = numpyro.sample("w", dist.Categorical(probs_w[w_prev])) - x = numpyro.sample("x", dist.Categorical(probs_x[x_prev])) + w = numpyro.sample( + "w", + dist.Categorical(probs_w[w_prev]), + infer={"enumerate": "parallel"}, + ) + x = numpyro.sample( + "x", + dist.Categorical(probs_x[x_prev]), + infer={"enumerate": "parallel"}, + ) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. @@ -224,8 +240,16 @@ def transition_fn(carry, y): w_prev, x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): - w = numpyro.sample("w", dist.Categorical(probs_w[w_prev])) - x = numpyro.sample("x", dist.Categorical(Vindex(probs_x)[w, x_prev])) + w = numpyro.sample( + "w", + dist.Categorical(probs_w[w_prev]), + infer={"enumerate": "parallel"}, + ) + x = numpyro.sample( + "x", + dist.Categorical(Vindex(probs_x)[w, x_prev]), + infer={"enumerate": "parallel"}, + ) with numpyro.plate("tones", data_dim, dim=-1) as tones: numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y) return (w, x, t + 1), None @@ -275,7 +299,7 @@ def transition_fn(carry, y): with mask(mask=(t < lengths)[..., None]): probs_x_t = Vindex(probs_x)[x_prev, x_curr] x_prev, x_curr = x_curr, numpyro.sample( - "x", dist.Categorical(probs_x_t) + "x", dist.Categorical(probs_x_t), infer={"enumerate": "parallel"} ) with numpyro.plate("tones", data_dim, dim=-1): probs_y_t = probs_y[x_curr.squeeze(-1)] diff --git a/examples/holt_winters.py b/examples/holt_winters.py index 49f78eebd..ddc8ea2db 100644 --- a/examples/holt_winters.py +++ b/examples/holt_winters.py @@ -180,7 +180,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Holt-Winters") parser.add_argument("--T", nargs="?", default=6, type=int) parser.add_argument("--future", nargs="?", default=1, type=int) diff --git a/examples/horseshoe_regression.py b/examples/horseshoe_regression.py index 3b5daa581..8bc8852d5 100644 --- a/examples/horseshoe_regression.py +++ b/examples/horseshoe_regression.py @@ -162,7 +162,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Horseshoe regression example") parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/hsgp.py b/examples/hsgp.py index cf8454010..bda4447fe 100644 --- a/examples/hsgp.py +++ b/examples/hsgp.py @@ -108,7 +108,7 @@ def get_thanksgiving_days(dates): def get_floating_days_indicators(dates): def encode(x): - return jnp.array(x.values, dtype=jnp.int64) + return jnp.array(x.values, dtype=jnp.result_type(int)) return { "labour_days_indicator": encode(get_labour_days(dates)), diff --git a/examples/minipyro.py b/examples/minipyro.py index 7ec79aec6..41bcf802d 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -58,7 +58,7 @@ def body_fn(i, val): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Mini Pyro demo") parser.add_argument("-f", "--full-pyro", action="store_true", default=False) parser.add_argument("-n", "--num-steps", default=1001, type=int) diff --git a/examples/neutra.py b/examples/neutra.py index ba96e410d..af50e54ad 100644 --- a/examples/neutra.py +++ b/examples/neutra.py @@ -197,7 +197,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="NeuTra HMC") parser.add_argument("-n", "--num-samples", nargs="?", default=4000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/ode.py b/examples/ode.py index 46b40d2d9..252567762 100644 --- a/examples/ode.py +++ b/examples/ode.py @@ -117,7 +117,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Predator-Prey Model") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) diff --git a/examples/prodlda.py b/examples/prodlda.py index f232e74a2..24ad34eb8 100644 --- a/examples/prodlda.py +++ b/examples/prodlda.py @@ -314,7 +314,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser( description="Probabilistic topic modelling with Flax and Haiku" ) diff --git a/examples/proportion_test.py b/examples/proportion_test.py index 305b757da..f5e0fdf35 100644 --- a/examples/proportion_test.py +++ b/examples/proportion_test.py @@ -160,7 +160,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Testing whether ") parser.add_argument("-n", "--num-samples", nargs="?", default=500, type=int) parser.add_argument("--num-warmup", nargs="?", default=1500, type=int) diff --git a/examples/sparse_regression.py b/examples/sparse_regression.py index 3317365a1..ab735e6a1 100644 --- a/examples/sparse_regression.py +++ b/examples/sparse_regression.py @@ -384,7 +384,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Gaussian Process example") parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) parser.add_argument("--num-warmup", nargs="?", default=500, type=int) diff --git a/examples/stochastic_volatility.py b/examples/stochastic_volatility.py index 4e68a8377..058750b03 100644 --- a/examples/stochastic_volatility.py +++ b/examples/stochastic_volatility.py @@ -122,7 +122,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Stochastic Volatility Model") parser.add_argument("-n", "--num-samples", nargs="?", default=600, type=int) parser.add_argument("--num-warmup", nargs="?", default=600, type=int) diff --git a/examples/thompson_sampling.py b/examples/thompson_sampling.py index 7389e2a1c..16c972a12 100644 --- a/examples/thompson_sampling.py +++ b/examples/thompson_sampling.py @@ -294,7 +294,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="Thompson sampling example") parser.add_argument( "--num-random", nargs="?", default=2, type=int, help="number of random draws" diff --git a/examples/ucbadmit.py b/examples/ucbadmit.py index d575bde74..cf4094781 100644 --- a/examples/ucbadmit.py +++ b/examples/ucbadmit.py @@ -151,7 +151,7 @@ def main(args): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser( description="UCBadmit gender discrimination using HMC" ) diff --git a/examples/vae.py b/examples/vae.py index 27c92c82b..3d37e2a60 100644 --- a/examples/vae.py +++ b/examples/vae.py @@ -160,7 +160,7 @@ def reconstruct_img(epoch, rng_key): if __name__ == "__main__": - assert numpyro.__version__.startswith("0.8.0") + assert numpyro.__version__.startswith("0.9.0") parser = argparse.ArgumentParser(description="parse args") parser.add_argument( "-n", "--num-epochs", default=15, type=int, help="number of training epochs" diff --git a/notebooks/source/bad_posterior_geometry.ipynb b/notebooks/source/bad_posterior_geometry.ipynb index 4c970988b..0594083e4 100644 --- a/notebooks/source/bad_posterior_geometry.ipynb +++ b/notebooks/source/bad_posterior_geometry.ipynb @@ -50,7 +50,7 @@ "\n", "from numpyro.infer import MCMC, NUTS\n", "\n", - "assert numpyro.__version__.startswith(\"0.8.0\")\n", + "assert numpyro.__version__.startswith(\"0.9.0\")\n", "\n", "# NB: replace cpu by gpu to run this notebook on gpu\n", "numpyro.set_platform(\"cpu\")" diff --git a/notebooks/source/bayesian_hierarchical_linear_regression.ipynb b/notebooks/source/bayesian_hierarchical_linear_regression.ipynb index cbb52e0b1..afa7a5970 100644 --- a/notebooks/source/bayesian_hierarchical_linear_regression.ipynb +++ b/notebooks/source/bayesian_hierarchical_linear_regression.ipynb @@ -244,7 +244,7 @@ "import numpyro.distributions as dist\n", "from jax import random\n", "\n", - "assert numpyro.__version__.startswith(\"0.8.0\")" + "assert numpyro.__version__.startswith(\"0.9.0\")" ] }, { diff --git a/notebooks/source/bayesian_hierarchical_stacking.ipynb b/notebooks/source/bayesian_hierarchical_stacking.ipynb index 25bf9153f..c3da0b956 100644 --- a/notebooks/source/bayesian_hierarchical_stacking.ipynb +++ b/notebooks/source/bayesian_hierarchical_stacking.ipynb @@ -96,7 +96,7 @@ " set_matplotlib_formats(\"svg\")\n", "\n", "numpyro.set_host_device_count(4)\n", - "assert numpyro.__version__.startswith(\"0.8.0\")" + "assert numpyro.__version__.startswith(\"0.9.0\")" ] }, { diff --git a/notebooks/source/bayesian_imputation.ipynb b/notebooks/source/bayesian_imputation.ipynb index 718b2749a..0a6971b03 100644 --- a/notebooks/source/bayesian_imputation.ipynb +++ b/notebooks/source/bayesian_imputation.ipynb @@ -55,7 +55,7 @@ "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats(\"svg\")\n", "\n", - "assert numpyro.__version__.startswith(\"0.8.0\")" + "assert numpyro.__version__.startswith(\"0.9.0\")" ] }, { diff --git a/notebooks/source/bayesian_regression.ipynb b/notebooks/source/bayesian_regression.ipynb index 4aedf2d0d..147c51333 100644 --- a/notebooks/source/bayesian_regression.ipynb +++ b/notebooks/source/bayesian_regression.ipynb @@ -95,7 +95,7 @@ "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats(\"svg\")\n", "\n", - "assert numpyro.__version__.startswith(\"0.8.0\")" + "assert numpyro.__version__.startswith(\"0.9.0\")" ], "execution_count": 2, "outputs": [] diff --git a/notebooks/source/logistic_regression.ipynb b/notebooks/source/logistic_regression.ipynb index 8bef90e16..872d31765 100644 --- a/notebooks/source/logistic_regression.ipynb +++ b/notebooks/source/logistic_regression.ipynb @@ -41,7 +41,7 @@ "from numpyro.examples.datasets import COVTYPE, load_dataset\n", "from numpyro.infer import HMC, MCMC, NUTS\n", "\n", - "assert numpyro.__version__.startswith(\"0.8.0\")\n", + "assert numpyro.__version__.startswith(\"0.9.0\")\n", "\n", "# NB: replace gpu by cpu to run this notebook in cpu\n", "numpyro.set_platform(\"gpu\")" diff --git a/notebooks/source/model_rendering.ipynb b/notebooks/source/model_rendering.ipynb index d74a1d365..2ecf75b01 100644 --- a/notebooks/source/model_rendering.ipynb +++ b/notebooks/source/model_rendering.ipynb @@ -35,7 +35,7 @@ "import numpyro\n", "import numpyro.distributions as dist\n", "\n", - "assert numpyro.__version__.startswith(\"0.8.0\")" + "assert numpyro.__version__.startswith(\"0.9.0\")" ] }, { diff --git a/notebooks/source/ordinal_regression.ipynb b/notebooks/source/ordinal_regression.ipynb index be2d2aecb..a3137924e 100644 --- a/notebooks/source/ordinal_regression.ipynb +++ b/notebooks/source/ordinal_regression.ipynb @@ -53,7 +53,7 @@ "import pandas as pd\n", "import seaborn as sns\n", "\n", - "assert numpyro.__version__.startswith(\"0.8.0\")" + "assert numpyro.__version__.startswith(\"0.9.0\")" ] }, { diff --git a/notebooks/source/time_series_forecasting.ipynb b/notebooks/source/time_series_forecasting.ipynb index 670713b5a..3e67a19cc 100644 --- a/notebooks/source/time_series_forecasting.ipynb +++ b/notebooks/source/time_series_forecasting.ipynb @@ -48,7 +48,7 @@ " set_matplotlib_formats(\"svg\")\n", "\n", "numpyro.set_host_device_count(4)\n", - "assert numpyro.__version__.startswith(\"0.8.0\")" + "assert numpyro.__version__.startswith(\"0.9.0\")" ] }, { diff --git a/numpyro/contrib/optim.py b/numpyro/contrib/optim.py deleted file mode 100644 index 575d6b696..000000000 --- a/numpyro/contrib/optim.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -""" -This module provides a wrapper for Optax optimizers so that they can be used with -NumPyro inference algorithms. -""" - -from typing import Tuple, TypeVar - -import optax - -from numpyro.optim import _NumPyroOptim - -_Params = TypeVar("_Params") -_State = Tuple[_Params, optax.OptState] - - -def optax_to_numpyro(transformation: optax.GradientTransformation) -> _NumPyroOptim: - """ - This function produces a ``numpyro.optim._NumPyroOptim`` instance from an - ``optax.GradientTransformation`` so that it can be used with - ``numpyro.infer.svi.SVI``. It is a lightweight wrapper that recreates the - ``(init_fn, update_fn, get_params_fn)`` interface defined by - :mod:`jax.example_libraries.optimizers`. - - :param transformation: An ``optax.GradientTransformation`` instance to wrap. - :return: An instance of ``numpyro.optim._NumPyroOptim`` wrapping the supplied - Optax optimizer. - """ - - def init_fn(params: _Params) -> _State: - opt_state = transformation.init(params) - return params, opt_state - - def update_fn(step, grads: _Params, state: _State) -> _State: - params, opt_state = state - updates, opt_state = transformation.update(grads, opt_state, params) - updated_params = optax.apply_updates(params, updates) - return updated_params, opt_state - - def get_params_fn(state: _State) -> _Params: - params, _ = state - return params - - return _NumPyroOptim(lambda x, y, z: (x, y, z), init_fn, update_fn, get_params_fn) diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index 2238d549b..6702d9079 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -23,7 +23,7 @@ from numpyro.distributions.transforms import biject_to from numpyro.handlers import replay, seed, trace from numpyro.infer.util import helpful_support_errors, transform_fn -from numpyro.optim import _NumPyroOptim +from numpyro.optim import _NumPyroOptim, optax_to_numpyro SVIState = namedtuple("SVIState", ["optim_state", "mutable_state", "rng_key"]) """ @@ -120,7 +120,7 @@ class SVI(object): :param optim: An instance of :class:`~numpyro.optim._NumpyroOptim`, a ``jax.example_libraries.optimizers.Optimizer`` or an Optax ``GradientTransformation``. If you pass an Optax optimizer it will - automatically be wrapped using :func:`numpyro.contrib.optim.optax_to_numpyro`. + automatically be wrapped using :func:`numpyro.optim.optax_to_numpyro`. >>> from optax import adam, chain, clip >>> svi = SVI(model, guide, chain(clip(10.0), adam(1e-3)), loss=Trace_ELBO()) @@ -145,8 +145,6 @@ def __init__(self, model, guide, optim, loss, **static_kwargs): else: try: import optax - - from numpyro.contrib.optim import optax_to_numpyro except ImportError: raise ImportError( "It looks like you tried to use an optimizer that isn't an " diff --git a/numpyro/optim.py b/numpyro/optim.py index 55816e4c1..b7bd54efd 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -291,3 +291,34 @@ def loss_fn(x): flat_params, out = results.x, results.fun state = (i + 1, _MinimizeState(flat_params, unravel_fn)) return (out, None), state + + +def optax_to_numpyro(transformation) -> _NumPyroOptim: + """ + This function produces a ``numpyro.optim._NumPyroOptim`` instance from an + ``optax.GradientTransformation`` so that it can be used with + ``numpyro.infer.svi.SVI``. It is a lightweight wrapper that recreates the + ``(init_fn, update_fn, get_params_fn)`` interface defined by + :mod:`jax.example_libraries.optimizers`. + + :param transformation: An ``optax.GradientTransformation`` instance to wrap. + :return: An instance of ``numpyro.optim._NumPyroOptim`` wrapping the supplied + Optax optimizer. + """ + import optax + + def init_fn(params): + opt_state = transformation.init(params) + return params, opt_state + + def update_fn(step, grads, state): + params, opt_state = state + updates, opt_state = transformation.update(grads, opt_state, params) + updated_params = optax.apply_updates(params, updates) + return updated_params, opt_state + + def get_params_fn(state): + params, _ = state + return params + + return _NumPyroOptim(lambda x, y, z: (x, y, z), init_fn, update_fn, get_params_fn) diff --git a/numpyro/version.py b/numpyro/version.py index d514afbbf..f43d7c762 100644 --- a/numpyro/version.py +++ b/numpyro/version.py @@ -1,4 +1,4 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -__version__ = "0.8.0" +__version__ = "0.9.0" diff --git a/setup.py b/setup.py index 21bc8bb07..03f32abe6 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ "dev": [ "dm-haiku", "flax", - "funsor==0.4.1", + "funsor>=0.4.1", "graphviz", "jaxns==0.0.7", "optax>=0.0.6", diff --git a/test/contrib/test_optim.py b/test/contrib/test_optim.py deleted file mode 100644 index 7fb094dcd..000000000 --- a/test/contrib/test_optim.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -from functools import partial - -from numpy.testing import assert_allclose -import pytest - -from jax import grad, jit, random -from jax.lax import fori_loop -import jax.numpy as jnp -from jax.test_util import check_close - -import numpyro -import numpyro.distributions as dist -from numpyro.distributions import constraints -from numpyro.infer import SVI, RenyiELBO, Trace_ELBO - -try: - import optax - - from numpyro.contrib.optim import optax_to_numpyro - - # the optimizer test is parameterized by different optax optimizers, but we have - # to define them here to ensure that `optax` is defined. pytest.mark.parameterize - # decorators are run even if tests are skipped at the top of the file. - optimizers = [ - (optax.adam, (1e-2,), {}), - # clipped adam - (optax.chain, (optax.clip(10.0), optax.adam(1e-2)), {}), - (optax.adagrad, (1e-1,), {}), - # SGD with momentum - (optax.sgd, (1e-2,), {"momentum": 0.9}), - (optax.rmsprop, (1e-2,), {"decay": 0.95}), - # RMSProp with momentum - (optax.rmsprop, (1e-4,), {"decay": 0.9, "momentum": 0.9}), - (optax.sgd, (1e-2,), {}), - ] -except ImportError: - pytestmark = pytest.mark.skip(reason="optax is not installed") - optimizers = [] - - -def loss(params): - return jnp.sum(params["x"] ** 2 + params["y"] ** 2) - - -@partial(jit, static_argnums=(1,)) -def step(opt_state, optim): - params = optim.get_params(opt_state) - g = grad(loss)(params) - return optim.update(g, opt_state) - - -@pytest.mark.parametrize("optim_class, args, kwargs", optimizers) -def test_optim_multi_params(optim_class, args, kwargs): - params = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([-1, -1.0, -1.0])} - opt = optax_to_numpyro(optim_class(*args, **kwargs)) - opt_state = opt.init(params) - for i in range(2000): - opt_state = step(opt_state, opt) - for _, param in opt.get_params(opt_state).items(): - assert jnp.allclose(param, jnp.zeros(3)) - - -@pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)]) -def test_beta_bernoulli(elbo): - data = jnp.array([1.0] * 8 + [0.0] * 2) - - def model(data): - f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) - with numpyro.plate("N", len(data)): - numpyro.sample("obs", dist.Bernoulli(f), obs=data) - - def guide(data): - alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) - beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) - numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) - - adam = optax.adam(0.05) - svi = SVI(model, guide, adam, elbo) - svi_state = svi.init(random.PRNGKey(1), data) - assert_allclose(svi.optim.get_params(svi_state.optim_state)["alpha_q"], 0.0) - - def body_fn(i, val): - svi_state, _ = svi.update(val, data) - return svi_state - - svi_state = fori_loop(0, 2000, body_fn, svi_state) - params = svi.get_params(svi_state) - assert_allclose( - params["alpha_q"] / (params["alpha_q"] + params["beta_q"]), - 0.8, - atol=0.05, - rtol=0.05, - ) - - -def test_jitted_update_fn(): - data = jnp.array([1.0] * 8 + [0.0] * 2) - - def model(data): - f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) - with numpyro.plate("N", len(data)): - numpyro.sample("obs", dist.Bernoulli(f), obs=data) - - def guide(data): - alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) - beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) - numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) - - adam = optax.adam(0.05) - svi = SVI(model, guide, adam, Trace_ELBO()) - svi_state = svi.init(random.PRNGKey(1), data) - expected = svi.get_params(svi.update(svi_state, data)[0]) - - actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0]) - check_close(actual, expected, atol=1e-5) diff --git a/test/test_optimizers.py b/test/test_optimizers.py index cee01e415..f28885759 100644 --- a/test/test_optimizers.py +++ b/test/test_optimizers.py @@ -10,6 +10,28 @@ from numpyro import optim +try: + import optax + + # the optimizer test is parameterized by different optax optimizers, but we have + # to define them here to ensure that `optax` is defined. pytest.mark.parameterize + # decorators are run even if tests are skipped at the top of the file. + optax_optimizers = [ + (optax.adam, (1e-2,), {}), + # clipped adam + (optax.chain, (optax.clip(10.0), optax.adam(1e-2)), {}), + (optax.adagrad, (1e-1,), {}), + # SGD with momentum + (optax.sgd, (1e-2,), {"momentum": 0.9}), + (optax.rmsprop, (1e-2,), {"decay": 0.95}), + # RMSProp with momentum + (optax.rmsprop, (1e-4,), {"decay": 0.9, "momentum": 0.9}), + (optax.sgd, (1e-2,), {}), + ] +except ImportError: + pytestmark = pytest.mark.skip(reason="optax is not installed") + optax_optimizers = [] + def loss(params): return jnp.sum(params["x"] ** 2 + params["y"] ** 2) @@ -23,20 +45,23 @@ def step(opt_state, optim): @pytest.mark.parametrize( - "optim_class, args", + "optim_class, args, kwargs", [ - (optim.Adam, (1e-2,)), - (optim.ClippedAdam, (1e-2,)), - (optim.Adagrad, (1e-1,)), - (optim.Momentum, (1e-2, 0.5)), - (optim.RMSProp, (1e-2, 0.95)), - (optim.RMSPropMomentum, (1e-4,)), - (optim.SGD, (1e-2,)), - ], + (optim.Adam, (1e-2,), {}), + (optim.ClippedAdam, (1e-2,), {}), + (optim.Adagrad, (1e-1,), {}), + (optim.Momentum, (1e-2, 0.5), {}), + (optim.RMSProp, (1e-2, 0.95), {}), + (optim.RMSPropMomentum, (1e-4,), {}), + (optim.SGD, (1e-2,), {}), + ] + + optax_optimizers, ) -def test_optim_multi_params(optim_class, args): +def test_optim_multi_params(optim_class, args, kwargs): params = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([-1, -1.0, -1.0])} - opt = optim_class(*args) + opt = optim_class(*args, **kwargs) + if not isinstance(opt, optim._NumPyroOptim): + opt = optim.optax_to_numpyro(opt) opt_state = opt.init(params) for i in range(2000): opt_state = step(opt_state, opt) @@ -47,20 +72,23 @@ def test_optim_multi_params(optim_class, args): # note: this is somewhat of a bruteforce test. testing directly from # _NumpyroOptim would probably be better @pytest.mark.parametrize( - "optim_class, args", + "optim_class, args, kwargs", [ - (optim.Adam, (1e-2,)), - (optim.ClippedAdam, (1e-2,)), - (optim.Adagrad, (1e-1,)), - (optim.Momentum, (1e-2, 0.5)), - (optim.RMSProp, (1e-2, 0.95)), - (optim.RMSPropMomentum, (1e-4,)), - (optim.SGD, (1e-2,)), - ], + (optim.Adam, (1e-2,), {}), + (optim.ClippedAdam, (1e-2,), {}), + (optim.Adagrad, (1e-1,), {}), + (optim.Momentum, (1e-2, 0.5), {}), + (optim.RMSProp, (1e-2, 0.95), {}), + (optim.RMSPropMomentum, (1e-4,), {}), + (optim.SGD, (1e-2,), {}), + ] + + optax_optimizers, ) -def test_numpyrooptim_no_double_jit(optim_class, args): +def test_numpyrooptim_no_double_jit(optim_class, args, kwargs): - opt = optim_class(*args) + opt = optim_class(*args, **kwargs) + if not isinstance(opt, optim._NumPyroOptim): + opt = optim.optax_to_numpyro(opt) state = opt.init(jnp.zeros(10)) my_fn_calls = 0