Skip to content

Commit f9dae13

Browse files
jrandom->jr
1 parent fe9c473 commit f9dae13

18 files changed

+140
-142
lines changed

benchmarks/small_neural_ode.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import jax.experimental.ode as experimental
1111
import jax.nn as jnn
1212
import jax.numpy as jnp
13-
import jax.random as jrandom
13+
import jax.random as jr
1414
import numpy as np
1515
import torch # pyright: ignore
1616
import torchdiffeq # pyright: ignore
@@ -44,7 +44,7 @@ def __init__(self):
4444
depth=1,
4545
activation=jnn.softplus,
4646
final_activation=jnn.tanh,
47-
key=jrandom.PRNGKey(0),
47+
key=jr.PRNGKey(0),
4848
)
4949

5050
def __call__(self, t, y, args):
@@ -182,7 +182,7 @@ def run(multiple, grad, batch_size=64, t1=100):
182182
func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight))) # pyright: ignore
183183
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias))) # pyright: ignore
184184

185-
y0_jax = jrandom.normal(jrandom.PRNGKey(1), (batch_size, 4))
185+
y0_jax = jr.normal(jr.PRNGKey(1), (batch_size, 4))
186186
y0_torch = torch.tensor(np.asarray(y0_jax))
187187

188188
time_torch(neural_ode_torch, y0_torch, t1, grad)

diffrax/_brownian/path.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import equinox.internal as eqxi
55
import jax
66
import jax.numpy as jnp
7-
import jax.random as jrandom
7+
import jax.random as jr
88
import jax.tree_util as jtu
99
from jaxtyping import Array, PRNGKeyArray, PyTree
1010

