Skip to content

Commit

Permalink
Correctly forward config.optim_dtype.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555396325
  • Loading branch information
andsteing authored and copybara-github committed Aug 10, 2023
1 parent ac6e056 commit a1d4cce
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions vit_jax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def init_model():
optax.sgd(
learning_rate=lr_fn,
momentum=0.9,
accumulator_dtype='bfloat16',
accumulator_dtype=config.optim_dtype,
),
)

Expand Down Expand Up @@ -224,7 +224,6 @@ def init_model():
img_sec_core_test = (
config.batch_eval * ds_test.cardinality().numpy() /
(time.time() - lt0) / jax.device_count())
lt0 = time.time()

lr = float(lr_fn(step))
logging.info(f'Step: {step} ' # pylint: disable=logging-fstring-interpolation
Expand All @@ -237,6 +236,7 @@ def init_model():
accuracy_test=accuracy_test,
lr=lr,
img_sec_core_test=img_sec_core_test))
lt0, lstep = time.time(), step

# Store checkpoint.
if ((config.checkpoint_every and step % config.eval_every == 0) or
Expand All @@ -246,5 +246,6 @@ def init_model():
flax.jax_utils.unreplicate(opt_state_repl), step), step)
logging.info('Stored checkpoint at step %d to "%s"', step,
checkpoint_path)
lt0, lstep = time.time(), step

return flax.jax_utils.unreplicate(params_repl)

0 comments on commit a1d4cce

Please sign in to comment.