Skip to content

Commit

Permalink
use the lion optimizer in MNIST example
Browse files Browse the repository at this point in the history
  • Loading branch information
thomas-meanwhile committed Oct 10, 2023
1 parent 13266c9 commit e1e6ebe
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions examples/mnist.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# You need to have tensorflow and tensorflow-datasets installed to run this example

import tensorflow_datasets as tfds
import optax
import time
Expand All @@ -10,8 +12,7 @@

from jvt import ViT

LEARNING_RATE = 2e-3
MOMENTUM = 0.9
LEARNING_RATE = 2e-4
MAX_ITER = 8
BATCH_SIZE = 256
CKPT_DIR = 'checkpoints'
Expand All @@ -36,9 +37,9 @@ def accuracy(parameters, infer_fn) -> float:
return jnp.mean(jnp.argmax(infer_fn(parameters, images), -1) == labels)


def create_train_state(rng: jax.random.KeyArray, f: nn.Module):
def create_train_state(rng: jax.Array, f: nn.Module):
parameters = jax.jit(f.init)(rng, jnp.ones((1, 28, 28, 1)))
optimizer = optax.sgd(LEARNING_RATE, momentum=MOMENTUM)
optimizer = optax.lion(LEARNING_RATE)
return train_state.TrainState.create(
apply_fn=jax.jit(f.apply),
params=parameters,
Expand Down

0 comments on commit e1e6ebe

Please sign in to comment.