Skip to content

Commit f35b1df

Browse files
committed
[Distributed] Allow additional horovod DistributedOptimizer args.
Signed-off-by: 泊霆 <[email protected]>
1 parent e9cba36 commit f35b1df

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

tensorflow/python/distribute/hvd_strategy.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,15 @@ def __init__(self, learning_rate=0.001, *args, **kwargs):
397397
else:
398398
def horovod_optimizer(*args, **kwargs):
399399
from horovod.tensorflow import DistributedOptimizer
400-
return DistributedOptimizer(HvdOptimizer(*args, **kwargs))
400+
horovod_args = DistributedOptimizer.__code__.co_varnames
401+
horovod_real_kargs = {}
402+
candidate_keys = list(kwargs.keys())
403+
for kwarg in candidate_keys:
404+
if kwarg in horovod_args:
405+
value = kwargs[kwarg]
406+
del kwargs[kwarg]
407+
horovod_real_kargs[kwarg] = value
408+
return DistributedOptimizer(HvdOptimizer(*args, **kwargs), **horovod_real_kargs)
401409
return horovod_optimizer
402410

403411

0 commit comments

Comments
 (0)