Skip to content

Commit

Permalink
Update dependencies and use new jax random key (#569)
Browse files Browse the repository at this point in the history
* Update dependency

* chex version typo

* fix typing

* fix typing 2

* Revert "fix typing 2"

This reverts commit 154e6b3.

* jax.random.PRNGKey(...) -> jax.random.key(...)
  • Loading branch information
junpenglao authored Sep 20, 2023
1 parent 46d874d commit 8efb158
Show file tree
Hide file tree
Showing 22 changed files with 43 additions and 43 deletions.
6 changes: 3 additions & 3 deletions blackjax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Any, Iterable, Mapping, Union

import jax
from jax import Array
from jax.typing import ArrayLike

"""
Expand All @@ -36,10 +35,11 @@ class WelfordAlgorithmState(NamedTuple):
(until we introduce shape annotation).
"""
#: JAX PyTrees
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[Any, "ArrayTree"]]
Array = jax.Array
ArrayTree = Union[jax.Array, Iterable["ArrayTree"], Mapping[Any, "ArrayTree"]]
ArrayLikeTree = Union[
ArrayLike, Iterable["ArrayLikeTree"], Mapping[Any, "ArrayLikeTree"]
]

#: JAX PRNGKey
PRNGKey = jax.random.PRNGKeyArray
PRNGKey = jax.Array
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
dependencies = [
"fastprogress>=0.2.0",
"jax>=0.3.13",
"jaxlib>=0.3.10",
"jaxopt>=0.5.5",
"optax",
"fastprogress>=1.0.0",
"jax>=0.4.16",
"jaxlib>=0.4.16",
"jaxopt>=0.8",
"optax>=0.1.7",
"typing-extensions>=4.4.0",
]
dynamic = ["version"]
Expand Down
4 changes: 2 additions & 2 deletions requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ aesara>=2.8.8
arviz
flax
ipython
jax
jaxlib
jax>=0.4.16
jaxlib>=0.4.16
jaxopt
jupytext
myst_nb
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-e ./
chex
chex>=0.1.83
pre-commit
pytest
pytest-benchmark
Expand Down
2 changes: 1 addition & 1 deletion tests/adaptation/test_step_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_reasonable_step_size(self):
def logdensity_fn(x):
return -jnp.sum(0.5 * x)

rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
run_key0, run_key1 = jax.random.split(rng_key, 2)

init_position = jnp.array([3.0])
Expand Down
2 changes: 1 addition & 1 deletion tests/mcmc/test_latent_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class GaussianTest(chex.TestCase):
def test_gaussian(self, seed, mean):
n_samples = 500_000

key = jax.random.PRNGKey(seed)
key = jax.random.key(seed)
key1, key2, key3, key4, key5 = jax.random.split(key, 5)

D = 5
Expand Down
2 changes: 1 addition & 1 deletion tests/mcmc/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class GaussianEuclideanMetricsTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = random.PRNGKey(0)
self.key = random.key(0)
self.dtype = "float32"

