Skip to content

Commit

Permalink
[DIST] Set data_sync_drop_remainder as true by default. (#143)
Browse files Browse the repository at this point in the history
1. set data_sync_drop_remainder to true by default.

Signed-off-by: langshi.cls <[email protected]>
  • Loading branch information
francktcheng authored Apr 24, 2023
1 parent d65f685 commit 0545159
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 22 deletions.
24 changes: 14 additions & 10 deletions docs/tutorial/ranking/criteo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,17 @@ def train(self, filenames):
self._args.top_mlp_dims)
loss = self.compute_loss(logits, labels)
step = tf.train.get_or_create_global_step()
train_op = sgd_decay_optimize(
loss,
lr_initial_value=self._args.lr_initial_value,
lr_warmup_steps=self._args.lr_warmup_steps,
lr_decay_start_step=self._args.lr_decay_start_step,
lr_decay_steps=self._args.lr_decay_steps)
return step, loss, train_op
train_auc, train_auc_update_op = hb.metrics.auc(
labels=labels,
predictions=logits, name='train_auc')
with tf.control_dependencies([train_auc_update_op]):
train_op = sgd_decay_optimize(
loss,
lr_initial_value=self._args.lr_initial_value,
lr_warmup_steps=self._args.lr_warmup_steps,
lr_decay_start_step=self._args.lr_decay_start_step,
lr_decay_steps=self._args.lr_decay_steps)
return step, loss, train_op, train_auc

def evaluate(self, filenames):
r'''Evaluate model.
Expand Down Expand Up @@ -160,7 +164,7 @@ def main(args):
train_filenames = args.filenames
eval_filenames = args.filenames
model = RankingModel(args)
step, loss, train_op = model.train(train_filenames)
step, loss, train_op, train_auc = model.train(train_filenames)

hooks = []
if args.eval_every_n_iter is not None:
Expand All @@ -171,7 +175,7 @@ def main(args):
if args.log_every_n_iter is not None:
hooks.append(
tf.train.LoggingTensorHook(
{'step': step, 'loss': loss},
{'step': step, 'loss': loss, 'train_auc': train_auc},
every_n_iter=args.log_every_n_iter))
if args.train_max_steps is not None:
hooks.append(tf.train.StopAtStepHook(args.train_max_steps))
Expand Down Expand Up @@ -236,5 +240,5 @@ def main(args):
disable_imputation=parsed.disable_imputation,
disable_transform=True,
override_embedding_size=parsed.embedding_dim)
with hb.scope():
with hb.scope(data_sync_drop_remainder=False):
main(parsed)
11 changes: 7 additions & 4 deletions docs/tutorial/ranking/taobao/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,11 @@ def train(self, filenames):
loss = self.compute_loss(logits, labels)
step = tf.train.get_or_create_global_step()
opt = tf.train.AdagradOptimizer(learning_rate=self._args.lr)
train_op = opt.minimize(loss, global_step=step)
return step, loss, train_op
train_auc, train_auc_update_op = hb.metrics.auc(
labels=labels, predictions=logits, name='train_auc')
with tf.control_dependencies([train_auc_update_op]):
train_op = opt.minimize(loss, global_step=step)
return step, loss, train_op, train_auc

def evaluate(self, filenames):
r'''Evaluate model.
Expand Down Expand Up @@ -148,7 +151,7 @@ def main(args):
train_filenames = args.filenames
eval_filenames = args.filenames
model = RankingModel(args)
step, loss, train_op = model.train(train_filenames)
step, loss, train_op, train_auc = model.train(train_filenames)

hooks = []
if args.eval_every_n_iter is not None:
Expand All @@ -159,7 +162,7 @@ def main(args):
if args.log_every_n_iter is not None:
hooks.append(
tf.train.LoggingTensorHook(
{'step': step, 'loss': loss},
{'step': step, 'loss': loss, 'train_auc': train_auc},
every_n_iter=args.log_every_n_iter))
if args.train_max_steps is not None:
hooks.append(tf.train.StopAtStepHook(args.train_max_steps))
Expand Down
2 changes: 1 addition & 1 deletion hybridbackend/tensorflow/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@
_ = (
_ctx.get().options
.register('data_batch_count', 1)
.register('data_sync_drop_remainder', False))
.register('data_sync_drop_remainder', True))
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def _test_distributed(rank):
batch_size = 10

with tf.Graph().as_default():
with hb.scope(mode=tf.estimator.ModeKeys.TRAIN):
with hb.scope(
data_sync_drop_remainder=False, mode=tf.estimator.ModeKeys.TRAIN):
with tf.device('/cpu:0'):
ds = tf.data.Dataset.range(100 + rank * 50)
ds = ds.batch(batch_size=batch_size)
Expand Down
6 changes: 3 additions & 3 deletions hybridbackend/tensorflow/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ def __init__(self, model_fn, **kwargs):
'''
kwargs['config'] = RunConfig.build(prototype=kwargs.pop('config', None))
model_dir = kwargs.get('model_dir', None)
self._train_drop_remainder = kwargs.pop('train_drop_remainder', None)
self._eval_drop_remainder = kwargs.pop('eval_drop_remainder', None)
self._predict_drop_remainder = kwargs.pop('predict_drop_remainder', None)
self._train_drop_remainder = kwargs.pop('train_drop_remainder', True)
self._eval_drop_remainder = kwargs.pop('eval_drop_remainder', True)
self._predict_drop_remainder = kwargs.pop('predict_drop_remainder', True)

super().__init__(
wraps_model_fn(model_fn, model_dir, kwargs['config']),
Expand Down
6 changes: 3 additions & 3 deletions hybridbackend/tensorflow/keras/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,9 @@ class HybridBackendKerasModel(cls, HybridBackendKerasModelBase):
'''
def __init__(self, *args, **kwargs):
self._device_fn = device_function
self._train_drop_remainder = kwargs.pop('train_drop_remainder', None)
self._eval_drop_remainder = kwargs.pop('eval_drop_remainder', None)
self._predict_drop_remainder = kwargs.pop('predict_drop_remainder', None)
self._train_drop_remainder = kwargs.pop('train_drop_remainder', True)
self._eval_drop_remainder = kwargs.pop('eval_drop_remainder', True)
self._predict_drop_remainder = kwargs.pop('predict_drop_remainder', True)
self._load_weights_dir = None
self._load_weights_scope = None
self._load_weights_skip_mismatched = True
Expand Down

0 comments on commit 0545159

Please sign in to comment.