-
Notifications
You must be signed in to change notification settings - Fork 0
/
c8002_jax_example.py
48 lines (39 loc) · 1.3 KB
/
c8002_jax_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import haiku as hk
import jax.numpy as jnp
import jax
def load_dataset(
split: str,
*,
is_training: bool,
batch_size: int,
):
"""Loads the dataset as a generator of batches."""
ds = tfds.load("mnist:3.*.*", split=split).cache().repeat()
if is_training:
ds = ds.shuffle(10 * batch_size, seed=0)
ds = ds.batch(batch_size)
return iter(tfds.as_numpy(ds))
input_dataset = load_dataset
train = load_dataset("train", is_training=True, batch_size=100)
def softmax_cross_entropy(logits, labels):
one_hot = jax.nn.one_hot(labels, logits.shape[-1])
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)
def loss_fn(images, labels):
mlp = hk.Sequential([
hk.Linear(300), jax.nn.relu,
hk.Linear(100), jax.nn.relu,
hk.Linear(10),
])
logits = mlp(images)
return jnp.mean(softmax_cross_entropy(logits, labels))
#### core routine for training
loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)
rng = jax.random.PRNGKey(42)
dummy_images, dummy_labels = next(input_dataset)
params = loss_fn_t.init(rng, dummy_images, dummy_labels)
def update_rule(param, update):
return param - 0.01 * update
for images, labels in input_dataset:
grads = jax.grad(loss_fn_t.apply)(params, images, labels)
params = jax.tree_map(update_rule, params, grads)