Skip to content

Commit

Permalink
mnist example modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
tflahaul committed Aug 27, 2022
1 parent b1f026a commit 5b40f35
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
models/

# C extensions
*.so
Expand Down
21 changes: 11 additions & 10 deletions examples/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
BATCH_SIZE = 256
CKPT_DIR = 'models'

train_set = tfds.as_numpy(tfds.load('mnist', split='train', batch_size=BATCH_SIZE, as_supervised=True))
test_set = tfds.as_numpy(tfds.load('mnist', split='test', batch_size=-1, as_supervised=True))
train_set = tfds.as_numpy(tfds.load('mnist', split='train', batch_size=BATCH_SIZE, as_supervised=True, data_dir='/tmp'))
test_set = tfds.as_numpy(tfds.load('mnist', split='test', batch_size=-1, as_supervised=True, data_dir='/tmp'))

@jax.jit
def apply_model(state, images, labels, key):
Expand All @@ -33,28 +33,29 @@ def loss_fn(parameters):

@partial(jax.jit, static_argnames='infer_fn')
def accuracy(parameters, infer_fn) -> float:
return jnp.mean(jnp.argmax(infer_fn(parameters, test_set[0]), -1) == test_set[1])
images, labels = test_set
return jnp.mean(jnp.argmax(infer_fn(parameters, images), -1) == labels)

def create_train_state(rng: jax.random.KeyArray, f: nn.Module):
parameters = jax.jit(f.init)(rng, jnp.ones((1, 28, 28, 1)))
optimizer = optax.sgd(LEARNING_RATE, momentum=MOMENTUM)
return train_state.TrainState.create(
apply_fn=f.apply,
apply_fn=jax.jit(f.apply),
params=parameters,
tx=optimizer)

def main() -> None:
train_fn = ViT(28, 10, 48, 3, 8, 192, enable_dropout=True, dropout_rate=0.1)
infer_fn = ViT(28, 10, 48, 3, 8, 192, enable_dropout=False).apply
key_p, key_d = jax.random.split(jax.random.PRNGKey(seed=3407))
state = create_train_state({'params': key_p, 'dropout': key_d}, train_fn)
train_fn = ViT(28, 10, 48, 3, 6, 192, enable_dropout=True, dropout_rate=0.1)
infer_fn = ViT(28, 10, 48, 3, 6, 192, enable_dropout=False).apply
kp, kd = jax.random.split(jax.random.PRNGKey(seed=3407))
state = create_train_state({'params': kp, 'dropout': kd}, train_fn)
print(f'Number of parameters: {sum(x.size for x in jax.tree_util.tree_leaves(state.params))}')
for epoch in range(1, MAX_ITER + 1):
running_loss, start = 0, time.time()
for images, labels in train_set:
loss, gradients = apply_model(state, images, labels, key_d)
loss, gradients = apply_model(state, images, labels, kd)
state = state.apply_gradients(grads=gradients)
_, key_d = jax.random.split(key_d)
kd, _ = jax.random.split(kd)
running_loss = running_loss + loss
acc = accuracy(state.params, infer_fn)
print(f'epoch {epoch:>2d}/{MAX_ITER}| loss={running_loss:.4f}, accuracy={acc:.3f}, time={time.time() - start:.2f}')
Expand Down

0 comments on commit 5b40f35

Please sign in to comment.