Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement distributed training using horovod #1865

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/TRAINING_ADVANCED.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ This document contains more advanced topics with regard to training models with
9. :ref:`parallel-training-optimization`
10. :ref:`data-importers`
11. :ref:`byte-output-mode`
12. :ref:`horovod-parallel-training`

22 changes: 22 additions & 0 deletions doc/TRAINING_HOROVOD.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. _horovod-parallel-training:

Distributed training using Horovod
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

If you have a capable compute architecture, it is possible to distribute the training using `Horovod <https://github.com/horovod/horovod>`_. A fast network is recommended.
Horovod is capable of using MPI and NVIDIA's NCCL for highly optimized inter-process communication.
It also offers `Gloo <https://github.com/facebookincubator/gloo>`_ as an easy-to-setup communication backend.

For more information about setup or tuning of Horovod please visit `Horovod's documentation <https://horovod.readthedocs.io/en/stable/summary_include.html>`_.

Horovod is expected to run on heterogeneous systems (e.g. different number and model type of GPUs per machine).
However, this can cause unpredictable problems and user interaction in training code is needed.
Therefore, we do only support homogenous systems, which means same hardware and also same software configuration (OS, drivers, MPI, NCCL, TensorFlow, ...) on each machine.
The only exception is different number of GPUs per machine, since this can be controlled by ``horovodrun -H``.

Detailed documentation how to run Horovod is provided `here <https://horovod.readthedocs.io/en/stable/running.html>`_.
The short command to train on 4 machines using 4 GPUs each:

.. code-block:: bash

horovodrun -np 16 -H server1:4,server2:4,server3:4,server4:4 python3 DeepSpeech.py --train_files [...] --horovod
9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def main():
'tensorflow == 1.15.4'
]

horovod_pypi_dep = [
'horovod[tensorflow] == 0.21.3'
]
if os.environ.get('DS_NODECODER', ''):
install_requires = install_requires_base
else:
Expand All @@ -49,6 +52,12 @@ def main():
else:
install_requires = install_requires + tensorflow_pypi_dep


if os.environ.get('DS_WITH_HOROVOD', ''):
install_requires = install_requires + horovod_pypi_dep
else:
install_requires = install_requires

setup(
name='coqui_stt_training',
version=version,
Expand Down
195 changes: 124 additions & 71 deletions training/coqui_stt_training/train.py

Large diffs are not rendered by default.

35 changes: 30 additions & 5 deletions training/coqui_stt_training/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,37 @@ def initialize_globals():
# CPU device
c.cpu_device = '/cpu:0'

# Available GPU devices
c.available_devices = get_available_gpus(c.session_config)
if FLAGS.horovod:
try:
import horovod.tensorflow as hvd
except ImportError as e:
print(
"Error importing Horovod. Did you installed DeepSpeech with 'DS_WITH_HOROVOD=y'? "
"If you do not want to use horovod, do not start with '--horovod=True'")
raise e

hvd.init()

# Pin GPU to be used to process local rank (one GPU per process)
c.session_config.gpu_options.visible_device_list = str(hvd.local_rank())
c.num_devices = hvd.size()
c.is_master_process = True if hvd.rank() == 0 else False
else:
# Available GPU devices
c.available_devices = get_available_gpus(c.session_config)

# If there is no GPU available, we fall back to CPU based operation
if not c.available_devices:
c.available_devices = [c.cpu_device]

c.num_devices = len(c.available_devices)

# If there are no horovod processes the only one should handled like horovod master
c.is_master_process = True

# If there is no GPU available, we fall back to CPU based operation
if not c.available_devices:
c.available_devices = [c.cpu_device]
# If there is no GPU available, we fall back to CPU based operation
if not c.available_devices:
c.available_devices = [c.cpu_device]

if FLAGS.bytes_output_mode:
c.alphabet = UTF8Alphabet()
Expand Down
29 changes: 21 additions & 8 deletions training/coqui_stt_training/util/feeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def create_dataset(sources,
limit=0,
exception_box=None,
process_ahead=None,
buffering=1 * MEGABYTE):
buffering=1 * MEGABYTE,
split_dataset = False):
epoch_counter = Counter() # survives restarts of the dataset and its generator

def generate_values():
Expand Down Expand Up @@ -135,14 +136,26 @@ def batch_fn(sample_ids, features, features_len, transcripts):

process_fn = partial(entry_to_features, train_phase=train_phase, augmentations=augmentations)

dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box),
output_types=(tf.string, tf.float32, tf.int32,
(tf.int64, tf.int32, tf.int64), tf.float64))
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE))

dataset = tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box),
output_types=(tf.string, tf.float32, tf.int32,
(tf.int64, tf.int32, tf.int64), tf.float64))
if split_dataset:
# Using horovod Iterator.get_next() is not aware of different devices.
# A.shard(n, i) will contain all elements of A whose index mod n = i.
import horovod.tensorflow as hvd
dataset = dataset.shard(hvd.size(), hvd.rank())
dataset = dataset.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)

if cache_path:
dataset = dataset.cache(cache_path)
dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn)
.prefetch(len(Config.available_devices)))

dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn))
if split_dataset:
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
else:
dataset = dataset.prefetch(Config.num_devices)

return dataset


Expand Down Expand Up @@ -178,5 +191,5 @@ def create_batch_set(bs, criteria):
ods = create_batch_set(outlier_batch_size,
lambda start, end, f, fl: end - start > int(outlier_duration_ms))
dataset = nds.concatenate(ods)
dataset = dataset.prefetch(len(Config.available_devices))
dataset = dataset.prefetch(len(Config.num_devices))
return dataset
2 changes: 2 additions & 0 deletions training/coqui_stt_training/util/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def create_flags():
f.DEFINE_boolean('train_cudnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.')

f.DEFINE_boolean('horovod', False, 'use horovod for training on multiple gpus')

# Sample limits

f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
Expand Down