Skip to content

Commit

Permalink
Update API usage
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Sep 20, 2023
1 parent 4f49d02 commit c36bcd2
Show file tree
Hide file tree
Showing 30 changed files with 48 additions and 48 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)

# Iterate
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
for _ in range(100):
rng_key, nuts_key = jax.random.split(rng_key)
state, _ = nuts.step(nuts_key, state)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_custom_gradients.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def logdensity_fn(y):
hmc = blackjax.hmc(logdensity_fn,1e-3, jnp.ones(1), 10)
state = hmc.init(1.)
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
new_state, info = hmc.step(rng_key, state)
```

Expand Down
4 changes: 2 additions & 2 deletions docs/examples/howto_metropolis_within_gibbs.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def sampling_loop(rng_key, initial_state, parameters, num_samples):

```{code-cell} ipython3
%%time
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
positions = sampling_loop(rng_key, initial_state, parameters, 10_000)
```

Expand Down Expand Up @@ -305,7 +305,7 @@ def sampling_loop_general(rng_key, initial_state, logdensity_fn, step_fn, init,

```{code-cell} ipython3
%%time
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
positions_general = sampling_loop_general(
rng_key=rng_key,
initial_state=initial_state,
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_other_frameworks.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ step_size=1e-3
nuts = blackjax.nuts(numba_logpdf, step_size, inverse_mass_matrix)
init = nuts.init(0.)
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
state, info = nuts.step(rng_key, init)
for _ in range(10):
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_sample_multiple_chains.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ And finally, to put `jax.vmap` and `jax.pmap` on an equal foot we sample as many
import multiprocessing
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
num_chains = multiprocessing.cpu_count()
```

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_use_aesara.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def init_param_fn(seed):
"thetas": jax.random.uniform(seed, (n_rat_tumors,), "float64", minval=0, maxval=1),
}
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
init_position = init_param_fn(rng_key)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_use_numpyro.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ import jax
from numpyro.infer.util import initialize_model
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
init_params, potential_fn_gen, *_ = initialize_model(
rng_key,
eight_schools_noncentered,
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/howto_use_oryx.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ from oryx.core.ppl import joint_sample
bnn = mlp([50, 50], num_classes)
initial_weights = joint_sample(bnn)(jax.random.PRNGKey(0), jnp.ones(num_features))
initial_weights = joint_sample(bnn)(jax.random.key(0), jnp.ones(num_features))
print(initial_weights.keys())
```
Expand All @@ -136,7 +136,7 @@ We can now run the window adaptation to get good values for the parameters of th
%%time
import blackjax
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(rng_key, initial_weights, 100)
kernel = blackjax.nuts(logdensity_fn, **parameters).step
Expand Down Expand Up @@ -173,7 +173,7 @@ posterior_weights = states.position
output_logits = jax.vmap(
lambda weights: jax.vmap(lambda x: intervene(bnn, **weights)(
jax.random.PRNGKey(0), x)
jax.random.key(0), x)
)(features)
)(posterior_weights)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_use_pymc.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ import jax
init_position_dict = model.initial_point()
init_position = [init_position_dict[rv] for rv in rvs]
rng_key = jax.random.PRNGKey(1234)
rng_key = jax.random.key(1234)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(rng_key, init_position, 1000)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/howto_use_tfp.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ initial_position = {
}
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
adapt = blackjax.window_adaptation(
blackjax.hmc, logdensity_fn, num_integration_steps=3
)
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def inference_loop(rng_key, kernel, initial_state, num_samples):

```{code-cell} python
%%time
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 10_000)
loc_samples = states.position["loc"].block_until_ready()
Expand Down Expand Up @@ -136,7 +136,7 @@ initial_state

```{code-cell} python
%%time
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
states = inference_loop(rng_key, nuts.step, initial_state, 4_000)
loc_samples = states.position["loc"].block_until_ready()
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)
# Iterate
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
step = jax.jit(nuts.step)
for _ in range(1_000):
rng_key, nuts_key = jax.random.split(rng_key)
Expand Down
2 changes: 1 addition & 1 deletion tests/adaptation/test_step_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_reasonable_step_size(self):
def logdensity_fn(x):
return -jnp.sum(0.5 * x)

rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
run_key0, run_key1 = jax.random.split(rng_key, 2)

init_position = jnp.array([3.0])
Expand Down
2 changes: 1 addition & 1 deletion tests/mcmc/test_latent_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class GaussianTest(chex.TestCase):
def test_gaussian(self, seed, mean):
n_samples = 500_000

key = jax.random.PRNGKey(seed)
key = jax.random.key(seed)
key1, key2, key3, key4, key5 = jax.random.split(key, 5)

D = 5
Expand Down
2 changes: 1 addition & 1 deletion tests/mcmc/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class GaussianEuclideanMetricsTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = random.PRNGKey(0)
self.key = random.key(0)
self.dtype = "float32"

@parameterized.named_parameters(
Expand Down
2 changes: 1 addition & 1 deletion tests/mcmc/test_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class TestNormalProposalDistribution(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(20220611)
self.key = jax.random.key(20220611)

def test_normal_univariate(self):
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/mcmc/test_random_walk_without_chex.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_one_step_addition(self):
Since the density == 1, the proposal is accepted.
The random step may depend on the previous position
"""
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
initial_position = jnp.array([50.0])

def random_step(key, position):
Expand Down Expand Up @@ -51,7 +51,7 @@ def test_logdensity_accepts(position):
class IRMHTest(unittest.TestCase):
def test_proposal_is_independent_of_position(self):
"""New position does not depend on previous"""
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
initial_position = jnp.array([50.0])
other_position = jnp.array([15000.0])

Expand Down Expand Up @@ -99,7 +99,7 @@ def test_generate_reject(self):
and given that the sampling rule rejects,
the prev_state is proposed again
"""
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)

prev_state = RWState(jnp.array([30.0]), 15.0)

Expand All @@ -118,7 +118,7 @@ def test_generate_reject(self):
np.testing.assert_allclose(sampled_proposal.state.position, jnp.array([30.0]))

def test_generate_accept(self):
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)
prev_state = RWState(jnp.array([30.0]), 15.0)