@@ -81,8 +81,8 @@ def evaluate(
8181
t1 = cast(RealScalarLike, t1)
8282
t0_ = force_bitcast_convert_type(t0, jnp.int32)
8383
t1_ = force_bitcast_convert_type(t1, jnp.int32)
84-
key = jrandom.fold_in(self.key, t0_)
85-
key = jrandom.fold_in(key, t1_)
84+
key = jr.fold_in(self.key, t0_)
85+
key = jr.fold_in(key, t1_)
8686
key = split_by_tree(key, self.shape)
8787
return jtu.tree_map(
8888
lambda key, shape: self._evaluate_leaf(t0, t1, key, shape), key, self.shape
@@ -91,7 +91,7 @@ def evaluate(
9191
def _evaluate_leaf(
9292
self, t0: RealScalarLike, t1: RealScalarLike, key, shape: jax.ShapeDtypeStruct
9393
):
94-
return jrandom.normal(key, shape.shape, shape.dtype) * jnp.sqrt(t1 - t0).astype(
94+
return jr.normal(key, shape.shape, shape.dtype) * jnp.sqrt(t1 - t0).astype(
9595
shape.dtype
9696
)
9797

diffrax/_brownian/tree.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import jax
66
import jax.lax as lax
77
import jax.numpy as jnp
8-
import jax.random as jrandom
8+
import jax.random as jr
99
import jax.tree_util as jtu
1010
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
1111

@@ -112,7 +112,7 @@ def _brownian_bridge(self, s, t, u, w_s, w_u, key, shape, dtype):
112112
mean = w_s + (w_u - w_s) * ((t - s) / (u - s))
113113
var = (u - t) * (t - s) / (u - s)
114114
std = jnp.sqrt(var)
115-
return mean + std * jrandom.normal(key, shape, dtype)
115+
return mean + std * jr.normal(key, shape, dtype)
116116

117117
def _evaluate_leaf(
118118
self,
@@ -140,9 +140,9 @@ def _evaluate_leaf(
140140
# errors are only raised after everything has finished executing.
141141
τ = jnp.clip(τ, t0, t1).astype(dtype)
142142

143-
key, init_key = jrandom.split(key, 2)
143+
key, init_key = jr.split(key, 2)
144144
thalf = t0 + 0.5 * (t1 - t0)
145-
w_t1 = jrandom.normal(init_key, shape, dtype) * jnp.sqrt(t1 - t0)
145+
w_t1 = jr.normal(init_key, shape, dtype) * jnp.sqrt(t1 - t0)
146146
w_thalf = self._brownian_bridge(t0, thalf, t1, 0, w_t1, key, shape, dtype)
147147
init_state = _State(
148148
s=t0,
@@ -164,7 +164,7 @@ def _cond_fun(_state):
164164
return (_state.u - _state.s) > self.tol
165165

166166
def _body_fun(_state):
167-
_key1, _key2 = jrandom.split(_state.key, 2)
167+
_key1, _key2 = jr.split(_state.key, 2)
168168
_cond = τ > _state.t
169169
_s = jnp.where(_cond, _state.t, _state.s)
170170
_u = jnp.where(_cond, _state.u, _state.t)

docs/usage/getting-started.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ $y(0) = 1 \qquad \mathrm{d}y(t) = -y(t)\mathrm{d}t + \frac{t}{10}\mathrm{d}w(t)$
6969
over the interval $[0, 3]$.
7070

7171
```python
72-
import jax.random as jrandom
72+
import jax.random as jr
7373
from diffrax import diffeqsolve, ControlTerm, Euler, MultiTerm, ODETerm, SaveAt, VirtualBrownianTree
7474

7575
t0, t1 = 1, 3
7676
drift = lambda t, y, args: -y
7777
diffusion = lambda t, y, args: 0.1 * t
78-
brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-3, shape=(), key=jrandom.PRNGKey(0))
78+
brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-3, shape=(), key=jr.PRNGKey(0))
7979
terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))
8080
solver = Euler()
8181
saveat = SaveAt(dense=True)

examples/continuous_normalising_flow.ipynb

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
"import jax.lax as lax\n",
7474
"import jax.nn as jnn\n",
7575
"import jax.numpy as jnp\n",
76-
"import jax.random as jrandom\n",
76+
"import jax.random as jr\n",
7777
"import matplotlib.pyplot as plt\n",
7878
"import optax # https://github.com/deepmind/optax\n",
7979
"import scipy.stats as stats\n",
@@ -111,7 +111,7 @@
111111
"\n",
112112
" def __init__(self, *, data_size, width_size, depth, key, **kwargs):\n",
113113
" super().__init__(**kwargs)\n",
114-
" keys = jrandom.split(key, depth + 1)\n",
114+
" keys = jr.split(key, depth + 1)\n",
115115
" layers = []\n",
116116
" if depth == 0:\n",
117117
" layers.append(\n",
@@ -150,7 +150,7 @@
150150
"\n",
151151
" def __init__(self, *, in_size, out_size, key, **kwargs):\n",
152152
" super().__init__(**kwargs)\n",
153-
" key1, key2, key3 = jrandom.split(key, 3)\n",
153+
" key1, key2, key3 = jr.split(key, 3)\n",
154154
" self.lin1 = eqx.nn.Linear(in_size, out_size, key=key1)\n",
155155
" self.lin2 = eqx.nn.Linear(1, out_size, key=key2)\n",
156156
" self.lin3 = eqx.nn.Linear(1, out_size, use_bias=False, key=key3)\n",
@@ -251,7 +251,7 @@
251251
" **kwargs,\n",
252252
" ):\n",
253253
" super().__init__(**kwargs)\n",
254-
" keys = jrandom.split(key, num_blocks)\n",
254+
" keys = jr.split(key, num_blocks)\n",
255255
" self.funcs = [\n",
256256
" Func(\n",
257257
" data_size=data_size,\n",
@@ -274,7 +274,7 @@
274274
" else:\n",
275275
" term = diffrax.ODETerm(approx_logp_wrapper)\n",
276276
" solver = diffrax.Tsit5()\n",
277-
" eps = jrandom.normal(key, y.shape)\n",
277+
" eps = jr.normal(key, y.shape)\n",
278278
" delta_log_likelihood = 0.0\n",
279279
" for func in reversed(self.funcs):\n",
280280
" y = (y, delta_log_likelihood)\n",
@@ -286,7 +286,7 @@
286286
"\n",
287287
" # Runs forward-in-time to draw samples from the CNF.\n",
288288
" def sample(self, *, key):\n",
289-
" y = jrandom.normal(key, (self.data_size,))\n",
289+
" y = jr.normal(key, (self.data_size,))\n",
290290
" for func in self.funcs:\n",
291291
" term = diffrax.ODETerm(func)\n",
292292
" solver = diffrax.Tsit5()\n",
@@ -300,7 +300,7 @@
300300
" t_so_far = self.t0\n",
301301
" t_end = self.t0 + (self.t1 - self.t0) * len(self.funcs)\n",
302302
" save_times = jnp.linspace(self.t0, t_end, 6)\n",
303-
" y = jrandom.normal(key, (self.data_size,))\n",
303+
" y = jr.normal(key, (self.data_size,))\n",
304304
" out = []\n",
305305
" for i, func in enumerate(self.funcs):\n",
306306
" if i == len(self.funcs) - 1:\n",
@@ -404,7 +404,7 @@
404404
"class DataLoader(eqx.Module):\n",
405405
" arrays: tuple[jnp.ndarray, ...]\n",
406406
" batch_size: int\n",
407-
" key: jrandom.PRNGKey\n",
407+
" key: jr.PRNGKey\n",
408408
"\n",
409409
" def __check_init__(self):\n",
410410
" dataset_size = self.arrays[0].shape[0]\n",
@@ -414,8 +414,8 @@
414414
" dataset_size = self.arrays[0].shape[0]\n",
415415
" num_batches = dataset_size // self.batch_size\n",
416416
" epoch = step // num_batches\n",
417-
" key = jrandom.fold_in(self.key, epoch)\n",
418-
" perm = jrandom.permutation(key, jnp.arange(dataset_size))\n",
417+
" key = jr.fold_in(self.key, epoch)\n",
418+
" perm = jr.permutation(key, jnp.arange(dataset_size))\n",
419419
" start = (step % num_batches) * self.batch_size\n",
420420
" slice_size = self.batch_size\n",
421421
" batch_indices = lax.dynamic_slice_in_dim(perm, start, slice_size)\n",
@@ -464,8 +464,8 @@
464464
" else:\n",
465465
" out_path = pathlib.Path(out_path)\n",
466466
"\n",
467-
" key = jrandom.PRNGKey(seed)\n",
468-
" model_key, loader_key, loss_key, sample_key = jrandom.split(key, 4)\n",
467+
" key = jr.PRNGKey(seed)\n",
468+
" model_key, loader_key, loss_key, sample_key = jr.split(key, 4)\n",
469469
"\n",
470470
" dataset, weights, mean, std, img, width, height = get_data(in_path)\n",
471471
" dataset_size, data_size = dataset.shape\n",
@@ -486,9 +486,9 @@
486486
" @eqx.filter_value_and_grad\n",
487487
" def loss(model, data, weight, loss_key):\n",
488488
" batch_size, _ = data.shape\n",
489-
" noise_key, train_key = jrandom.split(loss_key, 2)\n",
490-
" train_key = jrandom.split(key, batch_size)\n",
491-
" data = data + jrandom.normal(noise_key, data.shape) * 0.5 / std\n",
489+
" noise_key, train_key = jr.split(loss_key, 2)\n",
490+
" train_key = jr.split(key, batch_size)\n",
491+
" data = data + jr.normal(noise_key, data.shape) * 0.5 / std\n",
492492
" log_likelihood = jax.vmap(model.train)(data, key=train_key)\n",
493493
" return -jnp.mean(weight * log_likelihood) # minimise negative log-likelihood\n",
494494
"\n",
@@ -514,7 +514,7 @@
514514
" value = value + value_\n",
515515
" grads = jax.tree_util.tree_map(lambda a, b: a + b, grads, grads_)\n",
516516
" step = step + 1\n",
517-
" loss_key = jrandom.split(loss_key, 1)[0]\n",
517+
" loss_key = jr.split(loss_key, 1)[0]\n",
518518
" return value, grads, step, loss_key\n",
519519
"\n",
520520
" value, grads, step, loss_key = lax.fori_loop(\n",
@@ -537,7 +537,7 @@
537537
" print(f\"Step: {step}, Loss: {value}, Computation time: {end - start}\")\n",
538538
"\n",
539539
" num_samples = 5000\n",
540-
" sample_key = jrandom.split(sample_key, num_samples)\n",
540+
" sample_key = jr.split(sample_key, num_samples)\n",
541541
" samples = jax.vmap(model.sample)(key=sample_key)\n",
542542
" sample_flows = jax.vmap(model.sample_flow, out_axes=-1)(key=sample_key)\n",
543543
" fig, (*axs, ax, axtrue) = plt.subplots(\n",

examples/latent_ode.ipynb

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
"import jax\n",
7979
"import jax.nn as jnn\n",
8080
"import jax.numpy as jnp\n",
81-
"import jax.random as jrandom\n",
81+
"import jax.random as jr\n",
8282
"import matplotlib\n",
8383
"import matplotlib.pyplot as plt\n",
8484
"import numpy as np\n",
@@ -142,7 +142,7 @@
142142
" ):\n",
143143
" super().__init__(**kwargs)\n",
144144
"\n",
145-
" mkey, gkey, hlkey, lhkey, hdkey = jrandom.split(key, 5)\n",
145+
" mkey, gkey, hlkey, lhkey, hdkey = jr.split(key, 5)\n",
146146
"\n",
147147
" scale = jnp.ones(())\n",
148148
" mlp = eqx.nn.MLP(\n",
@@ -175,7 +175,7 @@
175175
" context = self.hidden_to_latent(hidden)\n",
176176
" mean, logstd = context[: self.latent_size], context[self.latent_size :]\n",
177177
" std = jnp.exp(logstd)\n",
178-
" latent = mean + jrandom.normal(key, (self.latent_size,)) * std\n",
178+
" latent = mean + jr.normal(key, (self.latent_size,)) * std\n",
179179
" return latent, mean, std\n",
180180
"\n",
181181
" # Decoder of the VAE\n",
@@ -209,7 +209,7 @@
209209
"\n",
210210
" # Run just the decoder during inference.\n",
211211
" def sample(self, ts, *, key):\n",
212-
" latent = jrandom.normal(key, (self.latent_size,))\n",
212+
" latent = jr.normal(key, (self.latent_size,))\n",
213213
" return self._sample(ts, latent)"
214214
]
215215
},
@@ -231,13 +231,13 @@
231231
"outputs": [],
232232
"source": [
233233
"def get_data(dataset_size, *, key):\n",
234-
" ykey, tkey1, tkey2 = jrandom.split(key, 3)\n",
234+
" ykey, tkey1, tkey2 = jr.split(key, 3)\n",
235235
"\n",
236-
" y0 = jrandom.normal(ykey, (dataset_size, 2))\n",
236+
" y0 = jr.normal(ykey, (dataset_size, 2))\n",
237237
"\n",
238238
" t0 = 0\n",
239-
" t1 = 2 + jrandom.uniform(tkey1, (dataset_size,))\n",
240-
" ts = jrandom.uniform(tkey2, (dataset_size, 20)) * (t1[:, None] - t0) + t0\n",
239+
" t1 = 2 + jr.uniform(tkey1, (dataset_size,))\n",
240+
" ts = jr.uniform(tkey2, (dataset_size, 20)) * (t1[:, None] - t0) + t0\n",
241241
" ts = jnp.sort(ts)\n",
242242
" dt0 = 0.1\n",
243243
"\n",
@@ -273,8 +273,8 @@
273273
" assert all(array.shape[0] == dataset_size for array in arrays)\n",
274274
" indices = jnp.arange(dataset_size)\n",
275275
" while True:\n",
276-
" perm = jrandom.permutation(key, indices)\n",
277-
" (key,) = jrandom.split(key, 1)\n",
276+
" perm = jr.permutation(key, indices)\n",
277+
" (key,) = jr.split(key, 1)\n",
278278
" start = 0\n",
279279
" end = batch_size\n",
280280
" while start < dataset_size:\n",
@@ -311,8 +311,8 @@
311311
" depth=2,\n",
312312
" seed=5678,\n",
313313
"):\n",
314-
" key = jrandom.PRNGKey(seed)\n",
315-
" data_key, model_key, loader_key, train_key, sample_key = jrandom.split(key, 5)\n",
314+
" key = jr.PRNGKey(seed)\n",
315+
" data_key, model_key, loader_key, train_key, sample_key = jr.split(key, 5)\n",
316316
"\n",
317317
" ts, ys = get_data(dataset_size, key=data_key)\n",
318318
"\n",
@@ -328,14 +328,14 @@
328328
" @eqx.filter_value_and_grad\n",
329329
" def loss(model, ts_i, ys_i, key_i):\n",
330330
" batch_size, _ = ts_i.shape\n",
331-
" key_i = jrandom.split(key_i, batch_size)\n",
331+
" key_i = jr.split(key_i, batch_size)\n",
332332
" loss = jax.vmap(model.train)(ts_i, ys_i, key=key_i)\n",
333333
" return jnp.mean(loss)\n",
334334
"\n",
335335
" @eqx.filter_jit\n",
336336
" def make_step(model, opt_state, ts_i, ys_i, key_i):\n",
337337
" value, grads = loss(model, ts_i, ys_i, key_i)\n",
338-
" key_i = jrandom.split(key_i, 1)[0]\n",
338+
" key_i = jr.split(key_i, 1)[0]\n",
339339
" updates, opt_state = optim.update(grads, opt_state)\n",
340340
" model = eqx.apply_updates(model, updates)\n",
341341
" return value, model, opt_state, key_i\n",

examples/neural_cde.ipynb

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
"import jax\n",
6262
"import jax.nn as jnn\n",
6363
"import jax.numpy as jnp\n",
64-
"import jax.random as jrandom\n",
64+
"import jax.random as jr\n",
6565
"import jax.scipy as jsp\n",
6666
"import matplotlib\n",
6767
"import matplotlib.pyplot as plt\n",
@@ -136,7 +136,7 @@
136136
"\n",
137137
" def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):\n",
138138
" super().__init__(**kwargs)\n",
139-
" ikey, fkey, lkey = jrandom.split(key, 3)\n",
139+
" ikey, fkey, lkey = jr.split(key, 3)\n",
140140
" self.initial = eqx.nn.MLP(data_size, hidden_size, width_size, depth, key=ikey)\n",
141141
" self.func = Func(data_size, hidden_size, width_size, depth, key=fkey)\n",
142142
" self.linear = eqx.nn.Linear(hidden_size, 1, key=lkey)\n",
@@ -195,9 +195,9 @@
195195
"outputs": [],
196196
"source": [
197197
"def get_data(dataset_size, add_noise, *, key):\n",
198-
" theta_key, noise_key = jrandom.split(key, 2)\n",
198+
" theta_key, noise_key = jr.split(key, 2)\n",
199199
" length = 100\n",
200-
" theta = jrandom.uniform(theta_key, (dataset_size,), minval=0, maxval=2 * math.pi)\n",
200+
" theta = jr.uniform(theta_key, (dataset_size,), minval=0, maxval=2 * math.pi)\n",
201201
" y0 = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)\n",
202202
" ts = jnp.broadcast_to(jnp.linspace(0, 4 * math.pi, length), (dataset_size, length))\n",
203203
" matrix = jnp.array([[-0.3, 2], [-2, -0.3]])\n",
@@ -207,7 +207,7 @@
207207
" ys = jnp.concatenate([ts[:, :, None], ys], axis=-1) # time is a channel\n",
208208
" ys = ys.at[: dataset_size // 2, :, 1].multiply(-1)\n",
209209
" if add_noise:\n",
210-
" ys = ys + jrandom.normal(noise_key, ys.shape) * 0.1\n",
210+
" ys = ys + jr.normal(noise_key, ys.shape) * 0.1\n",
211211
" coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts, ys)\n",
212212
" labels = jnp.zeros((dataset_size,))\n",
213213
" labels = labels.at[: dataset_size // 2].set(1.0)\n",
@@ -227,8 +227,8 @@
227227
" assert all(array.shape[0] == dataset_size for array in arrays)\n",
228228
" indices = jnp.arange(dataset_size)\n",
229229
" while True:\n",
230-
" perm = jrandom.permutation(key, indices)\n",
231-
" (key,) = jrandom.split(key, 1)\n",
230+
" perm = jr.permutation(key, indices)\n",
231+
" (key,) = jr.split(key, 1)\n",
232232
" start = 0\n",
233233
" end = batch_size\n",
234234
" while end < dataset_size:\n",
@@ -264,8 +264,8 @@
264264
" depth=1,\n",
265265
" seed=5678,\n",
266266
"):\n",
267-
" key = jrandom.PRNGKey(seed)\n",
268-
" train_data_key, test_data_key, model_key, loader_key = jrandom.split(key, 4)\n",
267+
" key = jr.PRNGKey(seed)\n",
268+
" train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)\n",
269269
"\n",
270270
" ts, coeffs, labels, data_size = get_data(\n",
271271
" dataset_size, add_noise, key=train_data_key\n",

0 commit comments

Comments
 (0)