@parameterized.named_parameters(
Expand Down
2 changes: 1 addition & 1 deletion tests/mcmc/test_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class TestNormalProposalDistribution(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(20220611)
self.key = jax.random.key(20220611)

def test_normal_univariate(self):
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/mcmc/test_random_walk_without_chex.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_one_step_addition(self):
Since the density == 1, the proposal is accepted.
The random step may depend on the previous position
"""
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
initial_position = jnp.array([50.0])

def random_step(key, position):
Expand Down Expand Up @@ -51,7 +51,7 @@ def test_logdensity_accepts(position):
class IRMHTest(unittest.TestCase):
def test_proposal_is_independent_of_position(self):
"""New position does not depend on previous"""
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
initial_position = jnp.array([50.0])
other_position = jnp.array([15000.0])

Expand Down Expand Up @@ -99,7 +99,7 @@ def test_generate_reject(self):
and given that the sampling rule rejects,
the prev_state is proposed again
"""
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)

prev_state = RWState(jnp.array([30.0]), 15.0)

Expand All @@ -118,7 +118,7 @@ def test_generate_reject(self):
np.testing.assert_allclose(sampled_proposal.state.position, jnp.array([30.0]))

def test_generate_accept(self):
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
prev_state = RWState(jnp.array([30.0]), 15.0)

generate = rmh_proposal(
Expand Down
10 changes: 5 additions & 5 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class LinearRegressionTest(chex.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(19)
self.key = jax.random.key(19)

def regression_logprob(self, log_scale, coefs, preds, x):
"""Linear regression"""
Expand Down Expand Up @@ -229,7 +229,7 @@ class SGMCMCTest(chex.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(19)
self.key = jax.random.key(19)

def logprior_fn(self, position):
return -0.5 * jnp.dot(position, position) * 0.01
Expand Down Expand Up @@ -379,7 +379,7 @@ class LatentGaussianTest(chex.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(19)
self.key = jax.random.key(19)
self.C = 2.0 * np.eye(1)
self.delta = 5.0
self.sampling_steps = 25_000
Expand Down Expand Up @@ -494,7 +494,7 @@ class UnivariateNormalTest(chex.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(12)
self.key = jax.random.key(12)

def normal_logprob(self, x):
return stats.norm.logpdf(x, loc=1.0, scale=2.0)
Expand Down Expand Up @@ -572,7 +572,7 @@ class MonteCarloStandardErrorTest(chex.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(20220203)
self.key = jax.random.key(20220203)

def generate_multivariate_target(self, rng=None):
"""Genrate a Multivariate Normal distribution as target."""
Expand Down
8 changes: 4 additions & 4 deletions tests/mcmc/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TrajectoryTest(chex.TestCase):
def test_dynamic_progressive_integration_divergence(
self, step_size, should_diverge
):
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)

logdensity_fn = jax.scipy.stats.norm.logpdf

Expand Down Expand Up @@ -74,7 +74,7 @@ def test_dynamic_progressive_integration_divergence(
assert is_diverging.item() is should_diverge

def test_dynamic_progressive_equal_recursive(self):
rng_key = jax.random.PRNGKey(23132)
rng_key = jax.random.key(23132)

def logdensity_fn(x):
return -((1.0 - x[0]) ** 2) - 1.5 * (x[1] - x[0] ** 2) ** 2
Expand Down Expand Up @@ -202,7 +202,7 @@ def logdensity_fn(x):
def test_dynamic_progressive_expansion(
self, step_size, should_diverge, should_turn, expected_doublings
):
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)

def logdensity_fn(x):
return -0.5 * x**2
Expand Down Expand Up @@ -260,7 +260,7 @@ def logdensity_fn(x):
assert is_turning == should_turn

def test_static_integration_variable_num_steps(self):
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)

logdensity_fn = jax.scipy.stats.norm.logpdf
position = 1.0
Expand Down
2 changes: 1 addition & 1 deletion tests/optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class OptimizerTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(1)
self.key = jax.random.key(1)

@chex.all_variants(with_pmap=False)
def test_dual_averaging(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/optimizers/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class PathfinderTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(1)
self.key = jax.random.key(1)

@chex.all_variants(without_device=False, with_pmap=False)
@parameterized.parameters(
Expand Down
2 changes: 1 addition & 1 deletion tests/smc/test_kernel_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SMCAndMCMCIntegrationTest(unittest.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)
self.initial_particles = jax.random.multivariate_normal(
self.key, jnp.zeros(2), jnp.eye(2), (3,)
)
Expand Down
2 changes: 1 addition & 1 deletion tests/smc/test_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_resampling_methods(self, num_samples, method_name):
x = jnp.array(np.random.randn(N), dtype="float32")
w = w / w.sum()

resampling_keys = jax.random.split(jax.random.PRNGKey(42), batch_size)
resampling_keys = jax.random.split(jax.random.key(42), batch_size)

resampling_idx = jax.vmap(
self.variant(resampling_methods[method_name], static_argnums=(2,)),
Expand Down
2 changes: 1 addition & 1 deletion tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _weighted_avg_and_std(values, weights):
class SMCTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)

@chex.variants(with_jit=True)
def test_smc(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/smc/test_tempered_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TemperedSMCTest(chex.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)

def logdensity_fn(self, log_scale, coefs, preds, x):
"""Linear regression"""
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_normalizing_constant(self):
num_particles = 200
num_dim = 2

rng_key = jax.random.PRNGKey(2356)
rng_key = jax.random.key(2356)
rng_key, cov_key = jax.random.split(rng_key, 2)
chol_cov = jax.random.uniform(cov_key, shape=(num_dim, num_dim))
iu = np.triu_indices(num_dim, 1)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def one_step(state, rng_key):


def run_regression(algorithm, **parameters):
key = jax.random.PRNGKey(0)
key = jax.random.key(0)
rng_key, init_key0, init_key1 = jax.random.split(key, 3)
x_data = jax.random.normal(init_key0, shape=(100_000, 1))
y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def logdensity_fn(x):

chex.clear_trace_counter()

rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
state = blackjax.hmc.init(1.0, logdensity_fn)

kernel = blackjax.hmc(
Expand Down Expand Up @@ -58,7 +58,7 @@ def logdensity_fn(x):

chex.clear_trace_counter()

rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
state = blackjax.nuts.init(1.0, logdensity_fn)

kernel = blackjax.nuts(
Expand All @@ -83,7 +83,7 @@ def logdensity_fn(x):

chex.clear_trace_counter()

rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)

warmup = blackjax.window_adaptation(
algorithm=blackjax.hmc,
Expand Down Expand Up @@ -111,7 +111,7 @@ def logdensity_fn(x):

chex.clear_trace_counter()

rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)

warmup = blackjax.window_adaptation(
algorithm=blackjax.nuts,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def setUp(self):
itertools.product(test_cases, [1, 2, 10], [(), (3,), (5, 7)])
)
def test_rhat_ess(self, case, num_chains, event_shape):
rng_key = jax.random.PRNGKey(self.test_seed)
rng_key = jax.random.key(self.test_seed)
sample_shape = list(event_shape)
if case["chain_axis"] < case["sample_axis"]:
sample_shape = insert_list(sample_shape, case["chain_axis"], num_chains)
Expand Down
2 changes: 1 addition & 1 deletion tests/vi/test_meanfield_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class MFVITest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)

def test_recover_posterior(self):
ground_truth = [
Expand Down
2 changes: 1 addition & 1 deletion tests/vi/test_svgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def svgd_training_loop(
class SvgdTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(1)
self.key = jax.random.key(1)

def test_recover_posterior(self):
# TODO improve testing
Expand Down

0 comments on commit 8efb158

Please sign in to comment.