generate = rmh_proposal(
Expand Down
10 changes: 5 additions & 5 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class LinearRegressionTest(chex.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(19)
self.key = jax.random.key(19)

def regression_logprob(self, log_scale, coefs, preds, x):
"""Linear regression"""
Expand Down Expand Up @@ -229,7 +229,7 @@ class SGMCMCTest(chex.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(19)
self.key = jax.random.key(19)

def logprior_fn(self, position):
return -0.5 * jnp.dot(position, position) * 0.01
Expand Down Expand Up @@ -379,7 +379,7 @@ class LatentGaussianTest(chex.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(19)
self.key = jax.random.key(19)
self.C = 2.0 * np.eye(1)
self.delta = 5.0
self.sampling_steps = 25_000
Expand Down Expand Up @@ -494,7 +494,7 @@ class UnivariateNormalTest(chex.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(12)
self.key = jax.random.key(12)

def normal_logprob(self, x):
return stats.norm.logpdf(x, loc=1.0, scale=2.0)
Expand Down Expand Up @@ -572,7 +572,7 @@ class MonteCarloStandardErrorTest(chex.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(20220203)
self.key = jax.random.key(20220203)

def generate_multivariate_target(self, rng=None):
"""Genrate a Multivariate Normal distribution as target."""
Expand Down
8 changes: 4 additions & 4 deletions tests/mcmc/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TrajectoryTest(chex.TestCase):
def test_dynamic_progressive_integration_divergence(
self, step_size, should_diverge
):
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)

logdensity_fn = jax.scipy.stats.norm.logpdf

Expand Down Expand Up @@ -74,7 +74,7 @@ def test_dynamic_progressive_integration_divergence(
assert is_diverging.item() is should_diverge

def test_dynamic_progressive_equal_recursive(self):
rng_key = jax.random.PRNGKey(23132)
rng_key = jax.random.key(23132)

def logdensity_fn(x):
return -((1.0 - x[0]) ** 2) - 1.5 * (x[1] - x[0] ** 2) ** 2
Expand Down Expand Up @@ -202,7 +202,7 @@ def logdensity_fn(x):
def test_dynamic_progressive_expansion(
self, step_size, should_diverge, should_turn, expected_doublings
):
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)

def logdensity_fn(x):
return -0.5 * x**2
Expand Down Expand Up @@ -260,7 +260,7 @@ def logdensity_fn(x):
assert is_turning == should_turn

def test_static_integration_variable_num_steps(self):
rng_key = jax.random.PRNGKey(0)
rng_key = jax.random.key(0)

logdensity_fn = jax.scipy.stats.norm.logpdf
position = 1.0
Expand Down
2 changes: 1 addition & 1 deletion tests/optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class OptimizerTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(1)
self.key = jax.random.key(1)

@chex.all_variants(with_pmap=False)
def test_dual_averaging(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/optimizers/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class PathfinderTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(1)
self.key = jax.random.key(1)

@chex.all_variants(without_device=False, with_pmap=False)
@parameterized.parameters(
Expand Down
2 changes: 1 addition & 1 deletion tests/smc/test_kernel_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SMCAndMCMCIntegrationTest(unittest.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)
self.initial_particles = jax.random.multivariate_normal(
self.key, jnp.zeros(2), jnp.eye(2), (3,)
)
Expand Down
2 changes: 1 addition & 1 deletion tests/smc/test_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_resampling_methods(self, num_samples, method_name):
x = jnp.array(np.random.randn(N), dtype="float32")
w = w / w.sum()

resampling_keys = jax.random.split(jax.random.PRNGKey(42), batch_size)
resampling_keys = jax.random.split(jax.random.key(42), batch_size)

resampling_idx = jax.vmap(
self.variant(resampling_methods[method_name], static_argnums=(2,)),
Expand Down
2 changes: 1 addition & 1 deletion tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _weighted_avg_and_std(values, weights):
class SMCTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)

@chex.variants(with_jit=True)
def test_smc(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/smc/test_tempered_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TemperedSMCTest(chex.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)
self.key = jax.random.key(42)

def logdensity_fn(self, log_scale, coefs, preds, x):
"""Linear regression"""
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_normalizing_constant(self):
num_particles = 200
num_dim = 2

rng_key = jax.random.PRNGKey(2356)
rng_key = jax.random.key(2356)
rng_key, cov_key = jax.random.split(rng_key, 2)
chol_cov = jax.random.uniform(cov_key, shape=(num_dim, num_dim))
iu = np.triu_indices(num_dim, 1)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def one_step(state, rng_key):


def run_regression(algorithm, **parameters):
key = jax.random.PRNGKey(0)
key = jax.random.key(0)
rng_key, init_key0, init_key1 = jax.random.split(key, 3)
x_data = jax.random.normal(init_key0, shape=(100_000, 1))
y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape)
Expand Down
Loading

0 comments on commit c36bcd2

Please sign in to comment.