Skip to content

Commit

Permalink
Add JFT ViT training script tests (smoke test + reproducibility).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 380287468
  • Loading branch information
dusenberrymw authored and copybara-github committed Jun 23, 2021
1 parent 77904f4 commit 9fb506c
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 11 deletions.
40 changes: 30 additions & 10 deletions baselines/jft/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Training loop."""
"""Deterministic ViT on JFT-300M."""

from functools import partial # pylint: disable=g-importing-member so standard
import importlib
Expand All @@ -29,13 +29,10 @@
import flax
import flax.jax_utils as flax_utils
import jax
import jax.config
import jax.nn
import jax.numpy as jnp
import jax.profiler
import ml_collections
import numpy as np
import tensorflow as tf
from tensorflow.io import gfile

import dune.experimental.big_vision.fewshot as fewshot
import dune.experimental.big_vision.input_pipeline as input_pipeline
Expand Down Expand Up @@ -67,11 +64,14 @@ def main(argv):
config = FLAGS.config
workdir = FLAGS.workdir

if config.get("dataset_dir"):
logging.info("data_dir=%s", config.dataset_dir)
logging.info("data_dir contents: %s", os.listdir(config.dataset_dir))
logging.info("Workdir: %s", workdir)

save_checkpoint_path = None
if config.get("checkpoint_steps"):
tf.io.gfile.makedirs(workdir)
gfile.makedirs(workdir)
save_checkpoint_path = os.path.join(workdir, "checkpoint.npz")

# The pool is used to perform misc operations such as logging in async way.
Expand Down Expand Up @@ -306,7 +306,7 @@ def decay_fn(v, wd):
# 3. Initialize model from something, e,g, start a fine-tuning job.
# 4. Train from scratch.
resume_checkpoint_path = None
if save_checkpoint_path and tf.io.gfile.exists(save_checkpoint_path):
if save_checkpoint_path and gfile.exists(save_checkpoint_path):
resume_checkpoint_path = save_checkpoint_path
elif config.get("resume"):
resume_checkpoint_path = fillin(config.resume)
Expand Down Expand Up @@ -340,6 +340,8 @@ def decay_fn(v, wd):
# Prepare the learning-rate and pre-fetch it to device to avoid delays.
lr_fn = u.create_learning_rate_schedule(
batch_size, total_steps, steps_per_epoch, **config.get("lr", {}))
# TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
# necessary for TPUs.
lr_iter = u.prefetch_scalar(map(lr_fn, range(first_step, total_steps)),
config.get("prefetch_to_device", 1))

Expand All @@ -356,6 +358,12 @@ def decay_fn(v, wd):
rngs_loop = flax_utils.replicate(rng_loop)
checkpoint_writer = None

# Note: we return the train loss, val loss, and fewshot best l2s for use in
# reproducibility unit tests.
train_loss = -jnp.inf
val_loss = -jnp.inf
best_l2 = {1: -jnp.inf}

write_note(f"First step compilations...\n{chrono.note}")
# Using a python integer for step here, because opt.state.step is allocated
# on TPU during replication.
Expand All @@ -376,6 +384,7 @@ def decay_fn(v, wd):

# Checkpoint saving
if u.itstime(step, config.get("checkpoint_steps"), total_steps, host=0):
write_note("Checkpointing...")
chrono.pause()
u.checkpointing_timeout(checkpoint_writer,
config.get("checkpoint_timeout", 1))
Expand All @@ -388,6 +397,7 @@ def decay_fn(v, wd):
# Check whether we want to keep a copy of the current checkpoint.
copy_step = None
if u.itstime(step, config.get("keep_checkpoint_steps"), total_steps):
write_note("Keeping a checkpoint copy...")
copy_step = step

# Checkpoint should be a nested dictionary or FLAX datataclasses from
Expand All @@ -399,6 +409,8 @@ def decay_fn(v, wd):

# Report training progress
if u.itstime(step, config.log_training_steps, total_steps, host=0):
write_note("Reporting training progress...")
train_loss = loss_value[0] # Keep to return for reproducibility tests.
mw.measure("learning_rate", lr_repl[0])
mw.measure("training_loss", loss_value[0])
for name, value in extra_measurements.items():
Expand All @@ -407,6 +419,7 @@ def decay_fn(v, wd):

# Report validation performance
if u.itstime(step, config.log_eval_steps, total_steps):
write_note("Evaluating on the validation set...")
chrono.pause()
for val_name, (val_iter, val_steps) in val_ds.items():
ncorrect, loss, nseen = 0, 0, 0
Expand All @@ -420,17 +433,20 @@ def decay_fn(v, wd):
ncorrect += np.sum(np.array(batch_ncorrect[0]))
loss += np.sum(np.array(batch_losses[0]))
nseen += np.sum(np.array(batch_n[0]))
val_loss = loss / nseen # Keep to return for reproducibility tests.
mw.measure(f"{val_name}_prec@1", ncorrect / nseen)
mw.measure(f"{val_name}_loss", loss / nseen)
mw.measure(f"{val_name}_loss", val_loss)
chrono.resume()

if "fewshot" in config:
# Compute few-shot on-the-fly evaluation.
if u.itstime(step, config.fewshot.log_steps, total_steps):
chrono.pause()
write_note(f"Few-shot evaluation...\n{chrono.note}")
r = fewshotter.run_all(opt_repl.target, config.fewshot.datasets)
fewshotter.walk_results(mw.measure, *r)
# Keep `best_l2` to return for reproducibility tests.
results, best_l2 = fewshotter.run_all(opt_repl.target,
config.fewshot.datasets)
fewshotter.walk_results(mw.measure, results, best_l2)
chrono.resume()
mw.step_end()

Expand All @@ -439,6 +455,10 @@ def decay_fn(v, wd):
pool.join()
mw.close()

# Return final training loss, validation loss, and fewshot best l2s for
# reproducibility test cases.
return train_loss, val_loss, best_l2


if __name__ == "__main__":
app.run(main)
146 changes: 146 additions & 0 deletions baselines/jft/deterministic_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# coding=utf-8
# Copyright 2021 The Uncertainty Baselines Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the deterministic ViT on JFT-300M model script."""
import os
import pathlib
import shutil
import tempfile

