diff --git a/blackjax/types.py b/blackjax/types.py index db6e7c76f..5a3b59f07 100644 --- a/blackjax/types.py +++ b/blackjax/types.py @@ -13,7 +13,6 @@ from typing import Any, Iterable, Mapping, Union import jax -from jax import Array from jax.typing import ArrayLike """ @@ -36,10 +35,11 @@ class WelfordAlgorithmState(NamedTuple): (until we introduce shape annotation). """ #: JAX PyTrees -ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[Any, "ArrayTree"]] +Array = jax.Array +ArrayTree = Union[jax.Array, Iterable["ArrayTree"], Mapping[Any, "ArrayTree"]] ArrayLikeTree = Union[ ArrayLike, Iterable["ArrayLikeTree"], Mapping[Any, "ArrayLikeTree"] ] #: JAX PRNGKey -PRNGKey = jax.random.PRNGKeyArray +PRNGKey = jax.Array diff --git a/pyproject.toml b/pyproject.toml index ae630a9ab..0739361e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,11 +31,11 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", ] dependencies = [ - "fastprogress>=0.2.0", - "jax>=0.3.13", - "jaxlib>=0.3.10", - "jaxopt>=0.5.5", - "optax", + "fastprogress>=1.0.0", + "jax>=0.4.16", + "jaxlib>=0.4.16", + "jaxopt>=0.8", + "optax>=0.1.7", "typing-extensions>=4.4.0", ] dynamic = ["version"] diff --git a/requirements-doc.txt b/requirements-doc.txt index e5561ed08..34e1ad6ac 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -4,8 +4,8 @@ aesara>=2.8.8 arviz flax ipython -jax -jaxlib +jax>=0.4.16 +jaxlib>=0.4.16 jaxopt jupytext myst_nb diff --git a/requirements.txt b/requirements.txt index f67ab4d8e..4cdf22942 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -e ./ -chex +chex>=0.1.83 pre-commit pytest pytest-benchmark diff --git a/tests/adaptation/test_step_size.py b/tests/adaptation/test_step_size.py index bee57c9c8..5c9dc4dbf 100644 --- a/tests/adaptation/test_step_size.py +++ b/tests/adaptation/test_step_size.py @@ -16,7 +16,7 @@ def test_reasonable_step_size(self): def logdensity_fn(x): return -jnp.sum(0.5 * x) - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.key(0) run_key0, run_key1 = jax.random.split(rng_key, 2) init_position = jnp.array([3.0]) diff --git a/tests/mcmc/test_latent_gaussian.py b/tests/mcmc/test_latent_gaussian.py index 32fe6cc12..9f46c9d63 100644 --- a/tests/mcmc/test_latent_gaussian.py +++ b/tests/mcmc/test_latent_gaussian.py @@ -14,7 +14,7 @@ class GaussianTest(chex.TestCase): def test_gaussian(self, seed, mean): n_samples = 500_000 - key = jax.random.PRNGKey(seed) + key = jax.random.key(seed) key1, key2, key3, key4, key5 = jax.random.split(key, 5) D = 5 diff --git a/tests/mcmc/test_metrics.py b/tests/mcmc/test_metrics.py index ef3c1c81d..3501ce0a8 100644 --- a/tests/mcmc/test_metrics.py +++ b/tests/mcmc/test_metrics.py @@ -11,7 +11,7 @@ class GaussianEuclideanMetricsTest(chex.TestCase): def setUp(self): super().setUp() - self.key = random.PRNGKey(0) + self.key = random.key(0) self.dtype = "float32" @parameterized.named_parameters( diff --git a/tests/mcmc/test_proposal.py b/tests/mcmc/test_proposal.py index 128f09892..3a0c3ac38 100644 --- a/tests/mcmc/test_proposal.py +++ b/tests/mcmc/test_proposal.py @@ -10,7 +10,7 @@ class TestNormalProposalDistribution(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(20220611) + self.key = jax.random.key(20220611) def test_normal_univariate(self): """ diff --git a/tests/mcmc/test_random_walk_without_chex.py b/tests/mcmc/test_random_walk_without_chex.py index 7ae431fa4..6e4e7afe1 100644 --- a/tests/mcmc/test_random_walk_without_chex.py +++ b/tests/mcmc/test_random_walk_without_chex.py @@ -21,7 +21,7 @@ def test_one_step_addition(self): Since the density == 1, the proposal is accepted. The random step may depend on the previous position """ - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.key(0) initial_position = jnp.array([50.0]) def random_step(key, position): @@ -51,7 +51,7 @@ def test_logdensity_accepts(position): class IRMHTest(unittest.TestCase): def test_proposal_is_independent_of_position(self): """New position does not depend on previous""" - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.key(0) initial_position = jnp.array([50.0]) other_position = jnp.array([15000.0]) @@ -99,7 +99,7 @@ def test_generate_reject(self): and given that the sampling rule rejects, the prev_state is proposed again """ - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.key(0) prev_state = RWState(jnp.array([30.0]), 15.0) @@ -118,7 +118,7 @@ def test_generate_reject(self): np.testing.assert_allclose(sampled_proposal.state.position, jnp.array([30.0])) def test_generate_accept(self): - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.key(0) prev_state = RWState(jnp.array([30.0]), 15.0) generate = rmh_proposal( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index d48cdb386..8772d2a13 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -71,7 +71,7 @@ class LinearRegressionTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(19) + self.key = jax.random.key(19) def regression_logprob(self, log_scale, coefs, preds, x): """Linear regression""" @@ -229,7 +229,7 @@ class SGMCMCTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(19) + self.key = jax.random.key(19) def logprior_fn(self, position): return -0.5 * jnp.dot(position, position) * 0.01 @@ -379,7 +379,7 @@ class LatentGaussianTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(19) + self.key = jax.random.key(19) self.C = 2.0 * np.eye(1) self.delta = 5.0 self.sampling_steps = 25_000 @@ -494,7 +494,7 @@ class UnivariateNormalTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(12) + self.key = jax.random.key(12) def normal_logprob(self, x): return stats.norm.logpdf(x, loc=1.0, scale=2.0) @@ -572,7 +572,7 @@ class MonteCarloStandardErrorTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(20220203) + self.key = jax.random.key(20220203) def generate_multivariate_target(self, rng=None): """Genrate a Multivariate Normal distribution as target.""" diff --git a/tests/mcmc/test_trajectory.py b/tests/mcmc/test_trajectory.py index 75d436c09..0444ac846 100644 --- a/tests/mcmc/test_trajectory.py +++ b/tests/mcmc/test_trajectory.py @@ -22,7 +22,7 @@ class TrajectoryTest(chex.TestCase): def test_dynamic_progressive_integration_divergence( self, step_size, should_diverge ): - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.key(0) logdensity_fn = jax.scipy.stats.norm.logpdf @@ -74,7 +74,7 @@ def test_dynamic_progressive_integration_divergence( assert is_diverging.item() is should_diverge def test_dynamic_progressive_equal_recursive(self): - rng_key = jax.random.PRNGKey(23132) + rng_key = jax.random.key(23132) def logdensity_fn(x): return -((1.0 - x[0]) ** 2) - 1.5 * (x[1] - x[0] ** 2) ** 2 @@ -202,7 +202,7 @@ def logdensity_fn(x): def test_dynamic_progressive_expansion( self, step_size, should_diverge, should_turn, expected_doublings ): - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.key(0) def logdensity_fn(x): return -0.5 * x**2 @@ -260,7 +260,7 @@ def logdensity_fn(x): assert is_turning == should_turn def test_static_integration_variable_num_steps(self): - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.key(0) logdensity_fn = jax.scipy.stats.norm.logpdf position = 1.0 diff --git a/tests/optimizers/test_optimizers.py b/tests/optimizers/test_optimizers.py index edf68a92a..a715acc18 100644 --- a/tests/optimizers/test_optimizers.py +++ b/tests/optimizers/test_optimizers.py @@ -23,7 +23,7 @@ class OptimizerTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(1) + self.key = jax.random.key(1) @chex.all_variants(with_pmap=False) def test_dual_averaging(self): diff --git a/tests/optimizers/test_pathfinder.py b/tests/optimizers/test_pathfinder.py index 4d9a60c37..b9b9c69be 100644 --- a/tests/optimizers/test_pathfinder.py +++ b/tests/optimizers/test_pathfinder.py @@ -14,7 +14,7 @@ class PathfinderTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(1) + self.key = jax.random.key(1) @chex.all_variants(without_device=False, with_pmap=False) @parameterized.parameters( diff --git a/tests/smc/test_kernel_compatibility.py b/tests/smc/test_kernel_compatibility.py index 105899125..1ea3a62c0 100644 --- a/tests/smc/test_kernel_compatibility.py +++ b/tests/smc/test_kernel_compatibility.py @@ -17,7 +17,7 @@ class SMCAndMCMCIntegrationTest(unittest.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(42) + self.key = jax.random.key(42) self.initial_particles = jax.random.multivariate_normal( self.key, jnp.zeros(2), jnp.eye(2), (3,) ) diff --git a/tests/smc/test_resampling.py b/tests/smc/test_resampling.py index 6c38169c9..20cb0d813 100644 --- a/tests/smc/test_resampling.py +++ b/tests/smc/test_resampling.py @@ -41,7 +41,7 @@ def test_resampling_methods(self, num_samples, method_name): x = jnp.array(np.random.randn(N), dtype="float32") w = w / w.sum() - resampling_keys = jax.random.split(jax.random.PRNGKey(42), batch_size) + resampling_keys = jax.random.split(jax.random.key(42), batch_size) resampling_idx = jax.vmap( self.variant(resampling_methods[method_name], static_argnums=(2,)), diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index c3ad75db5..242e11c55 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -24,7 +24,7 @@ def _weighted_avg_and_std(values, weights): class SMCTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(42) + self.key = jax.random.key(42) @chex.variants(with_jit=True) def test_smc(self): diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index 4ebec39dc..1edd1f723 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -37,7 +37,7 @@ class TemperedSMCTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(42) + self.key = jax.random.key(42) def logdensity_fn(self, log_scale, coefs, preds, x): """Linear regression""" @@ -174,7 +174,7 @@ def test_normalizing_constant(self): num_particles = 200 num_dim = 2 - rng_key = jax.random.PRNGKey(2356) + rng_key = jax.random.key(2356) rng_key, cov_key = jax.random.split(rng_key, 2) chol_cov = jax.random.uniform(cov_key, shape=(num_dim, num_dim)) iu = np.triu_indices(num_dim, 1) diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 8096c770e..d64efa4cd 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -38,7 +38,7 @@ def one_step(state, rng_key): def run_regression(algorithm, **parameters): - key = jax.random.PRNGKey(0) + key = jax.random.key(0) rng_key, init_key0, init_key1 = jax.random.split(key, 3) x_data = jax.random.normal(init_key0, shape=(100_000, 1)) y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) diff --git a/tests/test_compilation.py b/tests/test_compilation.py index d6bab3a07..e16f8ff3c 100644 --- a/tests/test_compilation.py +++ b/tests/test_compilation.py @@ -29,7 +29,7 @@ def logdensity_fn(x): chex.clear_trace_counter() - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.key(0) state = blackjax.hmc.init(1.0, logdensity_fn) kernel = blackjax.hmc( @@ -58,7 +58,7 @@ def logdensity_fn(x): chex.clear_trace_counter() - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.key(0) state = blackjax.nuts.init(1.0, logdensity_fn) kernel = blackjax.nuts( @@ -83,7 +83,7 @@ def logdensity_fn(x): chex.clear_trace_counter() - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.key(0) warmup = blackjax.window_adaptation( algorithm=blackjax.hmc, @@ -111,7 +111,7 @@ def logdensity_fn(x): chex.clear_trace_counter() - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.key(0) warmup = blackjax.window_adaptation( algorithm=blackjax.nuts, diff --git a/tests/test_diagnostics.py b/tests/test_diagnostics.py index a7ac0235c..d41f1dd6a 100644 --- a/tests/test_diagnostics.py +++ b/tests/test_diagnostics.py @@ -50,7 +50,7 @@ def setUp(self): itertools.product(test_cases, [1, 2, 10], [(), (3,), (5, 7)]) ) def test_rhat_ess(self, case, num_chains, event_shape): - rng_key = jax.random.PRNGKey(self.test_seed) + rng_key = jax.random.key(self.test_seed) sample_shape = list(event_shape) if case["chain_axis"] < case["sample_axis"]: sample_shape = insert_list(sample_shape, case["chain_axis"], num_chains) diff --git a/tests/vi/test_meanfield_vi.py b/tests/vi/test_meanfield_vi.py index 2b9ee56f7..c5a8a0865 100644 --- a/tests/vi/test_meanfield_vi.py +++ b/tests/vi/test_meanfield_vi.py @@ -11,7 +11,7 @@ class MFVITest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(42) + self.key = jax.random.key(42) def test_recover_posterior(self): ground_truth = [ diff --git a/tests/vi/test_svgd.py b/tests/vi/test_svgd.py index a0222a863..ba935b53d 100644 --- a/tests/vi/test_svgd.py +++ b/tests/vi/test_svgd.py @@ -32,7 +32,7 @@ def svgd_training_loop( class SvgdTest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(1) + self.key = jax.random.key(1) def test_recover_posterior(self): # TODO improve testing