Skip to content

Commit e9cba36

Browse files
committed
[Distributed] Estimator input args sanity check.
Signed-off-by: 泊霆 <[email protected]>
1 parent 07c57f7 commit e9cba36

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

tensorflow/python/distribute/group_embedding_collective_strategy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def estimator(self, model_fn, **kwargs):
101101
from tensorflow.python.distribute.hvd_strategy import wraps_estimator
102102
_estimator = wraps_estimator(_estimator_lib.Estimator)
103103
elif self._hb:
104-
_estimator = hb.estimator.Estimator
104+
_estimator = self._hb.estimator.Estimator
105105

106106
return _estimator(model_fn, **kwargs)
107107

tensorflow/python/distribute/hvd_strategy.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1060,10 +1060,14 @@ def __init__(self, model_fn, **kwargs):
10601060
self._eval_drop_remainder = kwargs.pop('eval_drop_remainder', True)
10611061
self._predict_drop_remainder = kwargs.pop(
10621062
'predict_drop_remainder', True)
1063+
config = kwargs.get('config', None)
1064+
if config is None:
1065+
config = run_config_lib.RunConfig()
1066+
else:
1067+
kwargs.pop('config')
10631068

10641069
super().__init__(
1065-
wraps_model_fn(model_fn, model_dir, kwargs['config']),
1066-
**kwargs)
1070+
wraps_model_fn(model_fn, model_dir, config), **kwargs)
10671071

10681072
def _assert_members_are_not_overridden(self):
10691073
r'''disable the overridden check here.

0 commit comments

Comments
 (0)