from absl import flags
from absl import logging
from absl.testing import parameterized
import ml_collections
import tensorflow as tf
import tensorflow_datasets as tfds
import deterministic # local file import

flags.adopt_module_key_flags(deterministic)
FLAGS = flags.FLAGS


def get_config():
"""Config for training a patch-transformer on JFT."""
config = ml_collections.ConfigDict()

# TODO(dusenberrymw): JFT + mocking is broken.
# config.dataset = 'jft/entity:1.0.0'
# config.val_split = 'test[:49511]' # aka tiny_test/test[:5%] in task_adapt
# config.train_split = 'train' # task_adapt used train+validation so +64167
# config.num_classes = 18291

config.dataset = 'imagenet21k'
config.val_split = 'full[:100]'
config.train_split = 'full[100:200]'
config.num_classes = 21843

config.init_head_bias = -10.0

config.trial = 0
config.batch_size = 3
config.total_steps = 1

config.prefetch_to_device = 1

pp_common = '|value_range(-1, 1)'
pp_common += f'|onehot({config.num_classes})'
pp_common += '|keep("image", "labels")'
# TODO(dusenberrymw): Mocking doesn't seem to encode into jpeg format.
# config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
config.pp_train = 'decode|resize_small(256)|central_crop(224)' + pp_common
config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
config.shuffle_buffer_size = 10

config.log_training_steps = 1
config.log_eval_steps = 1
config.checkpoint_steps = 1
config.keep_checkpoint_steps = 1

# Model section
config.model_name = 'resformer'
config.model = ml_collections.ConfigDict()
config.model.resnet = None
config.model.patches = ml_collections.ConfigDict()
config.model.patches.size = [16, 16]
config.model.hidden_size = 2
config.model.transformer = ml_collections.ConfigDict()
config.model.transformer.attention_dropout_rate = 0.
config.model.transformer.dropout_rate = 0.
config.model.transformer.mlp_dim = 2
config.model.transformer.num_heads = 1
config.model.transformer.num_layers = 1
config.model.classifier = 'token' # Or 'gap'
config.model.representation_size = 2

# Optimizer section
config.optim_name = 'Adam'
config.optim = ml_collections.ConfigDict()
config.optim.weight_decay = 0.1
config.optim.beta1 = 0.9
config.optim.beta2 = 0.999
config.weight_decay = None # No explicit weight decay

config.lr = ml_collections.ConfigDict()
config.lr.base = 0.001
config.lr.warmup_steps = 2
config.lr.decay_type = 'linear'
config.lr.linear_end = 1e-5

# Few-shot eval section
config.fewshot = ml_collections.ConfigDict()
config.fewshot.representation_layer = 'pre_logits'
config.fewshot.log_steps = 1
config.fewshot.datasets = {
'pets': ('oxford_iiit_pet', 'train', 'test'),
'imagenet': ('imagenet2012_subset/10pct', 'train', 'validation'),
}
config.fewshot.pp_train = 'decode|resize(256)|central_crop(224)|value_range(-1,1)'
config.fewshot.pp_eval = 'decode|resize(256)|central_crop(224)|value_range(-1,1)'
config.fewshot.shots = [1]
config.fewshot.l2_regs = [2.0**-7]
config.fewshot.walk_first = ('imagenet', config.fewshot.shots[0])

return config


class DeterministicTest(parameterized.TestCase, tf.test.TestCase):

def test_deterministic_script(self):
# Set flags.
FLAGS.xm_runlocal = True
FLAGS.config = get_config()
FLAGS.workdir = tempfile.mkdtemp(dir=self.get_temp_dir())

# Go two directories up to the root of the UB directory.
ub_root_dir = pathlib.Path(__file__).parents[2]
data_dir = str(ub_root_dir) + '/.tfds/metadata'
logging.info('data_dir contents: %s', os.listdir(data_dir))
FLAGS.config.dataset_dir = data_dir

# Check for any errors.
with tfds.testing.mock_data(num_examples=100, data_dir=data_dir):
train_loss, val_loss, fewshot_best_l2s = deterministic.main(None)

# Check for reproducibility.
self.assertAllClose(train_loss, 224.325)
self.assertAllClose(val_loss, 269.692)
self.assertAllClose(list(fewshot_best_l2s.values())[0], 0.0078125)

# TODO(dusenberrymw): Check for ability to restart from previous checkpoint
# (after failure, etc.).


if __name__ == '__main__':
tf.test.main()
2 changes: 1 addition & 1 deletion uncertainty_baselines/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _load(self,
if self._download_data:
self._dataset_builder.download_and_prepare()
dataset = self._dataset_builder.as_dataset(
self._split, decoders=self._decoders)
split=self._split, decoders=self._decoders)

# Possibly cache the original dataset before preprocessing is applied.
if self._cache:
Expand Down

0 comments on commit 9fb506c

Please sign in to comment.