Skip to content

Commit

Permalink
Add eval_checkpoint_dir flags to import the checkpoint directory fo…
Browse files Browse the repository at this point in the history
…r `deterministic.py` and `sngp.py`.

PiperOrigin-RevId: 335086648
  • Loading branch information
zi-lin authored and copybara-github committed Oct 2, 2020
1 parent 0c896c9 commit adc2d41
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
16 changes: 15 additions & 1 deletion baselines/toxic_comments/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@

# Prediction mode.
flags.DEFINE_bool('prediction_mode', False, 'Whether to predict only.')
flags.DEFINE_string('eval_checkpoint_dir', None,
'The directory to restore the model weights from for '
'prediction mode.')

FLAGS = flags.FLAGS

Expand All @@ -121,6 +124,14 @@
'psychiatric_or_mental_illness', 'other_disability')


@flags.multi_flags_validator(
['prediction_mode', 'eval_checkpoint_dir'],
message='`eval_checkpoint_dir` should be provided in prediction mode')
def _check_checkpoint_dir_for_prediction_mode(flags_dict):
return not flags_dict['prediction_mode'] or (
flags_dict['eval_checkpoint_dir'] is not None)


def save_prediction(data, path):
with (tf.io.gfile.GFile(path + '.npy', 'w')) as test_file:
np.save(test_file, np.array(data))
Expand Down Expand Up @@ -257,7 +268,10 @@ def main(argv):
}

checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
if FLAGS.prediction_mode:
latest_checkpoint = tf.train.latest_checkpoint(FLAGS.eval_checkpoint_dir)
else:
latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
initial_epoch = 0
if latest_checkpoint:
# checkpoint.restore must be within a strategy.scope() so that optimizer
Expand Down
26 changes: 20 additions & 6 deletions baselines/toxic_comments/sngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@

# Optimization and evaluation flags
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('per_core_batch_size', 64, 'Batch size per TPU core/GPU.')
flags.DEFINE_integer('eval_batch_size', 512, 'Batch size for CPU evaluation.')
flags.DEFINE_integer('per_core_batch_size', 32, 'Batch size per TPU core/GPU.')
flags.DEFINE_float(
'base_learning_rate', 5e-5,
'Base learning rate when total batch size is 128. It is '
Expand Down Expand Up @@ -170,6 +169,9 @@

# Prediction mode.
flags.DEFINE_bool('prediction_mode', False, 'Whether to predict only.')
flags.DEFINE_string('eval_checkpoint_dir', None,
'The directory to restore the model weights from for '
'prediction mode.')

FLAGS = flags.FLAGS

Expand All @@ -185,6 +187,14 @@
'psychiatric_or_mental_illness', 'other_disability')


@flags.multi_flags_validator(
['prediction_mode', 'eval_checkpoint_dir'],
message='`eval_checkpoint_dir` should be provided in prediction mode')
def _check_checkpoint_dir_for_prediction_mode(flags_dict):
return not flags_dict['prediction_mode'] or (
flags_dict['eval_checkpoint_dir'] is not None)


def save_prediction(data, path):
with (tf.io.gfile.GFile(path + '.npy', 'w')) as test_file:
np.save(test_file, np.array(data))
Expand Down Expand Up @@ -239,6 +249,7 @@ def main(argv):
strategy = tf.distribute.experimental.TPUStrategy(resolver)

batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
test_batch_size = batch_size
data_buffer_size = batch_size * 10

train_dataset_builder = ds.WikipediaToxicityDataset(
Expand Down Expand Up @@ -280,7 +291,7 @@ def main(argv):
for dataset_name, dataset_builder in dataset_builders.items():
test_datasets[dataset_name] = dataset_builder.build(split=base.Split.TEST)
steps_per_eval[dataset_name] = (
dataset_builder.info['num_test_examples'] // FLAGS.eval_batch_size)
dataset_builder.info['num_test_examples'] // test_batch_size)

if FLAGS.use_bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
Expand Down Expand Up @@ -341,7 +352,10 @@ def main(argv):
}

checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
if FLAGS.prediction_mode:
latest_checkpoint = tf.train.latest_checkpoint(FLAGS.eval_checkpoint_dir)
else:
latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
initial_epoch = 0
if latest_checkpoint:
# checkpoint.restore must be within a strategy.scope() so that optimizer
Expand Down Expand Up @@ -476,7 +490,7 @@ def step_fn(inputs):
# If model returns a tuple of (logits, covmat), extract both.
logits, covmat = logits
else:
covmat = tf.eye(FLAGS.eval_batch_size)
covmat = tf.eye(test_batch_size)

if FLAGS.use_bfloat16:
logits = tf.cast(logits, tf.float32)
Expand Down Expand Up @@ -556,7 +570,7 @@ def step_fn(inputs):
# If model returns a tuple of (logits, covmat), extract both.
logits, covmat = logits
else:
covmat = tf.eye(FLAGS.eval_batch_size)
covmat = tf.eye(test_batch_size)

if FLAGS.use_bfloat16:
logits = tf.cast(logits, tf.float32)
Expand Down

0 comments on commit adc2d41

Please sign in to comment.