From 69dbe8b89f0a20249c496aa6c37e47984166df72 Mon Sep 17 00:00:00 2001 From: Szymon Migacz <1934379+szmigacz@users.noreply.github.com> Date: Tue, 9 Apr 2019 02:25:12 +0200 Subject: [PATCH] Update for GNMT reference (#222) * Update for GNMT reference * Fixed weight init for not shared embeddings (this doesn't affect the reference config) * Better description of vocabulary and text datasets * Added example logfile from GNMT training on reference config --- compliance/example_logs/gnmt_example.log | 74 ++ rnn_translator/README.md | 192 ++++- rnn_translator/download_dataset.sh | 6 +- rnn_translator/pytorch/Dockerfile | 6 +- rnn_translator/pytorch/README.md | 180 ++++- rnn_translator/pytorch/multiproc.py | 46 -- rnn_translator/pytorch/requirements.txt | 4 +- rnn_translator/pytorch/run.sh | 10 +- rnn_translator/pytorch/run_and_time.sh | 2 +- .../pytorch/scripts/benchmark_inference.sh | 101 --- .../pytorch/scripts/benchmark_training.sh | 36 - .../pytorch/scripts/docker/build.sh | 2 +- .../scripts/downloaders/filter_dataset.py | 50 -- rnn_translator/pytorch/scripts/evaluate.sh | 29 - .../pytorch/scripts/filter_dataset.py | 79 ++ .../pytorch/scripts/parse_train_benchmark.sh | 50 -- .../pytorch/scripts/trim_checkpoints.py | 48 -- rnn_translator/pytorch/seq2seq/data/config.py | 13 +- .../pytorch/seq2seq/data/dataset.py | 338 ++++++++- .../pytorch/seq2seq/data/sampler.py | 285 ++++++-- .../pytorch/seq2seq/data/tokenizer.py | 101 ++- .../pytorch/seq2seq/inference/beam_search.py | 99 ++- .../pytorch/seq2seq/inference/inference.py | 301 ++++++-- .../pytorch/seq2seq/models/__init__.py | 4 - .../pytorch/seq2seq/models/attention.py | 48 +- .../pytorch/seq2seq/models/decoder.py | 141 +++- .../pytorch/seq2seq/models/encoder.py | 49 +- rnn_translator/pytorch/seq2seq/models/gnmt.py | 53 +- .../pytorch/seq2seq/models/seq2seq_base.py | 42 ++ .../pytorch/seq2seq/train/distributed.py | 222 ------ .../pytorch/seq2seq/train/fp_optimizers.py | 105 ++- .../pytorch/seq2seq/train/lr_scheduler.py | 104 +++ .../pytorch/seq2seq/train/smoothing.py | 13 +- .../pytorch/seq2seq/train/trainer.py | 214 +++++- rnn_translator/pytorch/seq2seq/utils.py | 297 +++++++- rnn_translator/pytorch/train.py | 679 ++++++++++-------- rnn_translator/pytorch/translate.py | 292 ++++---- 37 files changed, 2854 insertions(+), 1461 deletions(-) create mode 100644 compliance/example_logs/gnmt_example.log delete mode 100644 rnn_translator/pytorch/multiproc.py delete mode 100644 rnn_translator/pytorch/scripts/benchmark_inference.sh delete mode 100644 rnn_translator/pytorch/scripts/benchmark_training.sh delete mode 100644 rnn_translator/pytorch/scripts/downloaders/filter_dataset.py delete mode 100644 rnn_translator/pytorch/scripts/evaluate.sh create mode 100644 rnn_translator/pytorch/scripts/filter_dataset.py delete mode 100644 rnn_translator/pytorch/scripts/parse_train_benchmark.sh delete mode 100644 rnn_translator/pytorch/scripts/trim_checkpoints.py delete mode 100644 rnn_translator/pytorch/seq2seq/models/__init__.py delete mode 100644 rnn_translator/pytorch/seq2seq/train/distributed.py create mode 100644 rnn_translator/pytorch/seq2seq/train/lr_scheduler.py diff --git a/compliance/example_logs/gnmt_example.log b/compliance/example_logs/gnmt_example.log new file mode 100644 index 000000000..89a6233c0 --- /dev/null +++ b/compliance/example_logs/gnmt_example.log @@ -0,0 +1,74 @@ +:::MLPv0.5.0 gnmt 1550049265.393176317 (train.py:268) run_start +:::MLPv0.5.0 gnmt 1550049265.394439697 (seq2seq/utils.py:113) run_set_random_seed: 1 +:::MLPv0.5.0 gnmt 1550049265.412598848 (train.py:312) preproc_tokenize_training +:::MLPv0.5.0 gnmt 1550049265.413046360 (train.py:314) train_hp_max_sequence_length: 50 +:::MLPv0.5.0 gnmt 1550049274.456599236 (train.py:326) preproc_num_train_examples: 3498161 +:::MLPv0.5.0 gnmt 1550049275.559447765 (train.py:336) preproc_tokenize_eval +:::MLPv0.5.0 gnmt 1550049275.597178459 (train.py:346) preproc_num_eval_examples: 3003 +:::MLPv0.5.0 gnmt 1550049275.597592831 (train.py:350) preproc_vocab_size: 32317 +:::MLPv0.5.0 gnmt 1550049275.598169804 (seq2seq/models/gnmt.py:34) model_hp_num_layers: 4 +:::MLPv0.5.0 gnmt 1550049275.598643303 (seq2seq/models/gnmt.py:36) model_hp_hidden_size: 1024 +:::MLPv0.5.0 gnmt 1550049275.599111080 (seq2seq/models/gnmt.py:38) model_hp_dropout: 0.2 +:::MLPv0.5.0 gnmt 1550049279.603173733 (train.py:251) model_hp_loss_fn: "Cross Entropy with label smoothing" +:::MLPv0.5.0 gnmt 1550049279.603722811 (train.py:253) model_hp_loss_smoothing: 0.1 +:::MLPv0.5.0 gnmt 1550049280.493244410 (train.py:393) input_batch_size: 128 +:::MLPv0.5.0 gnmt 1550049280.493689060 (train.py:395) input_size: 3497728 +:::MLPv0.5.0 gnmt 1550049280.494647741 (seq2seq/data/sampler.py:254) input_order +:::MLPv0.5.0 gnmt 1550049280.496007442 (seq2seq/data/sampler.py:254) input_order +:::MLPv0.5.0 gnmt 1550049280.496948242 (train.py:409) eval_size: 3003 +:::MLPv0.5.0 gnmt 1550049280.497889280 (seq2seq/inference/beam_search.py:43) eval_hp_beam_size: 5 +:::MLPv0.5.0 gnmt 1550049280.498333216 (seq2seq/inference/beam_search.py:45) eval_hp_max_sequence_length: 150 +:::MLPv0.5.0 gnmt 1550049280.498714685 (seq2seq/inference/beam_search.py:47) eval_hp_length_normalization_constant: 5.0 +:::MLPv0.5.0 gnmt 1550049280.499088049 (seq2seq/inference/beam_search.py:49) eval_hp_length_normalization_factor: 0.6 +:::MLPv0.5.0 gnmt 1550049280.499464273 (seq2seq/inference/beam_search.py:51) eval_hp_coverage_penalty_factor: 0.1 +:::MLPv0.5.0 gnmt 1550049283.242700577 (seq2seq/train/trainer.py:115) opt_name: "adam" +:::MLPv0.5.0 gnmt 1550049283.243218422 (seq2seq/train/trainer.py:117) opt_learning_rate: 0.001 +:::MLPv0.5.0 gnmt 1550049283.243656635 (seq2seq/train/trainer.py:119) opt_hp_Adam_beta1: 0.9 +:::MLPv0.5.0 gnmt 1550049283.244005442 (seq2seq/train/trainer.py:121) opt_hp_Adam_beta2: 0.999 +:::MLPv0.5.0 gnmt 1550049283.244350195 (seq2seq/train/trainer.py:123) opt_hp_Adam_epsilon: 1e-08 +:::MLPv0.5.0 gnmt 1550049283.245311022 (seq2seq/train/lr_scheduler.py:78) opt_learning_rate_warmup_steps: 200 +:::MLPv0.5.0 gnmt 1550049283.245801687 (train.py:466) train_loop +:::MLPv0.5.0 gnmt 1550049283.246350050 (train.py:470) train_epoch: 0 +:::MLPv0.5.0 gnmt 1550049284.053482771 (seq2seq/data/sampler.py:205) input_order +:::MLPv0.5.0 gnmt 1550061484.477649927 (train.py:483) train_checkpoint +:::MLPv0.5.0 gnmt 1550061486.181115627 (train.py:490) eval_start: 0 +:::MLPv0.5.0 gnmt 1550061514.467880964 (train.py:495) eval_accuracy: {"epoch": 0, "value": 20.49} +:::MLPv0.5.0 gnmt 1550061514.468363047 (train.py:497) eval_target: 24.0 +:::MLPv0.5.0 gnmt 1550061514.468725204 (train.py:498) eval_stop +:::MLPv0.5.0 gnmt 1550061514.469318390 (train.py:470) train_epoch: 1 +:::MLPv0.5.0 gnmt 1550061515.311945915 (seq2seq/data/sampler.py:205) input_order +:::MLPv0.5.0 gnmt 1550073699.708187580 (train.py:483) train_checkpoint +:::MLPv0.5.0 gnmt 1550073705.950762510 (train.py:490) eval_start: 1 +:::MLPv0.5.0 gnmt 1550073733.111917734 (train.py:495) eval_accuracy: {"epoch": 1, "value": 21.72} +:::MLPv0.5.0 gnmt 1550073733.112442732 (train.py:497) eval_target: 24.0 +:::MLPv0.5.0 gnmt 1550073733.112777948 (train.py:498) eval_stop +:::MLPv0.5.0 gnmt 1550073733.113427401 (train.py:470) train_epoch: 2 +:::MLPv0.5.0 gnmt 1550073733.976921797 (seq2seq/data/sampler.py:205) input_order +:::MLPv0.5.0 gnmt 1550085915.256659985 (train.py:483) train_checkpoint +:::MLPv0.5.0 gnmt 1550085921.003725052 (train.py:490) eval_start: 2 +:::MLPv0.5.0 gnmt 1550085948.763688564 (train.py:495) eval_accuracy: {"epoch": 2, "value": 22.52} +:::MLPv0.5.0 gnmt 1550085948.764161110 (train.py:497) eval_target: 24.0 +:::MLPv0.5.0 gnmt 1550085948.764480591 (train.py:498) eval_stop +:::MLPv0.5.0 gnmt 1550085948.765091658 (train.py:470) train_epoch: 3 +:::MLPv0.5.0 gnmt 1550085949.625080585 (seq2seq/data/sampler.py:205) input_order +:::MLPv0.5.0 gnmt 1550098119.481947660 (train.py:483) train_checkpoint +:::MLPv0.5.0 gnmt 1550098125.219326258 (train.py:490) eval_start: 3 +:::MLPv0.5.0 gnmt 1550098154.012802124 (train.py:495) eval_accuracy: {"epoch": 3, "value": 23.09} +:::MLPv0.5.0 gnmt 1550098154.013273716 (train.py:497) eval_target: 24.0 +:::MLPv0.5.0 gnmt 1550098154.013695478 (train.py:498) eval_stop +:::MLPv0.5.0 gnmt 1550098154.014324665 (train.py:470) train_epoch: 4 +:::MLPv0.5.0 gnmt 1550098154.874188900 (seq2seq/data/sampler.py:205) input_order +:::MLPv0.5.0 gnmt 1550110332.296045065 (train.py:483) train_checkpoint +:::MLPv0.5.0 gnmt 1550110338.542354584 (train.py:490) eval_start: 4 +:::MLPv0.5.0 gnmt 1550110366.501193285 (train.py:495) eval_accuracy: {"epoch": 4, "value": 23.22} +:::MLPv0.5.0 gnmt 1550110366.501719475 (train.py:497) eval_target: 24.0 +:::MLPv0.5.0 gnmt 1550110366.502229214 (train.py:498) eval_stop +:::MLPv0.5.0 gnmt 1550110366.502951622 (train.py:470) train_epoch: 5 +:::MLPv0.5.0 gnmt 1550110367.328888416 (seq2seq/data/sampler.py:205) input_order +:::MLPv0.5.0 gnmt 1550122538.255764246 (train.py:483) train_checkpoint +:::MLPv0.5.0 gnmt 1550122544.367511034 (train.py:490) eval_start: 5 +:::MLPv0.5.0 gnmt 1550122571.801926613 (train.py:495) eval_accuracy: {"epoch": 5, "value": 24.11} +:::MLPv0.5.0 gnmt 1550122571.802411556 (train.py:497) eval_target: 24.0 +:::MLPv0.5.0 gnmt 1550122571.802731991 (train.py:498) eval_stop +:::MLPv0.5.0 gnmt 1550122571.803371668 (train.py:522) run_stop: {"success": true} +:::MLPv0.5.0 gnmt 1550122571.803844213 (train.py:523) run_final diff --git a/rnn_translator/README.md b/rnn_translator/README.md index 3be80abb3..362c2da41 100644 --- a/rnn_translator/README.md +++ b/rnn_translator/README.md @@ -2,8 +2,15 @@ This problem uses recurrent neural network to do language translation. +## Requirements +* [Python 3.6](https://www.python.org) +* [CUDA 9.0](https://developer.nvidia.com/cuda-90-download-archive) +* [PyTorch 1.0.1](https://pytorch.org) +* [sacrebleu](https://pypi.org/project/sacrebleu/) + ### Recommended setup * [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) +* [pytorch/pytorch:1.0.1-cuda10.0-cudnn7-runtime container](https://hub.docker.com/r/pytorch/pytorch/tags/) # 2. Directions ### Steps to configure machine @@ -63,6 +70,21 @@ Verify data with: bash verify_dataset.sh +### Steps specific to the pytorch version to run and time + + cd pytorch + sudo docker build . --rm -t gnmt:latest + SEED=1 + NOW=`date "+%F-%T"` + sudo nvidia-docker run -it --rm --ipc=host \ + -v $(pwd)/../data:/data \ + gnmt:latest "./run_and_time.sh" $SEED |tee benchmark-$NOW.log + +### one can control which GPUs are used with the NV_GPU variable + sudo NV_GPU=0 nvidia-docker run -it --rm --ipc=host \ + -v $(pwd)/../data:/data \ + gnmt:latest "./run_and_time.sh" $SEED |tee benchmark-$NOW.log + # 3. Dataset/Environment ### Publication/Attribution We use [WMT16 English-German](http://www.statmt.org/wmt16/translation-task.html) @@ -75,15 +97,105 @@ segment text into subword units (BPE), by default it builds shared vocabulary of Preprocessing removes all pairs of sentences that can't be decoded by latin-1 encoder. +### Vocabulary +Vocabulary is generated by the following lines from the `download_dataset.sh` +script: + +``` +# Create vocabulary file for BPE +cat "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.en" "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.de" | \ + ${OUTPUT_DIR}/subword-nmt/get_vocab.py | cut -f1 -d ' ' > "${OUTPUT_DIR}/vocab.bpe.${merge_ops}" +``` + +Vocabulary is stored to the `rnn_translator/data/vocab.bpe.32000` plain text +file. Tokens are separated with a newline character, one token per line. The +vocabulary file doesn't contain special tokens like for example BOS +(begin-of-string) or EOS (end-of-string) tokens. + +Here are first 10 lines from the `rnn_translator/data/vocab.bpe.32000` file: +``` +, +. +the +in +of +and +die +der +to +und +``` + +### Text datasets +The `download_dataset.sh` script automatically creates training, validation and +test datasets. Datasets are stored as plain text files. Sentences are separated +with a newline character, and tokens within each sentence are separated with a +single space character. + +Training data: +* source language (English): `rnn_translator/data/train.tok.clean.bpe.32000.en` +* target language (German): `rnn_translator/data/train.tok.clean.bpe.32000.de` + +Validation data: +* source language (English): `rnn_translator/data/newstest_dev.tok.clean.bpe.32000.en` +* target language (German): `rnn_translator/data/newstest_dev.tok.clean.bpe.32000.de` + +Test data: +* source language (English): `rnn_translator/data/newstest2014.tok.bpe.32000.en` +* target language (German): `rnn_translator/data/newstest2014.de` + * notice that the `newstest2014.de` file isn't tokenized, BLEU evaluation is + performed by the sacrebleu package and sacrebleu expects plain text raw data + (tokenization is performed internally by sacrebleu) + +Here are first 5 lines from the `rnn_translator/data/train.tok.clean.bpe.32000.en` file: +``` +Res@@ um@@ ption of the session +I declare resumed the session of the European Parliament ad@@ jour@@ ned on Friday 17 December 1999 , and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant fes@@ tive period . +Although , as you will have seen , the d@@ read@@ ed ' millenn@@ ium bug ' failed to materi@@ alise , still the people in a number of countries suffered a series of natural disasters that truly were d@@ read@@ ful . +You have requested a debate on this subject in the course of the next few days , during this part-session . +In the meantime , I should like to observe a minute ' s silence , as a number of Members have requested , on behalf of all the victims concerned , particularly those of the terrible stor@@ ms , in the various countries of the European Union . +``` + +And here are first 5 lines from the `rnn_translator/data/train.tok.clean.bpe.32000.de` file: +``` +Wiederaufnahme der Sitzungsperiode +Ich erkläre die am Freitag , dem 17. Dezember unterbro@@ ch@@ ene Sitzungsperiode des Europäischen Parlaments für wieder@@ aufgenommen , wünsche Ihnen nochmals alles Gute zum Jahres@@ wechsel und hoffe , daß Sie schöne Ferien hatten . +Wie Sie feststellen konnten , ist der ge@@ für@@ ch@@ tete " Mill@@ en@@ i@@ um-@@ Bu@@ g " nicht eingetreten . Doch sind Bürger einiger unserer Mitgliedstaaten Opfer von schrecklichen Naturkatastrophen geworden . +Im Parlament besteht der Wunsch nach einer Aussprache im Verlauf dieser Sitzungsperiode in den nächsten Tagen . +Heute möchte ich Sie bitten - das ist auch der Wunsch einiger Kolleginnen und Kollegen - , allen Opfern der St@@ ür@@ me , insbesondere in den verschiedenen Ländern der Europäischen Union , in einer Schwei@@ ge@@ minute zu ge@@ denken . +``` + ### Training and test data separation Training uses WMT16 English-German dataset, validation is on concatenation of newstest2015 and newstest2016, BLEU evaluation is done on newstest2014. + +### Data filtering +Training is executed only on pairs of sentences which satisfy the following equation: +``` + min_len <= src sentence sequence length <= max_len AND + min_len <= tgt sentence sequence length <= max_len +``` +`min_len` is set to 0, `max_len` is set to 50. Source and target sequence +lengths include special BOS (begin-of-sentence) and EOS (end-of-sentence) +tokens. + +Filtering is implemented in `pytorch/seq2seq/data/dataset.py`, class +`LazyParallelDataset`. + ### Training data order -By default training script does bucketing by sequence length. Before each epoch -dataset is randomly shuffled and split into chunks of 80 batches each. Within -each chunk it's sorted by (src + tgt) sequence length and then batches are -reshuffled within each chunk. +Training script does bucketing by sequence length. Bucketing algorithm uses 5 +equal-width buckets (`num_buckets = 5`). Pairs of training sentences are +assigned to buckets by the value of +`max(src_sentence_len // bucket_width, tgt_sentence_len // bucket_width)`, where +`bucket_width = (max_len + num_buckets - 1) // num_buckets`. +Before each training epoch batches are randomly sampled from the buckets (last +incomplete batches are dropped for each bucket), then all batches are +reshuffled. + + +Bucketing is implemented in `pytorch/seq2seq/data/sampler.py`, class +`BucketingSampler`. # 4. Model ### Publication/Attribution @@ -106,47 +218,91 @@ GNMT-like models from [tensorflow/nmt](https://github.com/tensorflow/nmt) and * general: * encoder and decoder are using shared embeddings * data-parallel multi-gpu training - * dynamic loss scaling with backoff for Tensor Cores (mixed precision) training * trained with label smoothing loss (smoothing factor 0.1) * encoder: - * 4-layer LSTM, hidden size 1024, first layer is bidirectional, the rest are - undirectional - * with residual connections starting from 3rd layer - * uses standard LSTM layer (accelerated by cudnn) + * 4-layer LSTM, hidden size 1024, first layer is bidirectional, the rest of + layers are unidirectional + * with residual connections starting from 3rd LSTM layer + * uses standard pytorch nn.LSTM layer + * dropout is applied on input to all LSTM layers, probability of dropout is + set to 0.2 + * hidden state of LSTM layers is initialized with zeros + * weights and bias of LSTM layers is initialized with uniform(-0.1, 0.1) + distribution * decoder: * 4-layer unidirectional LSTM with hidden size 1024 and fully-connected classifier - * with residual connections starting from 3rd layer - * uses standard LSTM layer (accelerated by cudnn) + * with residual connections starting from 3rd LSTM layer + * uses standard pytorch nn.LSTM layer + * dropout is applied on input to all LSTM layers, probability of dropout is + set to 0.2 + * hidden state of LSTM layers is initialized with zeros + * weights and bias of LSTM layers is initialized with uniform(-0.1, 0.1) + distribution + * weights and bias of fully-connected classifier is initialized with + uniform(-0.1, 0.1) distribution * attention: * normalized Bahdanau attention * model uses `gnmt_v2` attention mechanism * output from first LSTM layer of decoder goes into attention, then re-weighted context is concatenated with the input to all subsequent LSTM layers in decoder at the current timestep + * linear transform of keys and queries is initialized with uniform(-0.1, 0.1), + normalization scalar is initialized with 1.0 / sqrt(1024), + normalization bias is initialized with zero * inference: - * beam search with default beam size 5 - * with coverage penalty and length normalization + * beam search with beam size of 5 + * with coverage penalty and length normalization, coverage penalty factor is + set to 0.1, length normalization factor is set to 0.6 and length + normalization constant is set to 5.0 * BLEU computed by [sacrebleu](https://pypi.org/project/sacrebleu/) + +Implementation: +* base Seq2Seq model: `pytorch/seq2seq/models/seq2seq_base.py`, class `Seq2Seq` +* GNMT model: `pytorch/seq2seq/models/gnmt.py`, class `GNMT` +* encoder: `pytorch/seq2seq/models/encoder.py`, class `ResidualRecurrentEncoder` +* decoder: `pytorch/seq2seq/models/decoder.py`, class `ResidualRecurrentDecoder` +* attention: `pytorch/seq2seq/models/attention.py`, class `BahdanauAttention` +* inference (including BLEU evaluation and detokenization): `pytorch/seq2seq/inference/inference.py`, class `Translator` +* beam search: `pytorch/seq2seq/inference/beam_search.py`, class `SequenceGenerator` + ### Loss function Cross entropy loss with label smoothing (smoothing factor = 0.1), padding is not considered part of the loss. +Loss function is implemented in `pytorch/seq2seq/train/smoothing.py`, class +`LabelSmoothing`. + ### Optimizer -Adam optimizer with learning rate 5e-4. +Adam optimizer with learning rate 1e-3, beta1 = 0.9, beta2 = 0.999, epsilon = +1e-8 and no weight decay. +Network is trained with gradient clipping, max L2 norm of gradients is set to 5.0. + +Optimizer is implemented in `pytorch/seq2seq/train/fp_optimizers.py`, class +`Fp32Optimizer`. + +### Learning rate schedule +Model is trained with exponential learning rate warmup for 200 steps and with +step learning rate decay. Decay is started after 2/3 of training steps, decays +for a total of 4 times, at regularly spaced intervals, decay factor is 0.5. + +Learning rate scheduler is implemented in +`pytorch/seq2seq/train/lr_scheduler.py`, class `WarmupMultiStepLR`. # 5. Quality ### Quality metric -BLEU score on newstest2014 dataset. -BLEU scores reported by [sacrebleu](https://pypi.org/project/sacrebleu/) package +Uncased BLEU score on newstest2014 en-de dataset. +BLEU scores reported by [sacrebleu](https://pypi.org/project/sacrebleu/) +package (version 1.2.10). Sacrebleu is executed with the following flags: +`--score-only -lc --tokenize intl`. ### Quality target -Uncased BLEU score of 21.80. +Uncased BLEU score of 24.00. ### Evaluation frequency Evaluation of BLEU score is done after every epoch. ### Evaluation thoroughness -Evaluation uses all of `newstest2014.en`. +Evaluation uses all of `newstest2014.en` (3003 sentences). diff --git a/rnn_translator/download_dataset.sh b/rnn_translator/download_dataset.sh index e4ec0aed4..972c81e13 100644 --- a/rnn_translator/download_dataset.sh +++ b/rnn_translator/download_dataset.sh @@ -16,6 +16,8 @@ set -e +export LANG=C.UTF-8 +export LC_ALL=C.UTF-8 OUTPUT_DIR=${1:-"data"} echo "Writing to ${OUTPUT_DIR}. To change this, set the OUTPUT_DIR environment variable." @@ -136,8 +138,8 @@ cat "${OUTPUT_DIR}/newstest2015.tok.clean.de" \ > "${OUTPUT_DIR}/newstest_dev.tok.clean.de" # Filter datasets -python3 pytorch/scripts/downloaders/filter_dataset.py -f1 ${OUTPUT_DIR}/train.tok.clean.en -f2 ${OUTPUT_DIR}/train.tok.clean.de -python3 pytorch/scripts/downloaders/filter_dataset.py -f1 ${OUTPUT_DIR}/newstest_dev.tok.clean.en -f2 ${OUTPUT_DIR}/newstest_dev.tok.clean.de +python3 pytorch/scripts/filter_dataset.py -f1 ${OUTPUT_DIR}/train.tok.clean.en -f2 ${OUTPUT_DIR}/train.tok.clean.de +python3 pytorch/scripts/filter_dataset.py -f1 ${OUTPUT_DIR}/newstest_dev.tok.clean.en -f2 ${OUTPUT_DIR}/newstest_dev.tok.clean.de # Generate Subword Units (BPE) # Clone Subword NMT diff --git a/rnn_translator/pytorch/Dockerfile b/rnn_translator/pytorch/Dockerfile index a059d38d4..a187921c3 100644 --- a/rnn_translator/pytorch/Dockerfile +++ b/rnn_translator/pytorch/Dockerfile @@ -1,9 +1,9 @@ -FROM pytorch/pytorch:0.4_cuda9_cudnn7 +FROM pytorch/pytorch:1.0.1-cuda10.0-cudnn7-runtime ENV LANG C.UTF-8 ENV LC_ALL C.UTF-8 ADD . /workspace/pytorch -RUN pip install -r /workspace/pytorch/requirements.txt - WORKDIR /workspace/pytorch + +RUN pip install -r requirements.txt diff --git a/rnn_translator/pytorch/README.md b/rnn_translator/pytorch/README.md index 26d84b4b4..362c2da41 100644 --- a/rnn_translator/pytorch/README.md +++ b/rnn_translator/pytorch/README.md @@ -5,17 +5,16 @@ This problem uses recurrent neural network to do language translation. ## Requirements * [Python 3.6](https://www.python.org) * [CUDA 9.0](https://developer.nvidia.com/cuda-90-download-archive) -* [PyTorch 0.4.0](https://pytorch.org) +* [PyTorch 1.0.1](https://pytorch.org) * [sacrebleu](https://pypi.org/project/sacrebleu/) ### Recommended setup * [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) -* [pytorch/pytorch:0.4_cuda9_cudnn7 container](https://hub.docker.com/r/pytorch/pytorch/tags/) +* [pytorch/pytorch:1.0.1-cuda10.0-cudnn7-runtime container](https://hub.docker.com/r/pytorch/pytorch/tags/) # 2. Directions ### Steps to configure machine -Common steps for all rnn-translation tests To setup the environment on Ubuntu 16.04 (16 CPUs, one P100, 100 GB disk), you can use these commands. This may vary on a different operating system or graphics card. @@ -73,6 +72,7 @@ Verify data with: ### Steps specific to the pytorch version to run and time + cd pytorch sudo docker build . --rm -t gnmt:latest SEED=1 NOW=`date "+%F-%T"` @@ -81,8 +81,8 @@ Verify data with: gnmt:latest "./run_and_time.sh" $SEED |tee benchmark-$NOW.log ### one can control which GPUs are used with the NV_GPU variable - sudo NV_GPU=0 nvidia-docker run -it --rm --ipc=host \ - -v $(pwd)/../data:/data \ + sudo NV_GPU=0 nvidia-docker run -it --rm --ipc=host \ + -v $(pwd)/../data:/data \ gnmt:latest "./run_and_time.sh" $SEED |tee benchmark-$NOW.log # 3. Dataset/Environment @@ -97,15 +97,105 @@ segment text into subword units (BPE), by default it builds shared vocabulary of Preprocessing removes all pairs of sentences that can't be decoded by latin-1 encoder. +### Vocabulary +Vocabulary is generated by the following lines from the `download_dataset.sh` +script: + +``` +# Create vocabulary file for BPE +cat "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.en" "${OUTPUT_DIR}/train.tok.clean.bpe.${merge_ops}.de" | \ + ${OUTPUT_DIR}/subword-nmt/get_vocab.py | cut -f1 -d ' ' > "${OUTPUT_DIR}/vocab.bpe.${merge_ops}" +``` + +Vocabulary is stored to the `rnn_translator/data/vocab.bpe.32000` plain text +file. Tokens are separated with a newline character, one token per line. The +vocabulary file doesn't contain special tokens like for example BOS +(begin-of-string) or EOS (end-of-string) tokens. + +Here are first 10 lines from the `rnn_translator/data/vocab.bpe.32000` file: +``` +, +. +the +in +of +and +die +der +to +und +``` + +### Text datasets +The `download_dataset.sh` script automatically creates training, validation and +test datasets. Datasets are stored as plain text files. Sentences are separated +with a newline character, and tokens within each sentence are separated with a +single space character. + +Training data: +* source language (English): `rnn_translator/data/train.tok.clean.bpe.32000.en` +* target language (German): `rnn_translator/data/train.tok.clean.bpe.32000.de` + +Validation data: +* source language (English): `rnn_translator/data/newstest_dev.tok.clean.bpe.32000.en` +* target language (German): `rnn_translator/data/newstest_dev.tok.clean.bpe.32000.de` + +Test data: +* source language (English): `rnn_translator/data/newstest2014.tok.bpe.32000.en` +* target language (German): `rnn_translator/data/newstest2014.de` + * notice that the `newstest2014.de` file isn't tokenized, BLEU evaluation is + performed by the sacrebleu package and sacrebleu expects plain text raw data + (tokenization is performed internally by sacrebleu) + +Here are first 5 lines from the `rnn_translator/data/train.tok.clean.bpe.32000.en` file: +``` +Res@@ um@@ ption of the session +I declare resumed the session of the European Parliament ad@@ jour@@ ned on Friday 17 December 1999 , and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant fes@@ tive period . +Although , as you will have seen , the d@@ read@@ ed ' millenn@@ ium bug ' failed to materi@@ alise , still the people in a number of countries suffered a series of natural disasters that truly were d@@ read@@ ful . +You have requested a debate on this subject in the course of the next few days , during this part-session . +In the meantime , I should like to observe a minute ' s silence , as a number of Members have requested , on behalf of all the victims concerned , particularly those of the terrible stor@@ ms , in the various countries of the European Union . +``` + +And here are first 5 lines from the `rnn_translator/data/train.tok.clean.bpe.32000.de` file: +``` +Wiederaufnahme der Sitzungsperiode +Ich erkläre die am Freitag , dem 17. Dezember unterbro@@ ch@@ ene Sitzungsperiode des Europäischen Parlaments für wieder@@ aufgenommen , wünsche Ihnen nochmals alles Gute zum Jahres@@ wechsel und hoffe , daß Sie schöne Ferien hatten . +Wie Sie feststellen konnten , ist der ge@@ für@@ ch@@ tete " Mill@@ en@@ i@@ um-@@ Bu@@ g " nicht eingetreten . Doch sind Bürger einiger unserer Mitgliedstaaten Opfer von schrecklichen Naturkatastrophen geworden . +Im Parlament besteht der Wunsch nach einer Aussprache im Verlauf dieser Sitzungsperiode in den nächsten Tagen . +Heute möchte ich Sie bitten - das ist auch der Wunsch einiger Kolleginnen und Kollegen - , allen Opfern der St@@ ür@@ me , insbesondere in den verschiedenen Ländern der Europäischen Union , in einer Schwei@@ ge@@ minute zu ge@@ denken . +``` + ### Training and test data separation Training uses WMT16 English-German dataset, validation is on concatenation of newstest2015 and newstest2016, BLEU evaluation is done on newstest2014. + +### Data filtering +Training is executed only on pairs of sentences which satisfy the following equation: +``` + min_len <= src sentence sequence length <= max_len AND + min_len <= tgt sentence sequence length <= max_len +``` +`min_len` is set to 0, `max_len` is set to 50. Source and target sequence +lengths include special BOS (begin-of-sentence) and EOS (end-of-sentence) +tokens. + +Filtering is implemented in `pytorch/seq2seq/data/dataset.py`, class +`LazyParallelDataset`. + ### Training data order -By default training script does bucketing by sequence length. Before each epoch -dataset is randomly shuffled and split into chunks of 80 batches each. Within -each chunk it's sorted by (src + tgt) sequence length and then batches are -reshuffled within each chunk. +Training script does bucketing by sequence length. Bucketing algorithm uses 5 +equal-width buckets (`num_buckets = 5`). Pairs of training sentences are +assigned to buckets by the value of +`max(src_sentence_len // bucket_width, tgt_sentence_len // bucket_width)`, where +`bucket_width = (max_len + num_buckets - 1) // num_buckets`. +Before each training epoch batches are randomly sampled from the buckets (last +incomplete batches are dropped for each bucket), then all batches are +reshuffled. + + +Bucketing is implemented in `pytorch/seq2seq/data/sampler.py`, class +`BucketingSampler`. # 4. Model ### Publication/Attribution @@ -128,47 +218,91 @@ GNMT-like models from [tensorflow/nmt](https://github.com/tensorflow/nmt) and * general: * encoder and decoder are using shared embeddings * data-parallel multi-gpu training - * dynamic loss scaling with backoff for Tensor Cores (mixed precision) training * trained with label smoothing loss (smoothing factor 0.1) * encoder: - * 4-layer LSTM, hidden size 1024, first layer is bidirectional, the rest are - undirectional - * with residual connections starting from 3rd layer - * uses standard LSTM layer (accelerated by cudnn) + * 4-layer LSTM, hidden size 1024, first layer is bidirectional, the rest of + layers are unidirectional + * with residual connections starting from 3rd LSTM layer + * uses standard pytorch nn.LSTM layer + * dropout is applied on input to all LSTM layers, probability of dropout is + set to 0.2 + * hidden state of LSTM layers is initialized with zeros + * weights and bias of LSTM layers is initialized with uniform(-0.1, 0.1) + distribution * decoder: * 4-layer unidirectional LSTM with hidden size 1024 and fully-connected classifier - * with residual connections starting from 3rd layer - * uses standard LSTM layer (accelerated by cudnn) + * with residual connections starting from 3rd LSTM layer + * uses standard pytorch nn.LSTM layer + * dropout is applied on input to all LSTM layers, probability of dropout is + set to 0.2 + * hidden state of LSTM layers is initialized with zeros + * weights and bias of LSTM layers is initialized with uniform(-0.1, 0.1) + distribution + * weights and bias of fully-connected classifier is initialized with + uniform(-0.1, 0.1) distribution * attention: * normalized Bahdanau attention * model uses `gnmt_v2` attention mechanism * output from first LSTM layer of decoder goes into attention, then re-weighted context is concatenated with the input to all subsequent LSTM layers in decoder at the current timestep + * linear transform of keys and queries is initialized with uniform(-0.1, 0.1), + normalization scalar is initialized with 1.0 / sqrt(1024), + normalization bias is initialized with zero * inference: - * beam search with default beam size 5 - * with coverage penalty and length normalization + * beam search with beam size of 5 + * with coverage penalty and length normalization, coverage penalty factor is + set to 0.1, length normalization factor is set to 0.6 and length + normalization constant is set to 5.0 * BLEU computed by [sacrebleu](https://pypi.org/project/sacrebleu/) + +Implementation: +* base Seq2Seq model: `pytorch/seq2seq/models/seq2seq_base.py`, class `Seq2Seq` +* GNMT model: `pytorch/seq2seq/models/gnmt.py`, class `GNMT` +* encoder: `pytorch/seq2seq/models/encoder.py`, class `ResidualRecurrentEncoder` +* decoder: `pytorch/seq2seq/models/decoder.py`, class `ResidualRecurrentDecoder` +* attention: `pytorch/seq2seq/models/attention.py`, class `BahdanauAttention` +* inference (including BLEU evaluation and detokenization): `pytorch/seq2seq/inference/inference.py`, class `Translator` +* beam search: `pytorch/seq2seq/inference/beam_search.py`, class `SequenceGenerator` + ### Loss function Cross entropy loss with label smoothing (smoothing factor = 0.1), padding is not considered part of the loss. +Loss function is implemented in `pytorch/seq2seq/train/smoothing.py`, class +`LabelSmoothing`. + ### Optimizer -Adam optimizer with learning rate 5e-4. +Adam optimizer with learning rate 1e-3, beta1 = 0.9, beta2 = 0.999, epsilon = +1e-8 and no weight decay. +Network is trained with gradient clipping, max L2 norm of gradients is set to 5.0. + +Optimizer is implemented in `pytorch/seq2seq/train/fp_optimizers.py`, class +`Fp32Optimizer`. + +### Learning rate schedule +Model is trained with exponential learning rate warmup for 200 steps and with +step learning rate decay. Decay is started after 2/3 of training steps, decays +for a total of 4 times, at regularly spaced intervals, decay factor is 0.5. + +Learning rate scheduler is implemented in +`pytorch/seq2seq/train/lr_scheduler.py`, class `WarmupMultiStepLR`. # 5. Quality ### Quality metric -BLEU score on newstest2014 dataset. -BLEU scores reported by [sacrebleu](https://pypi.org/project/sacrebleu/) package +Uncased BLEU score on newstest2014 en-de dataset. +BLEU scores reported by [sacrebleu](https://pypi.org/project/sacrebleu/) +package (version 1.2.10). Sacrebleu is executed with the following flags: +`--score-only -lc --tokenize intl`. ### Quality target -Uncased BLEU score of 21.80. +Uncased BLEU score of 24.00. ### Evaluation frequency Evaluation of BLEU score is done after every epoch. ### Evaluation thoroughness -Evaluation uses all of `newstest2014.en`. +Evaluation uses all of `newstest2014.en` (3003 sentences). diff --git a/rnn_translator/pytorch/multiproc.py b/rnn_translator/pytorch/multiproc.py deleted file mode 100644 index 74b26e884..000000000 --- a/rnn_translator/pytorch/multiproc.py +++ /dev/null @@ -1,46 +0,0 @@ -import sys -import subprocess - -import torch - -def main(): - argslist = list(sys.argv)[1:] - world_size = torch.cuda.device_count() - - if '--world-size' in argslist: - argslist[argslist.index('--world-size') + 1] = str(world_size) - else: - argslist.append('--world-size') - argslist.append(str(world_size)) - - workers = [] - - for i in range(world_size): - if '--rank' in argslist: - argslist[argslist.index('--rank') + 1] = str(i) - else: - argslist.append('--rank') - argslist.append(str(i)) - stdout = None if i == 0 else subprocess.DEVNULL - worker = subprocess.Popen([str(sys.executable)] + argslist, stdout=stdout) - workers.append(worker) - - returncode = 0 - try: - for worker in workers: - worker_returncode = worker.wait() - if worker_returncode != 0: - returncode = 1 - except KeyboardInterrupt: - print('Pressed CTRL-C, TERMINATING') - for worker in workers: - worker.terminate() - for worker in workers: - worker.wait() - raise - - sys.exit(returncode) - - -if __name__ == "__main__": - main() diff --git a/rnn_translator/pytorch/requirements.txt b/rnn_translator/pytorch/requirements.txt index bff621450..fc40a4961 100644 --- a/rnn_translator/pytorch/requirements.txt +++ b/rnn_translator/pytorch/requirements.txt @@ -1,3 +1,3 @@ sacrebleu==1.2.10 -numpy==1.14.2 -mlperf-compliance==0.0.4 +git+git://github.com/NVIDIA/apex.git@9041a868a1a253172d94b113a963375b9badd030#egg=apex +mlperf-compliance==0.0.10 diff --git a/rnn_translator/pytorch/run.sh b/rnn_translator/pytorch/run.sh index a8bcde92c..1d38a395a 100755 --- a/rnn_translator/pytorch/run.sh +++ b/rnn_translator/pytorch/run.sh @@ -3,16 +3,12 @@ set -e DATASET_DIR='/data' -RESULTS_DIR='gnmt_wmt16' SEED=${1:-"1"} -TARGET=${2:-"21.80"} +TARGET=${2:-"24.00"} # run training -python3 -m multiproc train.py \ - --save ${RESULTS_DIR} \ +python3 train.py \ --dataset-dir ${DATASET_DIR} \ --seed $SEED \ - --target-bleu $TARGET \ - --epochs 8 \ - --batch-size 128 + --target-bleu $TARGET diff --git a/rnn_translator/pytorch/run_and_time.sh b/rnn_translator/pytorch/run_and_time.sh index de872ef18..0747c1535 100755 --- a/rnn_translator/pytorch/run_and_time.sh +++ b/rnn_translator/pytorch/run_and_time.sh @@ -13,7 +13,7 @@ echo "STARTING TIMING RUN AT $start_fmt" # run benchmark seed=${1:-"1"} -target=21.80 +target=24.00 echo "running benchmark" ./run.sh $seed $target diff --git a/rnn_translator/pytorch/scripts/benchmark_inference.sh b/rnn_translator/pytorch/scripts/benchmark_inference.sh deleted file mode 100644 index 35494d39c..000000000 --- a/rnn_translator/pytorch/scripts/benchmark_inference.sh +++ /dev/null @@ -1,101 +0,0 @@ -#!/bin/bash - -set -e - -DATASET_DIR='../data/wmt16_de_en' -RESULTS_DIR='gnmt_wmt16' - -# sort by length (ascending) -cat ${DATASET_DIR}/newstest2014.tok.bpe.32000.en \ - | awk '{ print length, $0 }' \ - | sort -n -s \ - | cut -d" " -f2- > /tmp/newstest2014.tok.bpe.32000.en.sorted - -batches=(512 256 128 64 32) -beams=(1 2 5 10) -maths=(fp16 fp32) - -model=../results/${RESULTS_DIR}/model_best.pth - -odir=../results/inference_benchmark -mkdir -p $odir - -echo RUNNING on unsorted dataset -rm -rf $odir/fp16_perf_unsorted.log -rm -rf $odir/fp32_perf_unsorted.log -rm -rf $odir/fp16_bleu.log -rm -rf $odir/fp32_bleu.log -ifile=${DATASET_DIR}/newstest2014.tok.bpe.32000.en - -for math in "${maths[@]}" -do - for batch in "${batches[@]}" - do - for beam in "${beams[@]}" - do - echo RUNNING: batch_size: $batch, beam_size: $beam, math: $math - - # run translation - python3 translate.py \ - -i $ifile \ - -m $model \ - --math $math \ - --print-freq 1 \ - --beam-size $beam \ - --batch-size $batch \ - -o /tmp/output.tok &> /tmp/log.log - - tok_per_sec=`cat /tmp/log.log \ - |tail -n 1 \ - |cut -f 1 \ - |cut -d ':' -f 2 |tr -d ' '` - - # detokenize output - perl ${DATASET_DIR}/mosesdecoder/scripts/tokenizer/detokenizer.perl -l de \ - < /tmp/output.tok \ - > /tmp/output.detok 2> /dev/null - - bleu=`sacrebleu ${DATASET_DIR}/newstest2014.de --input /tmp/output.detok -lc --tokenize intl -b` - - echo -e $tok_per_sec '\t\t batch: '$batch 'beam: ' $beam >> $odir/${math}_perf_unsorted.log - echo -e $bleu '\t\t batch: '$batch 'beam: ' $beam >> $odir/${math}_bleu.log - echo Tokens per second: $tok_per_sec, BLEU: $bleu - done - done -done - - -echo RUNNING on sorted dataset -rm -rf $odir/fp16_perf_sorted.log -rm -rf $odir/fp32_perf_sorted.log -ifile=/tmp/newstest2014.tok.bpe.32000.en.sorted - - -for math in "${maths[@]}" -do - for batch in "${batches[@]}" - do - for beam in "${beams[@]}" - do - echo RUNNING: batch_size: $batch, beam_size: $beam, math: $math - - # run translation - python3 translate.py \ - -i $ifile \ - -m $model \ - --math $math \ - --print-freq 1 \ - --beam-size $beam \ - --batch-size $batch \ - -o /tmp/output.tok &> /tmp/log.log - - tok_per_sec=`cat /tmp/log.log \ - |tail -n 1 \ - |cut -f 1 \ - |cut -d ':' -f 2 |tr -d ' '` - - echo -e $tok_per_sec '\t\t batch: '$batch 'beam: ' $beam >> $odir/${math}_perf_sorted.log - echo Tokens per second: $tok_per_sec - done - done -done diff --git a/rnn_translator/pytorch/scripts/benchmark_training.sh b/rnn_translator/pytorch/scripts/benchmark_training.sh deleted file mode 100644 index 1589b795c..000000000 --- a/rnn_translator/pytorch/scripts/benchmark_training.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -DATASET_DIR='../data/wmt16_de_en' - -hiddens=(1024) -batches=(128) -maths=(fp16 fp32) -gpus=(1 2 4 8) - -for hidden in "${hiddens[@]}" -do - for math in "${maths[@]}" - do - for batch in "${batches[@]}" - do - for gpu in "${gpus[@]}" - do - export CUDA_VISIBLE_DEVICES=`seq -s "," 0 $((gpu - 1))` - python3 -m multiproc train.py \ - --save benchmark_gpu_${gpu}_math_${math}_batch_${batch}_hidden_${hidden} \ - --dataset-dir ${DATASET_DIR} \ - --seed 12345 \ - --epochs 1 \ - --math ${math} \ - --print-freq 1 \ - --batch-size ${batch} \ - --disable-eval \ - --max-size $((128 * ${batch} * ${gpu})) \ - --max-length-train 48 \ - --max-length-val 150 \ - --model-config "{'num_layers': 8, 'hidden_size': ${hidden}, 'dropout':0.2, 'share_embedding': True}" \ - --optimization-config "{'optimizer': 'Adam', 'lr': 5e-4}" - done - done - done -done diff --git a/rnn_translator/pytorch/scripts/docker/build.sh b/rnn_translator/pytorch/scripts/docker/build.sh index c4297f4d6..e0e15f6cd 100644 --- a/rnn_translator/pytorch/scripts/docker/build.sh +++ b/rnn_translator/pytorch/scripts/docker/build.sh @@ -1,3 +1,3 @@ #!/bin/bash -docker build pytorch --rm -t gnmt:latest +docker build . --rm -t gnmt:latest diff --git a/rnn_translator/pytorch/scripts/downloaders/filter_dataset.py b/rnn_translator/pytorch/scripts/downloaders/filter_dataset.py deleted file mode 100644 index 081e53bbf..000000000 --- a/rnn_translator/pytorch/scripts/downloaders/filter_dataset.py +++ /dev/null @@ -1,50 +0,0 @@ -import argparse -import string -from collections import Counter - -def parse_args(): - parser = argparse.ArgumentParser(description='Clean dataset') - parser.add_argument('-f1', '--file1', help='file1') - parser.add_argument('-f2', '--file2', help='file2') - return parser.parse_args() - - -def save_output(fname, data): - with open(fname, 'w') as f: - f.writelines(data) - -def main(): - args = parse_args() - - c = Counter() - skipped = 0 - valid = 0 - data1 = [] - data2 = [] - - with open(args.file1) as f1, open(args.file2) as f2: - for idx, lines in enumerate(zip(f1, f2)): - line1, line2 = lines - if idx % 100000 == 1: - print('Processed {} lines'.format(idx)) - try: - line1.encode('latin1') - line2.encode('latin1') - except UnicodeEncodeError: - skipped += 1 - else: - data1.append(line1) - data2.append(line2) - valid += 1 - c.update(line1) - - ratio = valid / (skipped + valid) - print('Skipped: {}, Valid: {}, Valid ratio {}'.format(skipped, valid, ratio)) - print('Character frequency:', c) - - save_output(args.file1, data1) - save_output(args.file2, data2) - - -if __name__ == '__main__': - main() diff --git a/rnn_translator/pytorch/scripts/evaluate.sh b/rnn_translator/pytorch/scripts/evaluate.sh deleted file mode 100644 index c1fb95a9c..000000000 --- a/rnn_translator/pytorch/scripts/evaluate.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash - -set -e - -DATASET_DIR='../data/wmt16_de_en' -RESULTS_DIR='gnmt_wmt16' - -# evaluate best checkpoint on newstest2014 -python3 -u translate.py \ - --math fp16 \ - --model ../results/${RESULTS_DIR}/model_best.pth \ - --input ${DATASET_DIR}/newstest2014.tok.bpe.32000.en \ - --output ../results/${RESULTS_DIR}/newstest2014_out.tok.de \ - |tee ../results/${RESULTS_DIR}/inference.log - -# detokenize output -perl ${DATASET_DIR}/mosesdecoder/scripts/tokenizer/detokenizer.perl -l de \ - < ../results/${RESULTS_DIR}/newstest2014_out.tok.de \ - > ../results/${RESULTS_DIR}/newstest2014_out.de - -# compute uncased BLEU -cat ../results/${RESULTS_DIR}/newstest2014_out.de \ - |sacrebleu ${DATASET_DIR}/newstest2014.de -lc \ - --tokenize intl |tee ../results/${RESULTS_DIR}/bleu_nt14_lc.log - -# compute cased BLEU -cat ../results/${RESULTS_DIR}/newstest2014_out.de \ - |sacrebleu ${DATASET_DIR}/newstest2014.de \ - --tokenize intl |tee ../results/${RESULTS_DIR}/bleu_nt14.log diff --git a/rnn_translator/pytorch/scripts/filter_dataset.py b/rnn_translator/pytorch/scripts/filter_dataset.py new file mode 100644 index 000000000..3168d476e --- /dev/null +++ b/rnn_translator/pytorch/scripts/filter_dataset.py @@ -0,0 +1,79 @@ +import argparse +from collections import Counter + + +def parse_args(): + parser = argparse.ArgumentParser(description='Clean dataset') + parser.add_argument('-f1', '--file1', help='file1') + parser.add_argument('-f2', '--file2', help='file2') + return parser.parse_args() + + +def save_output(fname, data): + with open(fname, 'w') as f: + f.writelines(data) + + +def main(): + """ + Discards all pairs of sentences which can't be decoded by latin-1 encoder. + + It aims to filter out sentences with rare unicode glyphs and pairs which + are most likely not valid English-German sentences. + + Examples of discarded sentences: + + ✿★★★Hommage au king de la pop ★★★✿ ✿★★★Que son âme repos... + + Для их осуществления нам, прежде всего, необходимо преодолеть + возражения рыночных фундаменталистов, которые хотят ликвидировать или + уменьшить роль МВФ. + + practised as a scientist in various medical departments of the ⇗Medical + University of Hanover , the ⇗University of Ulm , and the ⇗RWTH Aachen + (rheumatology, pharmacology, physiology, pathology, microbiology, + immunology and electron-microscopy). + + The same shift】 and press 【】 【alt out with a smaller diameter + circle. + + Brought to you by ABMSUBS ♥leira(Coordinator/Translator) + ♥chibichan93(Timer/Typesetter) ♥ja... + + Some examples: &0u - ☺ &0U - ☻ &tel - ☏ &PI - ¶ &SU - ☼ &cH- - ♥ &M2=♫ + &sn - ﺵ SGML maps SGML to unicode. + """ + args = parse_args() + + c = Counter() + skipped = 0 + valid = 0 + data1 = [] + data2 = [] + + with open(args.file1) as f1, open(args.file2) as f2: + for idx, lines in enumerate(zip(f1, f2)): + line1, line2 = lines + if idx % 100000 == 1: + print('Processed {} lines'.format(idx)) + try: + line1.encode('latin1') + line2.encode('latin1') + except UnicodeEncodeError: + skipped += 1 + else: + data1.append(line1) + data2.append(line2) + valid += 1 + c.update(line1) + + ratio = valid / (skipped + valid) + print('Skipped: {}, Valid: {}, Valid ratio {}'.format(skipped, valid, ratio)) + print('Character frequency:', c) + + save_output(args.file1, data1) + save_output(args.file2, data2) + + +if __name__ == '__main__': + main() diff --git a/rnn_translator/pytorch/scripts/parse_train_benchmark.sh b/rnn_translator/pytorch/scripts/parse_train_benchmark.sh deleted file mode 100644 index b98b240c7..000000000 --- a/rnn_translator/pytorch/scripts/parse_train_benchmark.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/bin/bash - -hiddens=(1024) -batches=(128) -maths=(fp16 fp32) -gpus=(1 2 4 8) - -sentences=3418306 - -echo -e [parameters] "\t\t\t" [tokens / s] [second per epoch] - -for batch in "${batches[@]}" -do - for hidden in "${hiddens[@]}" - do - for math in "${maths[@]}" - do - for gpu in "${gpus[@]}" - do - dir=../results/benchmark_gpu_${gpu}_math_${math}_batch_${batch}_hidden_${hidden}/ - if [ ! -d $dir ]; then - echo Directory $dir does not exist - continue - fi - - total_tokens_per_s=0 - for gpu_id in `seq 0 $((gpu - 1))` - do - tokens_per_s=`cat ${dir}/log_gpu_${gpu_id}.log \ - |grep TRAIN \ - |cut -f 4 \ - |sed -E -n 's/.*\(([0-9]+)\).*/\1/p' \ - |tail -n 1` - total_tokens_per_s=$((total_tokens_per_s + tokens_per_s)) - done - - batch_time=`cat ${dir}/log_gpu_0.log \ - |grep TRAIN \ - |cut -f 2 \ - |sed -E -n 's/.*\(([.0-9]+)\).*/\1/p' \ - |tail -n 1` - - n_batches=$(( $sentences / ($batch * $gpu))) - epoch_time=`awk "BEGIN {print $n_batches * $batch_time}"` - - echo -e math: $math batch: $batch gpus: $gpu "\t\t" $total_tokens_per_s "\t" $epoch_time - done - done - done -done diff --git a/rnn_translator/pytorch/scripts/trim_checkpoints.py b/rnn_translator/pytorch/scripts/trim_checkpoints.py deleted file mode 100644 index 07c9f6d57..000000000 --- a/rnn_translator/pytorch/scripts/trim_checkpoints.py +++ /dev/null @@ -1,48 +0,0 @@ -import argparse -import glob -import os -import subprocess -import sys -import time - -import torch - - -def parse_args(): - """ - Parse arguments. - """ - parser = argparse.ArgumentParser(description='Trim training checkpoints') - parser.add_argument('--path', required=True, - help='path to directory with checkpoints (*.pth)') - parser.add_argument('--suffix', default='trim', - help='suffix appended to the name of output file') - return parser.parse_args() - - -def get_checkpoints(path): - """ - Gets all *.pth checkpoints from a given directory. - - :param path: - """ - checkpoints = glob.glob(os.path.join(path, '*.pth')) - return checkpoints - -def main(): - # Add parent folder to sys.path - sys.path.insert(1, os.path.join(sys.path[0], '..')) - args = parse_args() - - checkpoints = get_checkpoints(args.path) - print('All checkpoints:', checkpoints) - - for checkpoint in checkpoints: - print('Processing ', checkpoint) - chkpt = torch.load(checkpoint) - chkpt['optimizer'] = None - output_file = checkpoint.replace('pth', args.suffix + '.pth') - torch.save(chkpt, output_file) - -if __name__ == "__main__": - main() diff --git a/rnn_translator/pytorch/seq2seq/data/config.py b/rnn_translator/pytorch/seq2seq/data/config.py index 69af2a111..0582e0414 100644 --- a/rnn_translator/pytorch/seq2seq/data/config.py +++ b/rnn_translator/pytorch/seq2seq/data/config.py @@ -3,19 +3,30 @@ BOS_TOKEN = '' EOS_TOKEN = '<\s>' +# special PAD, UNKNOWN, BEGIN-OF-STRING, END-OF-STRING tokens PAD, UNK, BOS, EOS = [0, 1, 2, 3] +# path to the BPE vocabulary file, relative to the data directory, it should +# point to file generated by subword-nmt/get_vocab.py VOCAB_FNAME = 'vocab.bpe.32000' +# paths to source and target training files, relative to the data directory, it +# should point to BPE-encoded files, generated by subword-nmt/apply_bpe.py SRC_TRAIN_FNAME = 'train.tok.clean.bpe.32000.en' TGT_TRAIN_FNAME = 'train.tok.clean.bpe.32000.de' +# paths to source and target validation files, relative to the data directory, +# it should point to BPE-encoded files, generated by subword-nmt/apply_bpe.py SRC_VAL_FNAME = 'newstest_dev.tok.clean.bpe.32000.en' TGT_VAL_FNAME = 'newstest_dev.tok.clean.bpe.32000.de' +# path to the test source file, relative to the data directory, it should point +# to BPE-encoded file, generated by subword-nmt/apply_bpe.py SRC_TEST_FNAME = 'newstest2014.tok.bpe.32000.en' -TGT_TEST_FNAME = 'newstest2014.tok.bpe.32000.de' +# path to the test target file, relative to the data directory, it should point +# to plaintext file, tokenization is performed by the sacrebleu package TGT_TEST_TARGET_FNAME = 'newstest2014.de' +# path to the moses detokenizer, relative to the data directory DETOKENIZER = 'mosesdecoder/scripts/tokenizer/detokenizer.perl' diff --git a/rnn_translator/pytorch/seq2seq/data/dataset.py b/rnn_translator/pytorch/seq2seq/data/dataset.py index e0fbfa867..5d63884ee 100644 --- a/rnn_translator/pytorch/seq2seq/data/dataset.py +++ b/rnn_translator/pytorch/seq2seq/data/dataset.py @@ -1,16 +1,33 @@ import logging +from operator import itemgetter import torch -from torch.utils.data import Dataset -from torch.utils.data.sampler import SequentialSampler, RandomSampler -from seq2seq.data.sampler import BucketingSampler from torch.utils.data import DataLoader +from torch.utils.data import Dataset import seq2seq.data.config as config +from seq2seq.data.sampler import BucketingSampler +from seq2seq.data.sampler import DistributedSampler +from seq2seq.data.sampler import ShardingSampler +from seq2seq.data.sampler import StaticDistributedSampler -def build_collate_fn(batch_first=False, sort=False): +def build_collate_fn(batch_first=False, parallel=True, sort=False): + """ + Factory for collate_fn functions. + + :param batch_first: if True returns batches in (batch, seq) format, if + False returns in (seq, batch) format + :param parallel: if True builds batches from parallel corpus (src, tgt) + :param sort: if True sorts by src sequence length within each batch + """ def collate_seq(seq): + """ + Builds batches for training or inference. + Batches are returned as pytorch tensors, with padding. + + :param seq: list of sequences + """ lengths = [len(s) for s in seq] batch_length = max(lengths) @@ -26,68 +43,127 @@ def collate_seq(seq): return (seq_tensor, lengths) - def collate(seqs): + def parallel_collate(seqs): + """ + Builds batches from parallel dataset (src, tgt), optionally sorts batch + by src sequence length. + + :param seqs: tuple of (src, tgt) sequences + """ src_seqs, tgt_seqs = zip(*seqs) if sort: - key = lambda item: len(item[1]) - indices, src_seqs = zip(*sorted(enumerate(src_seqs), key=key, - reverse=True)) + indices, src_seqs = zip(*sorted(enumerate(src_seqs), + key=lambda item: len(item[1]), + reverse=True)) tgt_seqs = [tgt_seqs[idx] for idx in indices] + + return tuple([collate_seq(s) for s in [src_seqs, tgt_seqs]]) + + def single_collate(src_seqs): + """ + Builds batches from text dataset, optionally sorts batch by src + sequence length. + + :param src_seqs: source sequences + """ + if sort: + indices, src_seqs = zip(*sorted(enumerate(src_seqs), + key=lambda item: len(item[1]), + reverse=True)) else: indices = range(len(src_seqs)) - return tuple([collate_seq(s) for s in [src_seqs, tgt_seqs]] + [indices]) + return collate_seq(src_seqs), tuple(indices) - return collate + if parallel: + return parallel_collate + else: + return single_collate -class ParallelDataset(Dataset): - def __init__(self, src_fname, tgt_fname, tokenizer, - min_len, max_len, sort=False, max_size=None): +class TextDataset(Dataset): + def __init__(self, src_fname, tokenizer, min_len=None, max_len=None, + sort=False, max_size=None): + """ + Constructor for the TextDataset. Builds monolingual dataset. + + :param src_fname: path to the file with data + :param tokenizer: tokenizer + :param min_len: minimum sequence length + :param max_len: maximum sequence length + :param sort: sorts dataset by sequence length + :param max_size: loads at most 'max_size' samples from the input file, + if None loads the entire dataset + """ self.min_len = min_len self.max_len = max_len + self.parallel = False + self.sorted = False self.src = self.process_data(src_fname, tokenizer, max_size) - self.tgt = self.process_data(tgt_fname, tokenizer, max_size) - assert len(self.src) == len(self.tgt) - self.filter_data(min_len, max_len) - assert len(self.src) == len(self.tgt) + if min_len is not None and max_len is not None: + self.filter_data(min_len, max_len) - lengths = [len(s) + len(t) for (s, t) in zip(self.src, self.tgt)] + lengths = [len(s) for s in self.src] self.lengths = torch.tensor(lengths) if sort: self.sort_by_length() def sort_by_length(self): + """ + Sorts dataset by the sequence length. + """ self.lengths, indices = self.lengths.sort(descending=True) self.src = [self.src[idx] for idx in indices] - self.tgt = [self.tgt[idx] for idx in indices] + self.indices = indices.tolist() + self.sorted = True + + def unsort(self, array): + """ + "Unsorts" given array (restores original order of elements before + dataset was sorted by sequence length). + + :param array: array to be "unsorted" + """ + if self.sorted: + inverse = sorted(enumerate(self.indices), key=itemgetter(1)) + array = [array[i[0]] for i in inverse] + return array def filter_data(self, min_len, max_len): - logging.info(f'filtering data, min len: {min_len}, max len: {max_len}') + """ + Preserves only samples which satisfy the following inequality: + min_len <= sample sequence length <= max_len - initial_len = len(self.src) + :param min_len: minimum sequence length + :param max_len: maximum sequence length + """ + logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}') + initial_len = len(self.src) filtered_src = [] - filtered_tgt = [] - for src, tgt in zip(self.src, self.tgt): - if min_len <= len(src) <= max_len and \ - min_len <= len(tgt) <= max_len: + for src in self.src: + if min_len <= len(src) <= max_len: filtered_src.append(src) - filtered_tgt.append(tgt) self.src = filtered_src - self.tgt = filtered_tgt - filtered_len = len(self.src) - logging.info(f'pairs before: {initial_len}, after: {filtered_len}') + logging.info(f'Pairs before: {initial_len}, after: {filtered_len}') def process_data(self, fname, tokenizer, max_size): - logging.info(f'processing data from {fname}') + """ + Loads data from the input file. + + :param fname: input file name + :param tokenizer: tokenizer + :param max_size: loads at most 'max_size' samples from the input file, + if None loads the entire dataset + """ + logging.info(f'Processing data from {fname}') data = [] with open(fname) as dfile: for idx, line in enumerate(dfile): @@ -102,22 +178,208 @@ def __len__(self): return len(self.src) def __getitem__(self, idx): - return self.src[idx], self.tgt[idx] + return self.src[idx] - def get_loader(self, batch_size=1, shuffle=False, num_workers=0, batch_first=False, - drop_last=False, distributed=False, bucket=True): + def get_loader(self, batch_size=1, seeds=None, shuffle=False, + num_workers=0, batch_first=False, pad=False, + batching=None, batching_opt={}): - collate_fn = build_collate_fn(batch_first, sort=True) + collate_fn = build_collate_fn(batch_first, parallel=self.parallel, + sort=True) if shuffle: - sampler = BucketingSampler(self, batch_size, bucket) + if batching == 'random': + sampler = DistributedSampler(self, batch_size, seeds) + elif batching == 'sharding': + sampler = ShardingSampler(self, batch_size, seeds, + batching_opt['shard_size']) + elif batching == 'bucketing': + sampler = BucketingSampler(self, batch_size, seeds, + batching_opt['num_buckets']) + else: + raise NotImplementedError else: - sampler = SequentialSampler(self) + sampler = StaticDistributedSampler(self, batch_size, pad) return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, sampler=sampler, num_workers=num_workers, - pin_memory=False, - drop_last=drop_last) + pin_memory=True, + drop_last=False) + + +class ParallelDataset(TextDataset): + def __init__(self, src_fname, tgt_fname, tokenizer, + min_len, max_len, sort=False, max_size=None): + """ + Constructor for the ParallelDataset. + Tokenization is done when the data is loaded from the disk. + + :param src_fname: path to the file with src language data + :param tgt_fname: path to the file with tgt language data + :param tokenizer: tokenizer + :param min_len: minimum sequence length + :param max_len: maximum sequence length + :param sort: sorts dataset by sequence length + :param max_size: loads at most 'max_size' samples from the input file, + if None loads the entire dataset + """ + + self.min_len = min_len + self.max_len = max_len + self.parallel = True + self.sorted = False + + self.src = self.process_data(src_fname, tokenizer, max_size) + self.tgt = self.process_data(tgt_fname, tokenizer, max_size) + assert len(self.src) == len(self.tgt) + + self.filter_data(min_len, max_len) + assert len(self.src) == len(self.tgt) + + src_lengths = [len(s) for s in self.src] + tgt_lengths = [len(t) for t in self.tgt] + self.src_lengths = torch.tensor(src_lengths) + self.tgt_lengths = torch.tensor(tgt_lengths) + self.lengths = self.src_lengths + self.tgt_lengths + + if sort: + self.sort_by_length() + + def sort_by_length(self): + """ + Sorts dataset by the sequence length. + """ + self.lengths, indices = self.lengths.sort(descending=True) + + self.src = [self.src[idx] for idx in indices] + self.tgt = [self.tgt[idx] for idx in indices] + self.src_lengths = [self.src_lengths[idx] for idx in indices] + self.tgt_lengths = [self.tgt_lengths[idx] for idx in indices] + self.indices = indices.tolist() + self.sorted = True + + def filter_data(self, min_len, max_len): + """ + Preserves only samples which satisfy the following inequality: + min_len <= src sample sequence length <= max_len AND + min_len <= tgt sample sequence length <= max_len + + :param min_len: minimum sequence length + :param max_len: maximum sequence length + """ + logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}') + + initial_len = len(self.src) + filtered_src = [] + filtered_tgt = [] + for src, tgt in zip(self.src, self.tgt): + if min_len <= len(src) <= max_len and \ + min_len <= len(tgt) <= max_len: + filtered_src.append(src) + filtered_tgt.append(tgt) + + self.src = filtered_src + self.tgt = filtered_tgt + filtered_len = len(self.src) + logging.info(f'Pairs before: {initial_len}, after: {filtered_len}') + + def __getitem__(self, idx): + return self.src[idx], self.tgt[idx] + + +class LazyParallelDataset(TextDataset): + def __init__(self, src_fname, tgt_fname, tokenizer, + min_len, max_len, sort=False, max_size=None): + """ + Constructor for the LazyParallelDataset. + Tokenization is done on the fly. + + :param src_fname: path to the file with src language data + :param tgt_fname: path to the file with tgt language data + :param tokenizer: tokenizer + :param min_len: minimum sequence length + :param max_len: maximum sequence length + :param sort: sorts dataset by sequence length + :param max_size: loads at most 'max_size' samples from the input file, + if None loads the entire dataset + """ + self.min_len = min_len + self.max_len = max_len + self.parallel = True + self.sorted = False + self.tokenizer = tokenizer + + self.raw_src = self.process_raw_data(src_fname, max_size) + self.raw_tgt = self.process_raw_data(tgt_fname, max_size) + assert len(self.raw_src) == len(self.raw_tgt) + + logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}') + # Subtracting 2 because EOS and BOS are added later during tokenization + self.filter_raw_data(min_len - 2, max_len - 2) + assert len(self.raw_src) == len(self.raw_tgt) + + # Adding 2 because EOS and BOS are added later during tokenization + src_lengths = [i + 2 for i in self.src_len] + tgt_lengths = [i + 2 for i in self.tgt_len] + self.src_lengths = torch.tensor(src_lengths) + self.tgt_lengths = torch.tensor(tgt_lengths) + self.lengths = self.src_lengths + self.tgt_lengths + + def process_raw_data(self, fname, max_size): + """ + Loads data from the input file. + + :param fname: input file name + :param max_size: loads at most 'max_size' samples from the input file, + if None loads the entire dataset + """ + logging.info(f'Processing data from {fname}') + data = [] + with open(fname) as dfile: + for idx, line in enumerate(dfile): + if max_size and idx == max_size: + break + data.append(line) + return data + + def filter_raw_data(self, min_len, max_len): + """ + Preserves only samples which satisfy the following inequality: + min_len <= src sample sequence length <= max_len AND + min_len <= tgt sample sequence length <= max_len + + :param min_len: minimum sequence length + :param max_len: maximum sequence length + """ + initial_len = len(self.raw_src) + filtered_src = [] + filtered_tgt = [] + filtered_src_len = [] + filtered_tgt_len = [] + for src, tgt in zip(self.raw_src, self.raw_tgt): + src_len = src.count(' ') + 1 + tgt_len = tgt.count(' ') + 1 + if min_len <= src_len <= max_len and \ + min_len <= tgt_len <= max_len: + filtered_src.append(src) + filtered_tgt.append(tgt) + filtered_src_len.append(src_len) + filtered_tgt_len.append(tgt_len) + + self.raw_src = filtered_src + self.raw_tgt = filtered_tgt + self.src_len = filtered_src_len + self.tgt_len = filtered_tgt_len + filtered_len = len(self.raw_src) + logging.info(f'Pairs before: {initial_len}, after: {filtered_len}') + + def __getitem__(self, idx): + src = torch.tensor(self.tokenizer.segment(self.raw_src[idx])) + tgt = torch.tensor(self.tokenizer.segment(self.raw_tgt[idx])) + return src, tgt + + def __len__(self): + return len(self.raw_src) diff --git a/rnn_translator/pytorch/seq2seq/data/sampler.py b/rnn_translator/pytorch/seq2seq/data/sampler.py index b7ae1d84a..ffcda0ead 100644 --- a/rnn_translator/pytorch/seq2seq/data/sampler.py +++ b/rnn_translator/pytorch/seq2seq/data/sampler.py @@ -1,14 +1,25 @@ -import torch -from torch.utils.data.sampler import Sampler +import logging +import torch from mlperf_compliance import mlperf_log +from torch.utils.data.sampler import Sampler -from seq2seq.utils import get_world_size, get_rank +from seq2seq.utils import get_rank +from seq2seq.utils import get_world_size +from seq2seq.utils import gnmt_print -class BucketingSampler(Sampler): +class DistributedSampler(Sampler): + def __init__(self, dataset, batch_size, seeds, world_size=None, rank=None): + """ + Constructor for the DistributedSampler. - def __init__(self, dataset, batch_size, bucket=True, world_size=None, rank=None): + :param dataset: dataset + :param batch_size: local batch size + :param seeds: list of seeds, one seed for each training epoch + :param world_size: number of distributed workers + :param rank: rank of the current process + """ if world_size is None: world_size = get_world_size() if rank is None: @@ -18,71 +29,255 @@ def __init__(self, dataset, batch_size, bucket=True, world_size=None, rank=None) self.world_size = world_size self.rank = rank self.epoch = 0 - self.bucket = bucket + self.seeds = seeds self.batch_size = batch_size self.global_batch_size = batch_size * world_size self.data_len = len(self.dataset) + self.num_samples = self.data_len // self.global_batch_size \ * self.global_batch_size + def init_rng(self): + """ + Creates new RNG, seed depends on current epoch idx. + """ + rng = torch.Generator() + seed = self.seeds[self.epoch] + logging.info(f'Sampler for epoch {self.epoch} uses seed {seed}') + rng.manual_seed(seed) + return rng + + def distribute_batches(self, indices): + """ + Assigns batches to workers. + Consecutive ranks are getting consecutive batches. + + :param indices: torch.tensor with batch indices + """ + assert len(indices) == self.num_samples + + indices = indices.view(-1, self.batch_size) + indices = indices[self.rank::self.world_size].contiguous() + indices = indices.view(-1) + indices = indices.tolist() + + assert len(indices) == self.num_samples // self.world_size + return indices + + def reshuffle_batches(self, indices, rng): + """ + Permutes global batches + + :param indices: torch.tensor with batch indices + :param rng: instance of torch.Generator + """ + indices = indices.view(-1, self.global_batch_size) + num_batches = indices.shape[0] + order = torch.randperm(num_batches, generator=rng) + indices = indices[order, :] + indices = indices.view(-1) + return indices + def __iter__(self): - mlperf_log.gnmt_print(key=mlperf_log.INPUT_ORDER) + gnmt_print(key=mlperf_log.INPUT_ORDER, sync=False) + rng = self.init_rng() + # generate permutation + indices = torch.randperm(self.data_len, generator=rng) - # deterministically shuffle based on epoch - g = torch.Generator() - g.manual_seed(self.epoch) + # make indices evenly divisible by (batch_size * world_size) + indices = indices[:self.num_samples] + + # assign batches to workers + indices = self.distribute_batches(indices) + return iter(indices) + + def set_epoch(self, epoch): + """ + Sets current epoch index. + Epoch index is used to seed RNG in __iter__() function. + + :param epoch: index of current epoch + """ + self.epoch = epoch + def __len__(self): + return self.num_samples // self.world_size + + +class ShardingSampler(DistributedSampler): + def __init__(self, dataset, batch_size, seeds, shard_size, + world_size=None, rank=None): + """ + Constructor for the ShardingSampler. + + :param dataset: dataset + :param batch_size: local batch size + :param seeds: list of seeds, one seed for each training epoch + :param shard_size: number of global batches within one shard + :param world_size: number of distributed workers + :param rank: rank of the current process + """ + + super().__init__(dataset, batch_size, seeds, world_size, rank) + + self.shard_size = shard_size + self.num_samples = self.data_len // self.global_batch_size \ + * self.global_batch_size + + def __iter__(self): + gnmt_print(key=mlperf_log.INPUT_ORDER, sync=False) + rng = self.init_rng() # generate permutation - indices = torch.randperm(self.data_len, generator=g) + indices = torch.randperm(self.data_len, generator=rng) # make indices evenly divisible by (batch_size * world_size) indices = indices[:self.num_samples] + # splits the dataset into chunks of 'self.shard_size' global batches + # each, sorts by (src + tgt) sequence length within each chunk, + # reshuffles all global batches + shard_size = self.global_batch_size * self.shard_size + nshards = (self.num_samples + shard_size - 1) // shard_size - if self.bucket: - # begin shards - batches_in_shard = 80 - shard_size = self.global_batch_size * batches_in_shard - nshards = (self.num_samples + shard_size - 1) // shard_size + lengths = self.dataset.lengths[indices] - lengths = self.dataset.lengths[indices] + shards = [indices[i * shard_size:(i+1) * shard_size] for i in range(nshards)] + len_shards = [lengths[i * shard_size:(i+1) * shard_size] for i in range(nshards)] - shards = [indices[i * shard_size:(i+1) * shard_size] for i in range(nshards)] - len_shards = [lengths[i * shard_size:(i+1) * shard_size] for i in range(nshards)] + # sort by (src + tgt) sequence length within each shard + indices = [] + for len_shard in len_shards: + _, ind = len_shard.sort() + indices.append(ind) - indices = [] - for len_shard in len_shards: - _, ind = len_shard.sort() - indices.append(ind) + output = tuple(shard[idx] for shard, idx in zip(shards, indices)) - output = tuple(shard[idx] for shard,idx in zip(shards, indices)) - indices = torch.cat(output) - # global reshuffle - indices = indices.view(-1, self.global_batch_size) - order = torch.randperm(indices.shape[0], generator=g) - indices = indices[order, :] - indices = indices.view(-1) - # end shards + # build batches + indices = torch.cat(output) + # perform global reshuffle of all global batches + indices = self.reshuffle_batches(indices, rng) + # distribute batches to individual workers + indices = self.distribute_batches(indices) + return iter(indices) - assert len(indices) == self.num_samples +class BucketingSampler(DistributedSampler): + def __init__(self, dataset, batch_size, seeds, num_buckets, + world_size=None, rank=None): + """ + Constructor for the BucketingSampler. - # build indices for each individual worker - # ranks are getting consecutive batches, - # default pytorch DistributedSampler assigns strided batches - # with offset = length / world_size - indices = indices.view(-1, self.batch_size) - indices = indices[self.rank::self.world_size].contiguous() - indices = indices.view(-1) - indices = indices.tolist() + :param dataset: dataset + :param batch_size: local batch size + :param seeds: list of seeds, one seed for each training epoch + :param num_buckets: number of buckets + :param world_size: number of distributed workers + :param rank: rank of the current process + """ - assert len(indices) == self.num_samples // self.world_size + super().__init__(dataset, batch_size, seeds, world_size, rank) + + self.num_buckets = num_buckets + bucket_width = (dataset.max_len + num_buckets - 1) // num_buckets + + # assign sentences to buckets based on src and tgt sequence lengths + bucket_ids = torch.max(dataset.src_lengths // bucket_width, + dataset.tgt_lengths // bucket_width) + bucket_ids.clamp_(0, num_buckets - 1) + + # build buckets + all_indices = torch.tensor(range(self.data_len)) + self.buckets = [] + self.num_samples = 0 + global_bs = self.global_batch_size + for bid in range(num_buckets): + # gather indices for current bucket + indices = all_indices[bucket_ids == bid] + self.buckets.append(indices) + + # count number of samples in current bucket + samples = len(indices) // global_bs * global_bs + self.num_samples += samples + + def __iter__(self): + gnmt_print(key=mlperf_log.INPUT_ORDER, sync=False) + rng = self.init_rng() + global_bs = self.global_batch_size + + indices = [] + for bid in range(self.num_buckets): + # random shuffle within current bucket + perm = torch.randperm(len(self.buckets[bid]), generator=rng) + bucket_indices = self.buckets[bid][perm] + + # make bucket_indices evenly divisible by global batch size + length = len(bucket_indices) // global_bs * global_bs + bucket_indices = bucket_indices[:length] + assert len(bucket_indices) % self.global_batch_size == 0 + + # add samples from current bucket to indices for current epoch + indices.append(bucket_indices) + + indices = torch.cat(indices) + assert len(indices) % self.global_batch_size == 0 + + # perform global reshuffle of all global batches + indices = self.reshuffle_batches(indices, rng) + # distribute batches to individual workers + indices = self.distribute_batches(indices) return iter(indices) - def __len__(self): - return self.num_samples // self.world_size - def set_epoch(self, epoch): - self.epoch = epoch +class StaticDistributedSampler(Sampler): + def __init__(self, dataset, batch_size, pad, world_size=None, rank=None): + """ + Constructor for the StaticDistributedSampler. + + :param dataset: dataset + :param batch_size: local batch size + :param pad: if True: pads dataset to a multiple of global_batch_size + samples + :param world_size: number of distributed workers + :param rank: rank of the current process + """ + if world_size is None: + world_size = get_world_size() + if rank is None: + rank = get_rank() + + self.world_size = world_size + + global_batch_size = batch_size * world_size + + gnmt_print(key=mlperf_log.INPUT_ORDER, sync=False) + data_len = len(dataset) + num_samples = (data_len + global_batch_size - 1) \ + // global_batch_size * global_batch_size + self.num_samples = num_samples + + indices = list(range(data_len)) + if pad: + # pad dataset to a multiple of global_batch_size samples, uses + # sample with idx 0 as pad + indices += [0] * (num_samples - len(indices)) + else: + # temporary pad to a multiple of global batch size, pads with "-1" + # which is later removed from the list of indices + indices += [-1] * (num_samples - len(indices)) + indices = torch.tensor(indices) + + indices = indices.view(-1, batch_size) + indices = indices[rank::world_size].contiguous() + indices = indices.view(-1) + # remove temporary pad + indices = indices[indices != -1] + indices = indices.tolist() + self.indices = indices + + def __iter__(self): + return iter(self.indices) + + def __len__(self): + return len(self.indices) diff --git a/rnn_translator/pytorch/seq2seq/data/tokenizer.py b/rnn_translator/pytorch/seq2seq/data/tokenizer.py index b197a1af7..b2d954996 100644 --- a/rnn_translator/pytorch/seq2seq/data/tokenizer.py +++ b/rnn_translator/pytorch/seq2seq/data/tokenizer.py @@ -1,44 +1,105 @@ import logging from collections import defaultdict +from functools import partial import seq2seq.data.config as config -def default(): - return config.UNK class Tokenizer: - def __init__(self, vocab_fname, separator='@@'): + """ + Tokenizer class. + """ + def __init__(self, vocab_fname=None, pad=1, separator='@@'): + """ + Constructor for the Tokenizer class. - self.separator = separator + :param vocab_fname: path to the file with vocabulary + :param pad: pads vocabulary to a multiple of 'pad' tokens + :param separator: tokenization separator + """ + if vocab_fname: + self.separator = separator - logging.info(f'building vocabulary from {vocab_fname}') - vocab = [config.PAD_TOKEN, config.UNK_TOKEN, - config.BOS_TOKEN, config.EOS_TOKEN] + logging.info(f'Building vocabulary from {vocab_fname}') + vocab = [config.PAD_TOKEN, config.UNK_TOKEN, + config.BOS_TOKEN, config.EOS_TOKEN] - with open(vocab_fname) as vfile: - for line in vfile: - vocab.append(line.strip()) + with open(vocab_fname) as vfile: + for line in vfile: + vocab.append(line.strip()) - logging.info(f'size of vocabulary: {len(vocab)}') - self.vocab_size = len(vocab) + self.pad_vocabulary(vocab, pad) + self.vocab_size = len(vocab) + logging.info(f'Size of vocabulary: {self.vocab_size}') - self.tok2idx = defaultdict(default) - for idx, token in enumerate(vocab): - self.tok2idx[token] = idx + self.tok2idx = defaultdict(partial(int, config.UNK)) + for idx, token in enumerate(vocab): + self.tok2idx[token] = idx - self.idx2tok = {} - for key, value in self.tok2idx.items(): - self.idx2tok[value] = key + self.idx2tok = {} + for key, value in self.tok2idx.items(): + self.idx2tok[value] = key + + def pad_vocabulary(self, vocab, pad): + """ + Pads vocabulary to a multiple of 'pad' tokens. + + :param vocab: list with vocabulary + :param pad: integer + """ + vocab_size = len(vocab) + padded_vocab_size = (vocab_size + pad - 1) // pad * pad + for i in range(0, padded_vocab_size - vocab_size): + token = f'madeupword{i:04d}' + vocab.append(token) + assert len(vocab) % pad == 0 + + def get_state(self): + logging.info(f'Saving state of the tokenizer') + state = { + 'separator': self.separator, + 'vocab_size': self.vocab_size, + 'tok2idx': self.tok2idx, + 'idx2tok': self.idx2tok, + } + return state + + def set_state(self, state): + logging.info(f'Restoring state of the tokenizer') + self.separator = state['separator'] + self.vocab_size = state['vocab_size'] + self.tok2idx = state['tok2idx'] + self.idx2tok = state['idx2tok'] def segment(self, line): + """ + Tokenizes single sentence and adds special BOS and EOS tokens. + + :param line: sentence + + returns: list representing tokenized sentence + """ line = line.strip().split() entry = [self.tok2idx[i] for i in line] entry = [config.BOS] + entry + [config.EOS] return entry def detokenize(self, inputs, delim=' '): + """ + Detokenizes single sentence and removes token separator characters. + + :param inputs: sequence of tokens + :param delim: tokenization delimiter + + returns: string representing detokenized sentence + """ detok = delim.join([self.idx2tok[idx] for idx in inputs]) - detok = detok.replace( - self.separator+ ' ', '').replace(self.separator, '') + detok = detok.replace(self.separator + ' ', '') + detok = detok.replace(self.separator, '') + + detok = detok.replace(config.BOS_TOKEN, '') + detok = detok.replace(config.EOS_TOKEN, '') + detok = detok.replace(config.PAD_TOKEN, '') + detok = detok.strip() return detok diff --git a/rnn_translator/pytorch/seq2seq/inference/beam_search.py b/rnn_translator/pytorch/seq2seq/inference/beam_search.py index d5cabacf5..661dd3ab2 100644 --- a/rnn_translator/pytorch/seq2seq/inference/beam_search.py +++ b/rnn_translator/pytorch/seq2seq/inference/beam_search.py @@ -1,20 +1,33 @@ import torch - from mlperf_compliance import mlperf_log from seq2seq.data.config import BOS from seq2seq.data.config import EOS +from seq2seq.utils import gnmt_print -class SequenceGenerator(object): - def __init__(self, - model, - beam_size=5, - max_seq_len=100, - cuda=False, - len_norm_factor=0.6, - len_norm_const=5, +class SequenceGenerator: + """ + Generator for the autoregressive inference with beam search decoding. + """ + def __init__(self, model, beam_size=5, max_seq_len=100, cuda=False, + len_norm_factor=0.6, len_norm_const=5, cov_penalty_factor=0.1): + """ + Constructor for the SequenceGenerator. + + Beam search decoding supports coverage penalty and length + normalization. For details, refer to Section 7 of the GNMT paper + (https://arxiv.org/pdf/1609.08144.pdf). + + :param model: model which implements generate method + :param beam_size: decoder beam size + :param max_seq_len: maximum decoder sequence length + :param cuda: whether to use cuda + :param len_norm_factor: length normalization factor + :param len_norm_const: length normalization constant + :param cov_penalty_factor: coverage penalty factor + """ self.model = model self.cuda = cuda @@ -26,18 +39,31 @@ def __init__(self, self.batch_first = self.model.batch_first - mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_BEAM_SIZE, - value=self.beam_size) - mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_MAX_SEQ_LEN, - value=self.max_seq_len) - mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_LEN_NORM_CONST, - value=self.len_norm_const) - mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_LEN_NORM_FACTOR, - value=self.len_norm_factor) - mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_COV_PENALTY_FACTOR, - value=self.cov_penalty_factor) + gnmt_print(key=mlperf_log.EVAL_HP_BEAM_SIZE, + value=self.beam_size, sync=False) + gnmt_print(key=mlperf_log.EVAL_HP_MAX_SEQ_LEN, + value=self.max_seq_len, sync=False) + gnmt_print(key=mlperf_log.EVAL_HP_LEN_NORM_CONST, + value=self.len_norm_const, sync=False) + gnmt_print(key=mlperf_log.EVAL_HP_LEN_NORM_FACTOR, + value=self.len_norm_factor, sync=False) + gnmt_print(key=mlperf_log.EVAL_HP_COV_PENALTY_FACTOR, + value=self.cov_penalty_factor, sync=False) def greedy_search(self, batch_size, initial_input, initial_context=None): + """ + Greedy decoder. + + :param batch_size: decoder batch size + :param initial_input: initial input, usually tensor of BOS tokens + :param initial_context: initial context, usually [encoder_context, + src_seq_lengths, None] + + returns: (translation, lengths, counter) + translation: (batch_size, max_seq_len) - indices of target tokens + lengths: (batch_size) - lengths of generated translations + counter: number of iterations of the decoding loop + """ max_seq_len = self.max_seq_len translation = torch.zeros(batch_size, max_seq_len, dtype=torch.int64) @@ -68,7 +94,8 @@ def greedy_search(self, batch_size, initial_input, initial_context=None): counter += 1 words = words.view(word_view) - words, logprobs, attn, context = self.model.generate(words, context, 1) + output = self.model.generate(words, context, 1) + words, logprobs, attn, context = output words = words.view(-1) translation[active, idx] = words @@ -91,19 +118,34 @@ def greedy_search(self, batch_size, initial_input, initial_context=None): return translation, lengths, counter def beam_search(self, batch_size, initial_input, initial_context=None): + """ + Beam search decoder. + + :param batch_size: decoder batch size + :param initial_input: initial input, usually tensor of BOS tokens + :param initial_context: initial context, usually [encoder_context, + src_seq_lengths, None] + + returns: (translation, lengths, counter) + translation: (batch_size, max_seq_len) - indices of target tokens + lengths: (batch_size) - lengths of generated translations + counter: number of iterations of the decoding loop + """ beam_size = self.beam_size norm_const = self.len_norm_const norm_factor = self.len_norm_factor max_seq_len = self.max_seq_len cov_penalty_factor = self.cov_penalty_factor - translation = torch.zeros(batch_size * beam_size, max_seq_len, dtype=torch.int64) + translation = torch.zeros(batch_size * beam_size, max_seq_len, + dtype=torch.int64) lengths = torch.ones(batch_size * beam_size, dtype=torch.int64) scores = torch.zeros(batch_size * beam_size, dtype=torch.float32) active = torch.arange(0, batch_size * beam_size, dtype=torch.int64) base_mask = torch.arange(0, batch_size * beam_size, dtype=torch.int64) - global_offset = torch.arange(0, batch_size * beam_size, beam_size, dtype=torch.int64) + global_offset = torch.arange(0, batch_size * beam_size, beam_size, + dtype=torch.int64) eos_beam_fill = torch.tensor([0] + (beam_size - 1) * [float('-inf')]) @@ -135,21 +177,23 @@ def beam_search(self, batch_size, initial_input, initial_context=None): _, seq, feature = context[0].shape context[0].unsqueeze_(1) context[0] = context[0].expand(-1, beam_size, -1, -1) - context[0] = context[0].contiguous().view(batch_size * beam_size, seq, feature) + context[0] = context[0].contiguous().view(batch_size * beam_size, + seq, feature) # context[0]: (batch * beam, seq, feature) else: # context[0] (encoder state): (seq, batch, feature) seq, _, feature = context[0].shape context[0].unsqueeze_(2) context[0] = context[0].expand(-1, -1, beam_size, -1) - context[0] = context[0].contiguous().view(seq, batch_size * beam_size, feature) + context[0] = context[0].contiguous().view(seq, batch_size * + beam_size, feature) # context[0]: (seq, batch * beam, feature) - #context[1] (encoder seq length): (batch) + # context[1] (encoder seq length): (batch) context[1].unsqueeze_(1) context[1] = context[1].expand(-1, beam_size) context[1] = context[1].contiguous().view(batch_size * beam_size) - #context[1]: (batch * beam) + # context[1]: (batch * beam) accu_attn_scores = torch.zeros(batch_size * beam_size, seq) if self.cuda: @@ -168,7 +212,8 @@ def beam_search(self, batch_size, initial_input, initial_context=None): lengths[active[~eos_mask.view(-1)]] += 1 - words, logprobs, attn, context = self.model.generate(words, context, beam_size) + output = self.model.generate(words, context, beam_size) + words, logprobs, attn, context = output attn = attn.float().squeeze(attn_query_dim) attn = attn.masked_fill(eos_mask.view(-1).unsqueeze(1), 0) diff --git a/rnn_translator/pytorch/seq2seq/inference/inference.py b/rnn_translator/pytorch/seq2seq/inference/inference.py index e0ddbd93f..5ec3a4bff 100644 --- a/rnn_translator/pytorch/seq2seq/inference/inference.py +++ b/rnn_translator/pytorch/seq2seq/inference/inference.py @@ -1,29 +1,64 @@ +import contextlib +import logging +import os +import subprocess +import time + import torch +import torch.distributed as dist -from seq2seq.data.config import BOS -from seq2seq.data.config import EOS +import seq2seq.data.config as config from seq2seq.inference.beam_search import SequenceGenerator -from seq2seq.utils import batch_padded_sequences +from seq2seq.utils import AverageMeter +from seq2seq.utils import barrier +from seq2seq.utils import get_rank +from seq2seq.utils import get_world_size + +def gather_predictions(preds): + world_size = get_world_size() + if world_size > 1: + all_preds = [preds.new(preds.size(0), preds.size(1)) for i in range(world_size)] + dist.all_gather(all_preds, preds) + preds = torch.cat(all_preds) + return preds -class Translator(object): - def __init__(self, model, tok, +class Translator: + """ + Translator runs validation on test dataset, executes inference, optionally + computes BLEU score using sacrebleu. + """ + def __init__(self, + model, + tokenizer, + loader, beam_size=5, len_norm_factor=0.6, len_norm_const=5.0, cov_penalty_factor=0.1, max_seq_len=50, - cuda=False): + cuda=False, + print_freq=1, + dataset_dir=None, + save_path=None, + target_bleu=None): self.model = model - self.tok = tok - self.insert_target_start = [BOS] - self.insert_src_start = [BOS] - self.insert_src_end = [EOS] + self.tokenizer = tokenizer + self.loader = loader + self.insert_target_start = [config.BOS] + self.insert_src_start = [config.BOS] + self.insert_src_end = [config.EOS] self.batch_first = model.batch_first self.cuda = cuda self.beam_size = beam_size + self.print_freq = print_freq + self.dataset_dir = dataset_dir + self.target_bleu = target_bleu + self.save_path = save_path + + self.distributed = (get_world_size() > 1) self.generator = SequenceGenerator( model=self.model, @@ -34,55 +69,221 @@ def __init__(self, model, tok, len_norm_const=len_norm_const, cov_penalty_factor=cov_penalty_factor) - def translate(self, input_sentences): - stats = {} - batch_size = len(input_sentences) - beam_size = self.beam_size + def build_eval_path(self, epoch, iteration): + """ + Appends index of the current epoch and index of the current iteration + to the name of the file with results. + + :param epoch: index of the current epoch + :param iteration: index of the current iteration + """ + if iteration is not None: + eval_fname = f'eval_epoch_{epoch}_iter_{iteration}' + else: + eval_fname = f'eval_epoch_{epoch}' + eval_path = os.path.join(self.save_path, eval_fname) + return eval_path - src_tok = [torch.tensor(self.tok.segment(line)) for line in input_sentences] + def run(self, calc_bleu=True, epoch=None, iteration=None, eval_path=None, + summary=False, reference_path=None): + """ + Runs translation on test dataset. - bos = [self.insert_target_start] * (batch_size * beam_size) - bos = torch.LongTensor(bos) - if self.batch_first: - bos = bos.view(-1, 1) + :param calc_bleu: if True compares results with reference and computes + BLEU score + :param epoch: index of the current epoch + :param iteration: index of the current iteration + :param eval_path: path to the file for saving results + :param summary: if True prints summary + :param reference_path: path to the file with reference translation + """ + if self.cuda: + test_bleu = torch.cuda.FloatTensor([0]) + break_training = torch.cuda.LongTensor([0]) else: - bos = bos.view(1, -1) + test_bleu = torch.FloatTensor([0]) + break_training = torch.LongTensor([0]) - src = batch_padded_sequences(src_tok, self.batch_first, sort=True) - src, src_length, indices = src + if eval_path is None: + eval_path = self.build_eval_path(epoch, iteration) + detok_eval_path = eval_path + '.detok' - src_length = torch.LongTensor(src_length) - stats['total_enc_len'] = int(src_length.sum()) + with contextlib.suppress(FileNotFoundError): + os.remove(eval_path) + os.remove(detok_eval_path) - if self.cuda: - src = src.cuda() - src_length = src_length.cuda() - bos = bos.cuda() + rank = get_rank() + logging.info(f'Running evaluation on test set') + self.model.eval() + torch.cuda.empty_cache() - with torch.no_grad(): - context = self.model.encode(src, src_length) - context = [context, src_length, None] + output = self.evaluate(epoch, iteration, summary) + output = output[:len(self.loader.dataset)] + output = self.loader.dataset.unsort(output) - if beam_size == 1: - generator = self.generator.greedy_search - else: - generator = self.generator.beam_search + if rank == 0: + with open(eval_path, 'a') as eval_file: + eval_file.writelines(output) + if calc_bleu: + self.run_detokenizer(eval_path) + test_bleu[0] = self.run_sacrebleu(detok_eval_path, reference_path) + if summary: + logging.info(f'BLEU on test dataset: {test_bleu[0]:.2f}') + + if self.target_bleu and test_bleu[0] >= self.target_bleu: + logging.info(f'Target accuracy reached') + break_training[0] = 1 - preds, lengths, counter = generator(batch_size, bos, context) + barrier() + torch.cuda.empty_cache() + logging.info(f'Finished evaluation on test set') - preds = preds.cpu() - lengths = lengths.cpu() + if self.distributed: + dist.broadcast(break_training, 0) + dist.broadcast(test_bleu, 0) + + return test_bleu[0].item(), break_training[0].item() + + def evaluate(self, epoch, iteration, summary): + """ + Runs evaluation on test dataset. + + :param epoch: index of the current epoch + :param iteration: index of the current iteration + :param summary: if True prints summary + """ + batch_time = AverageMeter(False) + tot_tok_per_sec = AverageMeter(False) + iterations = AverageMeter(False) + enc_seq_len = AverageMeter(False) + dec_seq_len = AverageMeter(False) + stats = {} output = [] - for idx, pred in enumerate(preds): - end = lengths[idx] - 1 - pred = pred[1: end] - pred = pred.tolist() - out = self.tok.detokenize(pred) - output.append(out) - - stats['total_dec_len'] = int(lengths.sum()) - stats['iters'] = counter - - output = [output[indices.index(i)] for i in range(len(output))] - return output, stats + + for i, (src, indices) in enumerate(self.loader): + translate_timer = time.time() + src, src_length = src + + batch_size = self.loader.batch_size + global_batch_size = batch_size * get_world_size() + beam_size = self.beam_size + + bos = [self.insert_target_start] * (batch_size * beam_size) + bos = torch.LongTensor(bos) + if self.batch_first: + bos = bos.view(-1, 1) + else: + bos = bos.view(1, -1) + + src_length = torch.LongTensor(src_length) + stats['total_enc_len'] = int(src_length.sum()) + + if self.cuda: + src = src.cuda() + src_length = src_length.cuda() + bos = bos.cuda() + + with torch.no_grad(): + context = self.model.encode(src, src_length) + context = [context, src_length, None] + + if beam_size == 1: + generator = self.generator.greedy_search + else: + generator = self.generator.beam_search + preds, lengths, counter = generator(batch_size, bos, context) + + stats['total_dec_len'] = lengths.sum().item() + stats['iters'] = counter + + indices = torch.tensor(indices).to(preds) + preds = preds.scatter(0, indices.unsqueeze(1).expand_as(preds), preds) + + preds = gather_predictions(preds).cpu() + + for pred in preds: + pred = pred.tolist() + detok = self.tokenizer.detokenize(pred) + output.append(detok + '\n') + + elapsed = time.time() - translate_timer + batch_time.update(elapsed, batch_size) + + total_tokens = stats['total_dec_len'] + stats['total_enc_len'] + ttps = total_tokens / elapsed + tot_tok_per_sec.update(ttps, batch_size) + + iterations.update(stats['iters']) + enc_seq_len.update(stats['total_enc_len'] / batch_size, batch_size) + dec_seq_len.update(stats['total_dec_len'] / batch_size, batch_size) + + if i % self.print_freq == 0: + log = [] + log += f'TEST ' + if epoch is not None: + log += f'[{epoch}]' + if iteration is not None: + log += f'[{iteration}]' + log += f'[{i}/{len(self.loader)}]\t' + log += f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + log += f'Decoder iters {iterations.val:.1f} ({iterations.avg:.1f})\t' + log += f'Tok/s {tot_tok_per_sec.val:.0f} ({tot_tok_per_sec.avg:.0f})' + log = ''.join(log) + logging.info(log) + + tot_tok_per_sec.reduce('sum') + enc_seq_len.reduce('mean') + dec_seq_len.reduce('mean') + batch_time.reduce('mean') + iterations.reduce('sum') + + if summary and get_rank() == 0: + time_per_sentence = (batch_time.avg / global_batch_size) + log = [] + log += f'TEST SUMMARY:\n' + log += f'Lines translated: {len(self.loader.dataset)}\t' + log += f'Avg total tokens/s: {tot_tok_per_sec.avg:.0f}\n' + log += f'Avg time per batch: {batch_time.avg:.3f} s\t' + log += f'Avg time per sentence: {1000*time_per_sentence:.3f} ms\n' + log += f'Avg encoder seq len: {enc_seq_len.avg:.2f}\t' + log += f'Avg decoder seq len: {dec_seq_len.avg:.2f}\t' + log += f'Total decoder iterations: {int(iterations.sum)}' + log = ''.join(log) + logging.info(log) + + return output + + def run_detokenizer(self, eval_path): + """ + Executes moses detokenizer on eval_path file and saves result to + eval_path + ".detok" file. + + :param eval_path: path to the tokenized input + """ + logging.info('Running detokenizer') + detok_path = os.path.join(self.dataset_dir, config.DETOKENIZER) + detok_eval_path = eval_path + '.detok' + + with open(detok_eval_path, 'w') as detok_eval_file, \ + open(eval_path, 'r') as eval_file: + subprocess.run(['perl', f'{detok_path}'], stdin=eval_file, + stdout=detok_eval_file, stderr=subprocess.DEVNULL) + + def run_sacrebleu(self, detok_eval_path, reference_path): + """ + Executes sacrebleu and returns BLEU score. + + :param detok_eval_path: path to the test file + :param reference_path: path to the reference file + """ + if reference_path is None: + reference_path = os.path.join(self.dataset_dir, + config.TGT_TEST_TARGET_FNAME) + sacrebleu_params = '--score-only -lc --tokenize intl' + logging.info(f'Running sacrebleu (parameters: {sacrebleu_params})') + sacrebleu = subprocess.run([f'sacrebleu --input {detok_eval_path} \ + {reference_path} {sacrebleu_params}'], + stdout=subprocess.PIPE, shell=True) + test_bleu = float(sacrebleu.stdout.strip()) + return test_bleu diff --git a/rnn_translator/pytorch/seq2seq/models/__init__.py b/rnn_translator/pytorch/seq2seq/models/__init__.py deleted file mode 100644 index 6217b239e..000000000 --- a/rnn_translator/pytorch/seq2seq/models/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .seq2seq_base import Seq2Seq -from .gnmt import GNMT, ResidualRecurrentDecoder, ResidualRecurrentEncoder - -__all__ = ['GNMT'] diff --git a/rnn_translator/pytorch/seq2seq/models/attention.py b/rnn_translator/pytorch/seq2seq/models/attention.py index 7ead49fee..230f1bca3 100644 --- a/rnn_translator/pytorch/seq2seq/models/attention.py +++ b/rnn_translator/pytorch/seq2seq/models/attention.py @@ -8,12 +8,23 @@ class BahdanauAttention(nn.Module): """ - It should be very similar to tf.contrib.seq2seq.BahdanauAttention + Bahdanau Attention (https://arxiv.org/abs/1409.0473) + Implementation is very similar to tf.contrib.seq2seq.BahdanauAttention """ - def __init__(self, query_size, key_size, num_units, normalize=False, - dropout=0, batch_first=False): - + batch_first=False, init_weight=0.1): + """ + Constructor for the BahdanauAttention. + + :param query_size: feature dimension for query + :param key_size: feature dimension for keys + :param num_units: internal feature dimension + :param normalize: whether to normalize energy term + :param batch_first: if True batch size is the 1st dimension, if False + the sequence is first and batch size is second + :param init_weight: range for uniform initializer used to initialize + Linear key and query transform layers and linear_att vector + """ super(BahdanauAttention, self).__init__() self.normalize = normalize @@ -22,10 +33,11 @@ def __init__(self, query_size, key_size, num_units, normalize=False, self.linear_q = nn.Linear(query_size, num_units, bias=False) self.linear_k = nn.Linear(key_size, num_units, bias=False) + nn.init.uniform_(self.linear_q.weight.data, -init_weight, init_weight) + nn.init.uniform_(self.linear_k.weight.data, -init_weight, init_weight) self.linear_att = Parameter(torch.Tensor(num_units)) - self.dropout = nn.Dropout(dropout) self.mask = None if self.normalize: @@ -35,11 +47,14 @@ def __init__(self, query_size, key_size, num_units, normalize=False, self.register_parameter('normalize_scalar', None) self.register_parameter('normalize_bias', None) - self.reset_parameters() + self.reset_parameters(init_weight) - def reset_parameters(self): + def reset_parameters(self, init_weight): + """ + Sets initial random values for trainable parameters. + """ stdv = 1. / math.sqrt(self.num_units) - self.linear_att.data.uniform_(-stdv, stdv) + self.linear_att.data.uniform_(-init_weight, init_weight) if self.normalize: self.normalize_scalar.data.fill_(stdv) @@ -61,7 +76,8 @@ def set_mask(self, context_len, context): else: max_len = context.size(0) - indices = torch.arange(0, max_len, dtype=torch.int64, device=context.device) + indices = torch.arange(0, max_len, dtype=torch.int64, + device=context.device) self.mask = indices >= (context_len.unsqueeze(1)) def calc_score(self, att_query, att_keys): @@ -71,7 +87,7 @@ def calc_score(self, att_query, att_keys): :param att_query: b x t_q x n :param att_keys: b x t_k x n - return b x t_q x t_k scores + returns: b x t_q x t_k scores """ b, t_k, n = att_keys.size() @@ -83,16 +99,12 @@ def calc_score(self, att_query, att_keys): if self.normalize: sum_qk = sum_qk + self.normalize_bias - - tmp = self.linear_att.to(torch.float32) - linear_att = tmp / tmp.norm() - linear_att = linear_att.to(self.normalize_scalar) - + linear_att = self.linear_att / self.linear_att.norm() linear_att = linear_att * self.normalize_scalar else: linear_att = self.linear_att - out = F.tanh(sum_qk).matmul(linear_att) + out = torch.tanh(sum_qk).matmul(linear_att) return out def forward(self, query, keys): @@ -124,7 +136,6 @@ def forward(self, query, keys): # FC layers to transform query and key processed_query = self.linear_q(query) - # TODO move this out of decoder for efficiency during inference processed_key = self.linear_k(keys) # scores: (b x t_q x t_k) @@ -132,7 +143,7 @@ def forward(self, query, keys): if self.mask is not None: mask = self.mask.unsqueeze(1).expand(b, t_q, t_k) - # TODO I can't use -INF because of overflow check in pytorch + # I can't use -INF because of overflow check in pytorch scores.data.masked_fill_(mask, -65504.0) # Normalize the scores, softmax over t_k @@ -140,7 +151,6 @@ def forward(self, query, keys): # Calculate the weighted average of the attention inputs according to # the scores - scores_normalized = self.dropout(scores_normalized) # context: (b x t_q x n) context = torch.bmm(scores_normalized, keys) diff --git a/rnn_translator/pytorch/seq2seq/models/decoder.py b/rnn_translator/pytorch/seq2seq/models/decoder.py index ccfc3d00d..7e22a1c75 100644 --- a/rnn_translator/pytorch/seq2seq/models/decoder.py +++ b/rnn_translator/pytorch/seq2seq/models/decoder.py @@ -3,19 +3,36 @@ import torch import torch.nn as nn -from seq2seq.models.attention import BahdanauAttention import seq2seq.data.config as config +from seq2seq.models.attention import BahdanauAttention +from seq2seq.utils import init_lstm_ class RecurrentAttention(nn.Module): - - def __init__(self, input_size, context_size, hidden_size, num_layers=1, - bias=True, batch_first=False, dropout=0): + """ + LSTM wrapped with an attention module. + """ + def __init__(self, input_size=1024, context_size=1024, hidden_size=1024, + num_layers=1, batch_first=False, dropout=0.2, + init_weight=0.1): + """ + Constructor for the RecurrentAttention. + + :param input_size: number of features in input tensor + :param context_size: number of features in output from encoder + :param hidden_size: internal hidden size + :param num_layers: number of layers in LSTM + :param batch_first: if True the model uses (batch,seq,feature) tensors, + if false the model uses (seq, batch, feature) + :param dropout: probability of dropout (on input to LSTM layer) + :param init_weight: range for the uniform initializer + """ super(RecurrentAttention, self).__init__() - self.rnn = nn.LSTM(input_size, hidden_size, num_layers, bias, - batch_first) + self.rnn = nn.LSTM(input_size, hidden_size, num_layers, bias=True, + batch_first=batch_first) + init_lstm_(self.rnn, init_weight) self.attn = BahdanauAttention(hidden_size, context_size, context_size, normalize=True, batch_first=batch_first) @@ -23,66 +40,119 @@ def __init__(self, input_size, context_size, hidden_size, num_layers=1, self.dropout = nn.Dropout(dropout) def forward(self, inputs, hidden, context, context_len): + """ + Execute RecurrentAttention. + + :param inputs: tensor with inputs + :param hidden: hidden state for LSTM layer + :param context: context tensor from encoder + :param context_len: vector of encoder sequence lengths + + :returns (rnn_outputs, hidden, attn_output, attn_scores) + """ # set attention mask, sequences have different lengths, this mask # allows to include only valid elements of context in attention's # softmax self.attn.set_mask(context_len, context) + inputs = self.dropout(inputs) rnn_outputs, hidden = self.rnn(inputs, hidden) attn_outputs, scores = self.attn(rnn_outputs, context) - rnn_outputs = self.dropout(rnn_outputs) return rnn_outputs, hidden, attn_outputs, scores class Classifier(nn.Module): - - def __init__(self, in_features, out_features, math='fp32'): + """ + Fully-connected classifier + """ + def __init__(self, in_features, out_features, init_weight=0.1): + """ + Constructor for the Classifier. + + :param in_features: number of input features + :param out_features: number of output features (size of vocabulary) + :param init_weight: range for the uniform initializer + """ super(Classifier, self).__init__() - - self.out_features = out_features - - # padding required to trigger HMMA kernels - if math == 'fp16': - out_features = (out_features + 7) // 8 * 8 - self.classifier = nn.Linear(in_features, out_features) + nn.init.uniform_(self.classifier.weight.data, -init_weight, init_weight) + nn.init.uniform_(self.classifier.bias.data, -init_weight, init_weight) def forward(self, x): + """ + Execute the classifier. + + :param x: output from decoder + """ out = self.classifier(x) - out = out[..., :self.out_features] return out class ResidualRecurrentDecoder(nn.Module): - - def __init__(self, vocab_size, hidden_size=128, num_layers=8, bias=True, - dropout=0, batch_first=False, math='fp32', embedder=None): - + """ + Decoder with Embedding, LSTM layers, attention, residual connections and + optinal dropout. + + Attention implemented in this module is different than the attention + discussed in the GNMT arxiv paper. In this model the output from the first + LSTM layer of the decoder goes into the attention module, then the + re-weighted context is concatenated with inputs to all subsequent LSTM + layers in the decoder at the current timestep. + + Residual connections are enabled after 3rd LSTM layer, dropout is applied + on inputs to LSTM layers. + """ + def __init__(self, vocab_size, hidden_size=1024, num_layers=4, dropout=0.2, + batch_first=False, embedder=None, init_weight=0.1): + """ + Constructor of the ResidualRecurrentDecoder. + + :param vocab_size: size of vocabulary + :param hidden_size: hidden size for LSMT layers + :param num_layers: number of LSTM layers + :param dropout: probability of dropout (on input to LSTM layers) + :param batch_first: if True the model uses (batch,seq,feature) tensors, + if false the model uses (seq, batch, feature) + :param embedder: instance of nn.Embedding, if None constructor will + create new embedding layer + :param init_weight: range for the uniform initializer + """ super(ResidualRecurrentDecoder, self).__init__() self.num_layers = num_layers self.att_rnn = RecurrentAttention(hidden_size, hidden_size, hidden_size, num_layers=1, - batch_first=batch_first) + batch_first=batch_first, + dropout=dropout) self.rnn_layers = nn.ModuleList() for _ in range(num_layers - 1): self.rnn_layers.append( - nn.LSTM(2 * hidden_size, hidden_size, num_layers=1, bias=bias, + nn.LSTM(2 * hidden_size, hidden_size, num_layers=1, bias=True, batch_first=batch_first)) + for lstm in self.rnn_layers: + init_lstm_(lstm, init_weight) + if embedder is not None: self.embedder = embedder else: self.embedder = nn.Embedding(vocab_size, hidden_size, - padding_idx=config.PAD) + padding_idx=config.PAD) + nn.init.uniform_(self.embedder.weight.data, -init_weight, init_weight) - self.classifier = Classifier(hidden_size, vocab_size, math) + self.classifier = Classifier(hidden_size, vocab_size) self.dropout = nn.Dropout(p=dropout) def init_hidden(self, hidden): + """ + Converts flattened hidden state (from sequence generator) into a tuple + of hidden states. + + :param hidden: None or flattened hidden state for decoder RNN layers + """ if hidden is not None: # per-layer chunks hidden = hidden.chunk(self.num_layers) @@ -95,10 +165,19 @@ def init_hidden(self, hidden): return hidden def append_hidden(self, h): + """ + Appends the hidden vector h to the list of internal hidden states. + + :param h: hidden vector + """ if self.inference: self.next_hidden.append(h) def package_hidden(self): + """ + Flattens the hidden state from all LSTM layers into one tensor (for + the sequence generator). + """ if self.inference: hidden = torch.cat(tuple(itertools.chain(*self.next_hidden))) else: @@ -106,6 +185,14 @@ def package_hidden(self): return hidden def forward(self, inputs, context, inference=False): + """ + Execute the decoder. + + :param inputs: tensor with inputs to the decoder + :param context: state of encoder, encoder sequence lengths and hidden + state of decoder's LSTM layers + :param inference: if True stores and repackages hidden state + """ self.inference = inference enc_context, enc_len, hidden = context @@ -116,15 +203,15 @@ def forward(self, inputs, context, inference=False): x, h, attn, scores = self.att_rnn(x, hidden[0], enc_context, enc_len) self.append_hidden(h) - x = self.dropout(x) x = torch.cat((x, attn), dim=2) + x = self.dropout(x) x, h = self.rnn_layers[0](x, hidden[1]) self.append_hidden(h) for i in range(1, len(self.rnn_layers)): residual = x - x = self.dropout(x) x = torch.cat((x, attn), dim=2) + x = self.dropout(x) x, h = self.rnn_layers[i](x, hidden[i + 1]) self.append_hidden(h) x = x + residual diff --git a/rnn_translator/pytorch/seq2seq/models/encoder.py b/rnn_translator/pytorch/seq2seq/models/encoder.py index 0f74a6e4c..0b1944aa2 100644 --- a/rnn_translator/pytorch/seq2seq/models/encoder.py +++ b/rnn_translator/pytorch/seq2seq/models/encoder.py @@ -3,41 +3,78 @@ from torch.nn.utils.rnn import pad_packed_sequence import seq2seq.data.config as config +from seq2seq.utils import init_lstm_ class ResidualRecurrentEncoder(nn.Module): + """ + Encoder with Embedding, LSTM layers, residual connections and optional + dropout. - def __init__(self, vocab_size, hidden_size=128, num_layers=8, bias=True, - dropout=0, batch_first=False, embedder=None): + The first LSTM layer is bidirectional and uses variable sequence length + API, the remaining (num_layers-1) layers are unidirectional. Residual + connections are enabled after third LSTM layer, dropout is applied on + inputs to LSTM layers. + """ + def __init__(self, vocab_size, hidden_size=1024, num_layers=4, dropout=0.2, + batch_first=False, embedder=None, init_weight=0.1): + """ + Constructor for the ResidualRecurrentEncoder. + :param vocab_size: size of vocabulary + :param hidden_size: hidden size for LSTM layers + :param num_layers: number of LSTM layers, 1st layer is bidirectional + :param dropout: probability of dropout (on input to LSTM layers) + :param batch_first: if True the model uses (batch,seq,feature) tensors, + if false the model uses (seq, batch, feature) + :param embedder: instance of nn.Embedding, if None constructor will + create new embedding layer + :param init_weight: range for the uniform initializer + """ super(ResidualRecurrentEncoder, self).__init__() self.batch_first = batch_first self.rnn_layers = nn.ModuleList() + # 1st LSTM layer, bidirectional self.rnn_layers.append( - nn.LSTM(hidden_size, hidden_size, num_layers=1, bias=bias, + nn.LSTM(hidden_size, hidden_size, num_layers=1, bias=True, batch_first=batch_first, bidirectional=True)) + # 2nd LSTM layer, with 2x larger input_size self.rnn_layers.append( - nn.LSTM((2 * hidden_size), hidden_size, num_layers=1, bias=bias, + nn.LSTM((2 * hidden_size), hidden_size, num_layers=1, bias=True, batch_first=batch_first)) + # Remaining LSTM layers for _ in range(num_layers - 2): self.rnn_layers.append( - nn.LSTM(hidden_size, hidden_size, num_layers=1, bias=bias, + nn.LSTM(hidden_size, hidden_size, num_layers=1, bias=True, batch_first=batch_first)) + for lstm in self.rnn_layers: + init_lstm_(lstm, init_weight) + self.dropout = nn.Dropout(p=dropout) if embedder is not None: self.embedder = embedder else: self.embedder = nn.Embedding(vocab_size, hidden_size, - padding_idx=config.PAD) + padding_idx=config.PAD) + nn.init.uniform_(self.embedder.weight.data, -init_weight, init_weight) def forward(self, inputs, lengths): + """ + Execute the encoder. + + :param inputs: tensor with indices from the vocabulary + :param lengths: vector with sequence lengths (excluding padding) + + returns: tensor with encoded sequences + """ x = self.embedder(inputs) # bidirectional layer + x = self.dropout(x) x = pack_padded_sequence(x, lengths.cpu().numpy(), batch_first=self.batch_first) x, _ = self.rnn_layers[0](x) diff --git a/rnn_translator/pytorch/seq2seq/models/gnmt.py b/rnn_translator/pytorch/seq2seq/models/gnmt.py index ff949abd2..bd691f4a5 100644 --- a/rnn_translator/pytorch/seq2seq/models/gnmt.py +++ b/rnn_translator/pytorch/seq2seq/models/gnmt.py @@ -1,41 +1,56 @@ import torch.nn as nn - from mlperf_compliance import mlperf_log import seq2seq.data.config as config -from .seq2seq_base import Seq2Seq -from .decoder import ResidualRecurrentDecoder -from .encoder import ResidualRecurrentEncoder +from seq2seq.models.decoder import ResidualRecurrentDecoder +from seq2seq.models.encoder import ResidualRecurrentEncoder +from seq2seq.models.seq2seq_base import Seq2Seq +from seq2seq.utils import gnmt_print class GNMT(Seq2Seq): - def __init__(self, vocab_size, hidden_size=512, num_layers=8, bias=True, - dropout=0.2, batch_first=False, math='fp32', - share_embedding=False): + """ + GNMT v2 model + """ + def __init__(self, vocab_size, hidden_size=1024, num_layers=4, dropout=0.2, + batch_first=False, share_embedding=True): + """ + Constructor for the GNMT v2 model. + + :param vocab_size: size of vocabulary (number of tokens) + :param hidden_size: internal hidden size of the model + :param num_layers: number of layers, applies to both encoder and + decoder + :param dropout: probability of dropout (in encoder and decoder) + :param batch_first: if True the model uses (batch,seq,feature) tensors, + if false the model uses (seq, batch, feature) + :param share_embedding: if True embeddings are shared between encoder + and decoder + """ super(GNMT, self).__init__(batch_first=batch_first) - mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_NUM_LAYERS, - value=num_layers) - mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_HIDDEN_SIZE, - value=hidden_size) - mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_DROPOUT, - value=dropout) + gnmt_print(key=mlperf_log.MODEL_HP_NUM_LAYERS, + value=num_layers, sync=False) + gnmt_print(key=mlperf_log.MODEL_HP_HIDDEN_SIZE, + value=hidden_size, sync=False) + gnmt_print(key=mlperf_log.MODEL_HP_DROPOUT, + value=dropout, sync=False) if share_embedding: - embedder = nn.Embedding(vocab_size, hidden_size, padding_idx=config.PAD) + embedder = nn.Embedding(vocab_size, hidden_size, + padding_idx=config.PAD) + nn.init.uniform_(embedder.weight.data, -0.1, 0.1) else: embedder = None self.encoder = ResidualRecurrentEncoder(vocab_size, hidden_size, - num_layers, bias, dropout, + num_layers, dropout, batch_first, embedder) self.decoder = ResidualRecurrentDecoder(vocab_size, hidden_size, - num_layers, bias, dropout, - batch_first, math, embedder) - - + num_layers, dropout, + batch_first, embedder) def forward(self, input_encoder, input_enc_len, input_decoder): context = self.encode(input_encoder, input_enc_len) diff --git a/rnn_translator/pytorch/seq2seq/models/seq2seq_base.py b/rnn_translator/pytorch/seq2seq/models/seq2seq_base.py index 844a3c473..4b0631719 100644 --- a/rnn_translator/pytorch/seq2seq/models/seq2seq_base.py +++ b/rnn_translator/pytorch/seq2seq/models/seq2seq_base.py @@ -3,19 +3,61 @@ class Seq2Seq(nn.Module): + """ + Generic Seq2Seq module, with an encoder and a decoder. + """ def __init__(self, encoder=None, decoder=None, batch_first=False): + """ + Constructor for the Seq2Seq module. + + :param encoder: encoder module + :param decoder: decoder module + :param batch_first: if True the model uses (batch, seq, feature) + tensors, if false the model uses (seq, batch, feature) tensors + """ super(Seq2Seq, self).__init__() self.encoder = encoder self.decoder = decoder self.batch_first = batch_first def encode(self, inputs, lengths): + """ + Applies the encoder to inputs with a given input sequence lengths. + + :param inputs: tensor with inputs (batch, seq_len) if 'batch_first' + else (seq_len, batch) + :param lengths: vector with sequence lengths (excluding padding) + """ return self.encoder(inputs, lengths) def decode(self, inputs, context, inference=False): + """ + Applies the decoder to inputs, given the context from the encoder. + + :param inputs: tensor with inputs (batch, seq_len) if 'batch_first' + else (seq_len, batch) + :param context: context from the encoder + :param inference: if True inference mode, if False training mode + """ return self.decoder(inputs, context, inference) def generate(self, inputs, context, beam_size): + """ + Autoregressive generator, works with SequenceGenerator class. + Executes decoder (in inference mode), applies log_softmax and topK for + inference with beam search decoding. + + :param inputs: tensor with inputs to the decoder + :param context: context from the encoder + :param beam_size: beam size for the generator + + returns: (words, logprobs, scores, new_context) + words: indices of topK tokens + logprobs: log probabilities of topK tokens + scores: scores from the attention module (for coverage penalty) + new_context: new decoder context, includes new hidden states for + decoder RNN cells + """ logits, scores, new_context = self.decode(inputs, context, True) logprobs = log_softmax(logits, dim=-1) logprobs, words = logprobs.topk(beam_size, dim=-1) diff --git a/rnn_translator/pytorch/seq2seq/train/distributed.py b/rnn_translator/pytorch/seq2seq/train/distributed.py deleted file mode 100644 index 19b8d85ad..000000000 --- a/rnn_translator/pytorch/seq2seq/train/distributed.py +++ /dev/null @@ -1,222 +0,0 @@ -import torch -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -import torch.distributed as dist -from torch.nn.modules import Module -from torch.autograd import Variable -from collections import OrderedDict - - -def flat_dist_call(tensors, call, extra_args=None): - flat_dist_call.warn_on_half = True - buckets = OrderedDict() - for tensor in tensors: - tp = tensor.type() - if tp not in buckets: - buckets[tp] = [] - buckets[tp].append(tensor) - - if flat_dist_call.warn_on_half: - if torch.cuda.HalfTensor in buckets: - print("WARNING: gloo dist backend for half parameters may be extremely slow." + - " It is recommended to use the NCCL backend in this case.") - flat_dist_call.warn_on_half = False - - for tp in buckets: - bucket = buckets[tp] - coalesced = _flatten_dense_tensors(bucket) - if extra_args is not None: - call(coalesced, *extra_args) - else: - call(coalesced) - if call is dist.all_reduce: - coalesced /= dist.get_world_size() - - for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)): - buf.copy_(synced) - -class DistributedDataParallel(Module): - """ - :class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables - easy multiprocess distributed data parallel training, similar to ``torch.nn.parallel.DistributedDataParallel``. - - :class:`DistributedDataParallel` is designed to work with - the launch utility script ``apex.parallel.multiproc.py``. - When used with ``multiproc.py``, :class:`DistributedDataParallel` - assigns 1 process to each of the available (visible) GPUs on the node. - Parameters are broadcast across participating processes on initialization, and gradients are - allreduced and averaged over processes during ``backward()``. - - :class:`DistributedDataParallel` is optimized for use with NCCL. It achieves high performance by - overlapping communication with computation during ``backward()`` and bucketing smaller gradient - transfers to reduce the total number of transfers required. - - :class:`DistributedDataParallel` assumes that your script accepts the command line - arguments "rank" and "world-size." It also assumes that your script calls - ``torch.cuda.set_device(args.rank)`` before creating the model. - - https://github.com/NVIDIA/apex/tree/master/examples/distributed shows detailed usage. - https://github.com/NVIDIA/apex/tree/master/examples/imagenet shows another example - that combines :class:`DistributedDataParallel` with mixed precision training. - - Args: - module: Network definition to be run in multi-gpu/distributed mode. - message_size (Default = 1e7): Minimum number of elements in a communication bucket. - shared_param (Default = False): If your model uses shared parameters this must be True. It will disable bucketing of parameters to avoid race conditions. - - """ - - def __init__(self, module, message_size=10000000, shared_param=False): - super(DistributedDataParallel, self).__init__() - self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False - self.shared_param = shared_param - self.message_size = message_size - - #reference to last iterations parameters to see if anything has changed - self.param_refs = [] - - self.reduction_stream = torch.cuda.Stream() - - self.module = module - self.param_list = list(self.module.parameters()) - - if dist._backend == dist.dist_backend.NCCL: - for param in self.param_list: - assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU." - - self.record = [] - self.create_hooks() - - flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) ) - - def create_hooks(self): - #all reduce gradient hook - def allreduce_params(): - if not self.needs_reduction: - return - self.needs_reduction = False - - #parameter ordering refresh - if self.needs_refresh and not self.shared_param: - t_record = torch.cuda.IntTensor(self.record) - dist.broadcast(t_record, 0) - self.record = [int(entry) for entry in t_record] - self.needs_refresh = False - - grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] - flat_dist_call(grads, dist.all_reduce) - - def flush_buckets(): - if not self.needs_reduction: - return - self.needs_reduction = False - - grads = [] - for i in range(self.ready_end, len(self.param_state)): - param = self.param_refs[self.record[i]] - if param.grad is not None: - grads.append(param.grad.data) - grads = [param.grad.data for param in self.ready_params] + grads - - if(len(grads)>0): - orig_stream = torch.cuda.current_stream() - with torch.cuda.stream(self.reduction_stream): - self.reduction_stream.wait_stream(orig_stream) - flat_dist_call(grads, dist.all_reduce) - - torch.cuda.current_stream().wait_stream(self.reduction_stream) - - for param_i, param in enumerate(list(self.module.parameters())): - def wrapper(param_i): - - def allreduce_hook(*unused): - if self.needs_refresh: - self.record.append(param_i) - Variable._execution_engine.queue_callback(allreduce_params) - else: - Variable._execution_engine.queue_callback(flush_buckets) - self.comm_ready_buckets(self.record.index(param_i)) - - - if param.requires_grad: - param.register_hook(allreduce_hook) - wrapper(param_i) - - - def comm_ready_buckets(self, param_ind): - - if self.param_state[param_ind] != 0: - raise RuntimeError("Error: Your model uses shared parameters, DDP flag shared_params must be set to True in initialization.") - - - if self.param_state[self.ready_end] == 0: - self.param_state[param_ind] = 1 - return - - - while self.ready_end < len(self.param_state) and self.param_state[self.ready_end] == 1: - self.ready_params.append(self.param_refs[self.record[self.ready_end]]) - self.ready_numel += self.ready_params[-1].numel() - self.ready_end += 1 - - - if self.ready_numel < self.message_size: - self.param_state[param_ind] = 1 - return - - grads = [param.grad.data for param in self.ready_params] - - bucket = [] - bucket_inds = [] - while grads: - bucket.append(grads.pop(0)) - - cumm_size = 0 - for ten in bucket: - cumm_size += ten.numel() - - if cumm_size < self.message_size: - continue - - evt = torch.cuda.Event() - evt.record(torch.cuda.current_stream()) - evt.wait(stream=self.reduction_stream) - - with torch.cuda.stream(self.reduction_stream): - flat_dist_call(bucket, dist.all_reduce) - - for i in range(self.ready_start, self.ready_start+len(bucket)): - self.param_state[i] = 2 - self.ready_params.pop(0) - - self.param_state[param_ind] = 1 - - def forward(self, *inputs, **kwargs): - - param_list = [param for param in list(self.module.parameters()) if param.requires_grad] - - - #Force needs_refresh to True if there are shared params - #this will force it to always, only call flush_buckets which is safe - #for shared parameters in the model. - #Parentheses are not necessary for correct order of operations, but make the intent clearer. - if (not self.param_refs) or self.shared_param: - self.needs_refresh = True - else: - self.needs_refresh = ( - (len(param_list) != len(self.param_refs)) or any( - [param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)])) - - if self.needs_refresh: - self.record = [] - - - self.param_state = [0 for i in range(len(param_list))] - self.param_refs = param_list - self.needs_reduction = True - - self.ready_start = 0 - self.ready_end = 0 - self.ready_params = [] - self.ready_numel = 0 - - return self.module(*inputs, **kwargs) diff --git a/rnn_translator/pytorch/seq2seq/train/fp_optimizers.py b/rnn_translator/pytorch/seq2seq/train/fp_optimizers.py index ba57060ea..a3397c0f2 100644 --- a/rnn_translator/pytorch/seq2seq/train/fp_optimizers.py +++ b/rnn_translator/pytorch/seq2seq/train/fp_optimizers.py @@ -6,9 +6,18 @@ class Fp16Optimizer: - + """ + Mixed precision optimizer with dynamic loss scaling and backoff. + https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#scalefactor + """ @staticmethod def set_grads(params, params_with_grad): + """ + Copies gradients from param_with_grad to params + + :param params: dst parameters + :param params_with_grad: src parameters + """ for param, param_w_grad in zip(params, params_with_grad): if param.grad is None: param.grad = torch.nn.Parameter(torch.empty_like(param)) @@ -16,11 +25,31 @@ def set_grads(params, params_with_grad): @staticmethod def set_weights(params, new_params): + """ + Copies parameters from new_params to params + + :param params: dst parameters + :param new_params: src parameters + """ for param, new_param in zip(params, new_params): param.data.copy_(new_param.data) def __init__(self, fp16_model, grad_clip=float('inf'), loss_scale=8192, - dls_downscale=2, dls_upscale=2, dls_upscale_interval=2048): + dls_downscale=2, dls_upscale=2, dls_upscale_interval=128): + """ + Constructor for the Fp16Optimizer. + + :param fp16_model: model (previously casted to half) + :param grad_clip: coefficient for gradient clipping, max L2 norm of the + gradients + :param loss_scale: initial loss scale + :param dls_downscale: loss downscale factor, loss scale is divided by + this factor when NaN/INF occurs in the gradients + :param dls_upscale: loss upscale factor, loss scale is multiplied by + this factor if previous dls_upscale_interval batches finished + successfully + :param dls_upscale_interval: interval for loss scale upscaling + """ logging.info('Initializing fp16 optimizer') self.initialize_model(fp16_model) @@ -32,6 +61,11 @@ def __init__(self, fp16_model, grad_clip=float('inf'), loss_scale=8192, self.grad_clip = grad_clip def initialize_model(self, model): + """ + Initializes internal state and build fp32 master copy of weights. + + :param model: fp16 model + """ logging.info('Initializing fp32 clone weights') self.fp16_model = model self.fp16_model.zero_grad() @@ -41,23 +75,36 @@ def initialize_model(self, model): for param in self.fp32_params: param.requires_grad = True - def step(self, loss, optimizer, update=True): + def step(self, loss, optimizer, scheduler, update=True): + """ + Performs one step of the optimizer. + Applies loss scaling, computes gradients in fp16, converts gradients to + fp32, inverts scaling and applies optional gradient norm clipping. + If gradients are finite, it applies update to fp32 master weights and + copies updated parameters to fp16 model for the next iteration. If + gradients are not finite, it skips the batch and adjusts scaling factor + for the next iteration. + + :param loss: value of loss function + :param optimizer: optimizer + :param update: if True executes weight update + """ loss *= self.loss_scale - - self.fp16_model.zero_grad() loss.backward() - self.set_grads(self.fp32_params, self.fp16_model.parameters()) - if self.loss_scale != 1.0: - for param in self.fp32_params: - param.grad.data /= self.loss_scale + if update: + self.set_grads(self.fp32_params, self.fp16_model.parameters()) + if self.loss_scale != 1.0: + for param in self.fp32_params: + param.grad.data /= self.loss_scale - norm = clip_grad_norm_(self.fp32_params, self.grad_clip) + norm = clip_grad_norm_(self.fp32_params, self.grad_clip) - if update: if math.isfinite(norm): + scheduler.step() optimizer.step() - self.set_weights(self.fp16_model.parameters(), self.fp32_params) + self.set_weights(self.fp16_model.parameters(), + self.fp32_params) self.since_last_invalid += 1 else: self.loss_scale /= self.dls_downscale @@ -71,22 +118,46 @@ def step(self, loss, optimizer, update=True): logging.info(f'Upscaling, new scale: {self.loss_scale}') self.since_last_invalid = 0 + self.fp16_model.zero_grad() -class Fp32Optimizer: +class Fp32Optimizer: + """ + Standard optimizer, computes backward and applies weight update. + """ def __init__(self, model, grad_clip=None): + """ + Constructor for the Fp32Optimizer + + :param model: model + :param grad_clip: coefficient for gradient clipping, max L2 norm of the + gradients + """ logging.info('Initializing fp32 optimizer') self.initialize_model(model) self.grad_clip = grad_clip def initialize_model(self, model): + """ + Initializes state of the model. + + :param model: model + """ self.model = model self.model.zero_grad() - def step(self, loss, optimizer, update=True): + def step(self, loss, optimizer, scheduler, update=True): + """ + Performs one step of the optimizer. + + :param loss: value of loss function + :param optimizer: optimizer + :param update: if True executes weight update + """ loss.backward() - if self.grad_clip != float('inf'): - clip_grad_norm_(self.model.parameters(), self.grad_clip) if update: + if self.grad_clip != float('inf'): + clip_grad_norm_(self.model.parameters(), self.grad_clip) + scheduler.step() optimizer.step() - self.model.zero_grad() + self.model.zero_grad() diff --git a/rnn_translator/pytorch/seq2seq/train/lr_scheduler.py b/rnn_translator/pytorch/seq2seq/train/lr_scheduler.py new file mode 100644 index 000000000..4df62b6a0 --- /dev/null +++ b/rnn_translator/pytorch/seq2seq/train/lr_scheduler.py @@ -0,0 +1,104 @@ +import logging +import math + +import torch +from mlperf_compliance import mlperf_log + +from seq2seq.utils import gnmt_print + + +def perhaps_convert_float(param, total): + if isinstance(param, float): + param = int(param * total) + return param + + +class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): + """ + Learning rate scheduler with exponential warmup and step decay. + """ + def __init__(self, optimizer, iterations, warmup_steps=0, + remain_steps=1.0, decay_interval=None, decay_steps=4, + decay_factor=0.5, last_epoch=-1): + """ + Constructor of WarmupMultiStepLR. + + Parameters: warmup_steps, remain_steps and decay_interval accept both + integers and floats as an input. Integer input is interpreted as + absolute index of iteration, float input is interpreted as a fraction + of total training iterations (epochs * steps_per_epoch). + + If decay_interval is None then the decay will happen at regulary spaced + intervals ('decay_steps' decays between iteration indices + 'remain_steps' and 'iterations'). + + :param optimizer: instance of optimizer + :param iterations: total number of training iterations + :param warmup_steps: number of warmup iterations + :param remain_steps: start decay at 'remain_steps' iteration + :param decay_interval: interval between LR decay steps + :param decay_steps: max number of decay steps + :param decay_factor: decay factor + :param last_epoch: the index of last iteration + """ + + # iterations before learning rate reaches base LR + self.warmup_steps = perhaps_convert_float(warmup_steps, iterations) + logging.info(f'Scheduler warmup steps: {self.warmup_steps}') + + # iteration at which decay starts + self.remain_steps = perhaps_convert_float(remain_steps, iterations) + logging.info(f'Scheduler remain steps: {self.remain_steps}') + + # number of steps between each decay + if decay_interval is None: + # decay at regulary spaced intervals + decay_iterations = iterations - self.remain_steps + self.decay_interval = decay_iterations // (decay_steps) + self.decay_interval = max(self.decay_interval, 1) + else: + self.decay_interval = perhaps_convert_float(decay_interval, + iterations) + logging.info(f'Scheduler decay interval: {self.decay_interval}') + + # multiplicative decay factor + self.decay_factor = decay_factor + logging.info(f'Scheduler decay factor: {self.decay_factor}') + + # max number of decay steps + self.decay_steps = decay_steps + logging.info(f'Scheduler max decay steps: {self.decay_steps}') + + if self.warmup_steps > self.remain_steps: + logging.warn(f'warmup_steps should not be larger than ' + f'remain_steps, setting warmup_steps=remain_steps') + self.warmup_steps = self.remain_steps + + gnmt_print(key=mlperf_log.OPT_LR_WARMUP_STEPS, value=self.warmup_steps, + sync=False) + + super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch <= self.warmup_steps: + # exponential lr warmup + if self.warmup_steps != 0: + warmup_factor = math.exp(math.log(0.01) / self.warmup_steps) + else: + warmup_factor = 1.0 + inv_decay = warmup_factor ** (self.warmup_steps - self.last_epoch) + lr = [base_lr * inv_decay for base_lr in self.base_lrs] + + elif self.last_epoch >= self.remain_steps: + # step decay + decay_iter = self.last_epoch - self.remain_steps + num_decay_steps = decay_iter // self.decay_interval + 1 + num_decay_steps = min(num_decay_steps, self.decay_steps) + lr = [ + base_lr * (self.decay_factor ** num_decay_steps) + for base_lr in self.base_lrs + ] + else: + # base lr + lr = [base_lr for base_lr in self.base_lrs] + return lr diff --git a/rnn_translator/pytorch/seq2seq/train/smoothing.py b/rnn_translator/pytorch/seq2seq/train/smoothing.py index a2a6c4d38..d5b60b27f 100644 --- a/rnn_translator/pytorch/seq2seq/train/smoothing.py +++ b/rnn_translator/pytorch/seq2seq/train/smoothing.py @@ -1,15 +1,26 @@ import torch import torch.nn as nn + class LabelSmoothing(nn.Module): + """ + NLL loss with label smoothing. + """ def __init__(self, padding_idx, smoothing=0.0): + """ + Constructor for the LabelSmoothing module. + + :param padding_idx: index of the PAD token + :param smoothing: label smoothing factor + """ super(LabelSmoothing, self).__init__() self.padding_idx = padding_idx self.confidence = 1.0 - smoothing self.smoothing = smoothing def forward(self, x, target): - logprobs = torch.nn.functional.log_softmax(x, dim=-1) + logprobs = torch.nn.functional.log_softmax(x, dim=-1, + dtype=torch.float32) non_pad_mask = (target != self.padding_idx) nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) diff --git a/rnn_translator/pytorch/seq2seq/train/trainer.py b/rnn_translator/pytorch/seq2seq/train/trainer.py index 8c6dd2cbd..c7a2b7539 100644 --- a/rnn_translator/pytorch/seq2seq/train/trainer.py +++ b/rnn_translator/pytorch/seq2seq/train/trainer.py @@ -1,36 +1,75 @@ import logging -import math -import time import os +import time from itertools import cycle +import numpy as np import torch import torch.optim import torch.utils.data - +from apex.parallel import DistributedDataParallel as DDP from mlperf_compliance import mlperf_log -from seq2seq.train.distributed import DistributedDataParallel as DDP -from seq2seq.train.fp_optimizers import Fp16Optimizer, Fp32Optimizer +from seq2seq.train.fp_optimizers import Fp16Optimizer +from seq2seq.train.fp_optimizers import Fp32Optimizer +from seq2seq.train.lr_scheduler import WarmupMultiStepLR from seq2seq.utils import AverageMeter +from seq2seq.utils import gnmt_print from seq2seq.utils import sync_workers class Seq2SeqTrainer: - - def __init__(self, model, criterion, opt_config, + """ + Seq2SeqTrainer + """ + def __init__(self, + model, + criterion, + opt_config, + scheduler_config, print_freq=10, save_freq=1000, grad_clip=float('inf'), batch_first=False, save_info={}, save_path='.', + train_iterations=0, checkpoint_filename='checkpoint%s.pth', keep_checkpoints=5, math='fp32', cuda=True, distributed=False, + intra_epoch_eval=0, + iter_size=1, + translator=None, verbose=False): + """ + Constructor for the Seq2SeqTrainer. + + :param model: model to train + :param criterion: criterion (loss function) + :param opt_config: dictionary with options for the optimizer + :param scheduler_config: dictionary with options for the learning rate + scheduler + :param print_freq: prints short summary every 'print_freq' iterations + :param save_freq: saves checkpoint every 'save_freq' iterations + :param grad_clip: coefficient for gradient clipping + :param batch_first: if True the model uses (batch,seq,feature) tensors, + if false the model uses (seq, batch, feature) + :param save_info: dict with additional state stored in each checkpoint + :param save_path: path to the directiory for checkpoints + :param train_iterations: total number of training iterations to execute + :param checkpoint_filename: name of files with checkpoints + :param keep_checkpoints: max number of checkpoints to keep + :param math: arithmetic type + :param cuda: if True use cuda, if False train on cpu + :param distributed: if True run distributed training + :param intra_epoch_eval: number of additional eval runs within each + training epoch + :param iter_size: number of iterations between weight updates + :param translator: instance of Translator, runs inference on test set + :param verbose: enables verbose logging + """ super(Seq2SeqTrainer, self).__init__() self.model = model self.criterion = criterion @@ -48,37 +87,53 @@ def __init__(self, model, criterion, opt_config, self.batch_first = batch_first self.verbose = verbose self.loss = None + self.translator = translator + self.intra_epoch_eval = intra_epoch_eval + self.iter_size = iter_size if cuda: self.model = self.model.cuda() self.criterion = self.criterion.cuda() + if math == 'fp16': + self.model = self.model.half() + if distributed: self.model = DDP(self.model) if math == 'fp16': - self.model = self.model.half() self.fp_optimizer = Fp16Optimizer(self.model, grad_clip) params = self.fp_optimizer.fp32_params elif math == 'fp32': self.fp_optimizer = Fp32Optimizer(self.model, grad_clip) params = self.model.parameters() - opt_name = opt_config['optimizer'] - lr = opt_config['lr'] - self.optimizer = torch.optim.__dict__[opt_name](params, lr=lr) - mlperf_log.gnmt_print(key=mlperf_log.OPT_NAME, - value=opt_name) - mlperf_log.gnmt_print(key=mlperf_log.OPT_LR, - value=lr) - mlperf_log.gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA1, - value=self.optimizer.defaults['betas'][0]) - mlperf_log.gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA2, - value=self.optimizer.defaults['betas'][1]) - mlperf_log.gnmt_print(key=mlperf_log.OPT_HP_ADAM_EPSILON, - value=self.optimizer.defaults['eps']) + opt_name = opt_config.pop('optimizer') + self.optimizer = torch.optim.__dict__[opt_name](params, **opt_config) + logging.info(f'Using optimizer: {self.optimizer}') + gnmt_print(key=mlperf_log.OPT_NAME, + value=mlperf_log.ADAM, sync=False) + gnmt_print(key=mlperf_log.OPT_LR, + value=opt_config['lr'], sync=False) + gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA1, + value=self.optimizer.defaults['betas'][0], sync=False) + gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA2, + value=self.optimizer.defaults['betas'][1], sync=False) + gnmt_print(key=mlperf_log.OPT_HP_ADAM_EPSILON, + value=self.optimizer.defaults['eps'], sync=False) + + self.scheduler = WarmupMultiStepLR(self.optimizer, train_iterations, + **scheduler_config) def iterate(self, src, tgt, update=True, training=True): + """ + Performs one iteration of the training/validation. + + :param src: batch of examples from the source language + :param tgt: batch of examples from the target language + :param update: if True: optimizer does update of the weights + :param training: if True: executes optimizer + """ src, src_length = src tgt, tgt_length = tgt src_length = torch.LongTensor(src_length) @@ -102,14 +157,15 @@ def iterate(self, src, tgt, update=True, training=True): tgt_labels = tgt[1:] T, B = output.size(0), output.size(1) - loss = self.criterion(output.view(T * B, -1).float(), + loss = self.criterion(output.view(T * B, -1), tgt_labels.contiguous().view(-1)) loss_per_batch = loss.item() - loss /= B + loss /= (B * self.iter_size) if training: - self.fp_optimizer.step(loss, self.optimizer, update) + self.fp_optimizer.step(loss, self.optimizer, self.scheduler, + update) loss_per_token = loss_per_batch / num_toks['tgt'] loss_per_sentence = loss_per_batch / B @@ -117,12 +173,24 @@ def iterate(self, src, tgt, update=True, training=True): return loss_per_token, loss_per_sentence, num_toks def feed_data(self, data_loader, training=True): + """ + Runs training or validation on batches from data_loader. + + :param data_loader: data loader + :param training: if True runs training else runs validation + """ if training: assert self.optimizer is not None + eval_fractions = np.linspace(0, 1, self.intra_epoch_eval+2)[1:-1] + iters_with_update = len(data_loader) // self.iter_size + eval_iters = (eval_fractions * iters_with_update).astype(int) + eval_iters = eval_iters * self.iter_size + eval_iters = set(eval_iters) + batch_time = AverageMeter() data_time = AverageMeter() - losses_per_token = AverageMeter() - losses_per_sentence = AverageMeter() + losses_per_token = AverageMeter(skip_first=False) + losses_per_sentence = AverageMeter(skip_first=False) tot_tok_time = AverageMeter() src_tok_time = AverageMeter() @@ -131,13 +199,17 @@ def feed_data(self, data_loader, training=True): batch_size = data_loader.batch_size end = time.time() - for i, (src, tgt, _) in enumerate(data_loader): + for i, (src, tgt) in enumerate(data_loader): self.save_counter += 1 # measure data loading time data_time.update(time.time() - end) + update = False + if i % self.iter_size == self.iter_size - 1: + update = True + # do a train/evaluate iteration - stats = self.iterate(src, tgt, training=training) + stats = self.iterate(src, tgt, update, training=training) loss_per_token, loss_per_sentence, num_toks = stats # measure accuracy and record loss @@ -152,20 +224,36 @@ def feed_data(self, data_loader, training=True): tot_num_toks = num_toks['tgt'] + num_toks['src'] tot_tok_time.update(tot_num_toks / elapsed) self.loss = losses_per_token.avg - end = time.time() + + if training and i in eval_iters: + test_bleu, _ = self.translator.run(calc_bleu=True, + epoch=self.epoch, + iteration=i) + + log = [] + log += [f'TRAIN [{self.epoch}][{i}/{len(data_loader)}]'] + log += [f'BLEU: {test_bleu:.2f}'] + log = '\t'.join(log) + logging.info(log) + + self.model.train() + self.preallocate(data_loader, training=True) if i % self.print_freq == 0: - phase = 'TRAIN' if training else 'EVAL' + phase = 'TRAIN' if training else 'VALIDATION' log = [] log += [f'{phase} [{self.epoch}][{i}/{len(data_loader)}]'] log += [f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'] - log += [f'Data {data_time.val:.3f} ({data_time.avg:.3f})'] + log += [f'Data {data_time.val:.2e} ({data_time.avg:.2e})'] log += [f'Tok/s {tot_tok_time.val:.0f} ({tot_tok_time.avg:.0f})'] if self.verbose: log += [f'Src tok/s {src_tok_time.val:.0f} ({src_tok_time.avg:.0f})'] log += [f'Tgt tok/s {tgt_tok_time.val:.0f} ({tgt_tok_time.avg:.0f})'] log += [f'Loss/sentence {losses_per_sentence.val:.1f} ({losses_per_sentence.avg:.1f})'] - log += [f'Loss/tok {losses_per_token.val:.8f} ({losses_per_token.avg:.8f})'] + log += [f'Loss/tok {losses_per_token.val:.4f} ({losses_per_token.avg:.4f})'] + if training: + lr = self.optimizer.param_groups[0]['lr'] + log += [f'LR {lr:.3e}'] log = '\t'.join(log) logging.info(log) @@ -179,9 +267,21 @@ def feed_data(self, data_loader, training=True): if rank == 0: self.save(identifier=identifier) - return losses_per_token.avg + end = time.time() + + tot_tok_time.reduce('sum') + losses_per_token.reduce('mean') + + return losses_per_token.avg, tot_tok_time.avg def preallocate(self, data_loader, training): + """ + Generates maximum sequence length batch and runs forward and backward + pass without updating model parameters. + + :param data_loader: data loader + :param training: if True preallocates memory for backward pass + """ batch_size = data_loader.batch_size max_len = data_loader.dataset.max_len @@ -198,48 +298,86 @@ def preallocate(self, data_loader, training): src = src, src_length tgt = tgt, tgt_length self.iterate(src, tgt, update=False, training=training) + self.model.zero_grad() def optimize(self, data_loader): + """ + Sets model in training mode, preallocates memory and runs training on + data provided by data_loader. + + :param data_loader: data loader + """ torch.set_grad_enabled(True) self.model.train() torch.cuda.empty_cache() self.preallocate(data_loader, training=True) output = self.feed_data(data_loader, training=True) + self.model.zero_grad() torch.cuda.empty_cache() return output def evaluate(self, data_loader): + """ + Sets model in eval mode, disables gradients, preallocates memory and + runs validation on data provided by data_loader. + + :param data_loader: data loader + """ torch.set_grad_enabled(False) self.model.eval() torch.cuda.empty_cache() self.preallocate(data_loader, training=False) output = self.feed_data(data_loader, training=False) + self.model.zero_grad() torch.cuda.empty_cache() return output def load(self, filename): + """ + Loads checkpoint from filename. + + :param filename: path to the checkpoint file + """ if os.path.isfile(filename): checkpoint = torch.load(filename, map_location={'cuda:0': 'cpu'}) - self.model.load_state_dict(checkpoint['state_dict']) + if self.distributed: + self.model.module.load_state_dict(checkpoint['state_dict']) + else: + self.model.load_state_dict(checkpoint['state_dict']) self.fp_optimizer.initialize_model(self.model) self.optimizer.load_state_dict(checkpoint['optimizer']) + self.scheduler.load_state_dict(checkpoint['scheduler']) self.epoch = checkpoint['epoch'] self.loss = checkpoint['loss'] - logging.info(f'loaded checkpoint {filename} (epoch {self.epoch})') + logging.info(f'Loaded checkpoint {filename} (epoch {self.epoch})') else: - logging.error(f'invalid checkpoint: {filename}') + logging.error(f'Invalid checkpoint: {filename}') def save(self, identifier=None, is_best=False, save_all=False): + """ + Stores checkpoint to a file. + + :param identifier: identifier for periodic checkpoint + :param is_best: if True stores checkpoint to 'model_best.pth' + :param save_all: if True stores checkpoint after completed training + epoch + """ def write_checkpoint(state, filename): filename = os.path.join(self.save_path, filename) - logging.info(f'saving model to {filename}') + logging.info(f'Saving model to {filename}') torch.save(state, filename) + if self.distributed: + model_state = self.model.module.state_dict() + else: + model_state = self.model.state_dict() + state = { 'epoch': self.epoch, - 'state_dict': self.model.state_dict(), + 'state_dict': model_state, 'optimizer': self.optimizer.state_dict(), + 'scheduler': self.scheduler.state_dict(), 'loss': getattr(self, 'loss', None), } state = dict(list(state.items()) + list(self.save_info.items())) diff --git a/rnn_translator/pytorch/seq2seq/utils.py b/rnn_translator/pytorch/seq2seq/utils.py index a654ff872..4f3665813 100644 --- a/rnn_translator/pytorch/seq2seq/utils.py +++ b/rnn_translator/pytorch/seq2seq/utils.py @@ -1,30 +1,160 @@ -from contextlib import contextmanager -import os import logging.config +import os +import random +import sys +import time +from contextlib import contextmanager import numpy as np import torch -from torch.nn.utils.rnn import pack_padded_sequence +import torch.distributed as dist +import torch.nn.init as init +import torch.utils.collect_env +from mlperf_compliance import mlperf_log + + +def gnmt_print(*args, **kwargs): + """ + Wrapper for MLPerf compliance logging calls. + All arguments but 'sync' are passed to mlperf_log.gnmt_print function. + If 'sync' is set to True then the wrapper will synchronize all distributed + workers. 'sync' should be set to True for all compliance tags that require + accurate timing (RUN_START, RUN_STOP etc.) + """ + if kwargs.pop('sync'): + barrier() + if get_rank() == 0: + kwargs['stack_offset'] = 2 + mlperf_log.gnmt_print(*args, **kwargs) + + +def init_lstm_(lstm, init_weight=0.1): + """ + Initializes weights of LSTM layer. + Weights and biases are initialized with uniform(-init_weight, init_weight) + distribution. + + :param lstm: instance of torch.nn.LSTM + :param init_weight: range for the uniform initializer + """ + # Initialize hidden-hidden weights + init.uniform_(lstm.weight_hh_l0.data, -init_weight, init_weight) + # Initialize input-hidden weights: + init.uniform_(lstm.weight_ih_l0.data, -init_weight, init_weight) + + # Initialize bias. PyTorch LSTM has two biases, one for input-hidden GEMM + # and the other for hidden-hidden GEMM. Here input-hidden bias is + # initialized with uniform distribution and hidden-hidden bias is + # initialized with zeros. + init.uniform_(lstm.bias_ih_l0.data, -init_weight, init_weight) + init.zeros_(lstm.bias_hh_l0.data) + + if lstm.bidirectional: + init.uniform_(lstm.weight_hh_l0_reverse.data, -init_weight, init_weight) + init.uniform_(lstm.weight_ih_l0_reverse.data, -init_weight, init_weight) -import seq2seq.data.config as config + init.uniform_(lstm.bias_ih_l0_reverse.data, -init_weight, init_weight) + init.zeros_(lstm.bias_hh_l0_reverse.data) + + +def generate_seeds(rng, size): + """ + Generate list of random seeds + + :param rng: random number generator + :param size: length of the returned list + """ + seeds = [rng.randint(0, 2**32 - 1) for _ in range(size)] + return seeds + + +def broadcast_seeds(seeds, device): + """ + Broadcasts random seeds to all distributed workers. + Returns list of random seeds (broadcasted from workers with rank 0). + + :param seeds: list of seeds (integers) + :param device: torch.device + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + seeds_tensor = torch.LongTensor(seeds).to(device) + torch.distributed.broadcast(seeds_tensor, 0) + seeds = seeds_tensor.tolist() + return seeds + + +def setup_seeds(master_seed, epochs, device): + """ + Generates seeds from one master_seed. + Function returns (worker_seeds, shuffling_seeds), worker_seeds are later + used to initialize per-worker random number generators (mostly for + dropouts), shuffling_seeds are for RNGs resposible for reshuffling the + dataset before each epoch. + Seeds are generated on worker with rank 0 and broadcasted to all other + workers. + + :param master_seed: master RNG seed used to initialize other generators + :param epochs: number of epochs + :param device: torch.device (used for distributed.broadcast) + """ + if master_seed is None: + # random master seed, random.SystemRandom() uses /dev/urandom on Unix + master_seed = random.SystemRandom().randint(0, 2**32 - 1) + if get_rank() == 0: + # master seed is reported only from rank=0 worker, it's to avoid + # confusion, seeds from rank=0 are later broadcasted to other + # workers + logging.info(f'Using random master seed: {master_seed}') + else: + # master seed was specified from command line + logging.info(f'Using master seed from command line: {master_seed}') + + gnmt_print(key=mlperf_log.RUN_SET_RANDOM_SEED, value=master_seed, + sync=False) + + # initialize seeding RNG + seeding_rng = random.Random(master_seed) + + # generate worker seeds, one seed for every distributed worker + worker_seeds = generate_seeds(seeding_rng, get_world_size()) + + # generate seeds for data shuffling, one seed for every epoch + shuffling_seeds = generate_seeds(seeding_rng, epochs) + + # broadcast seeds from rank=0 to other workers + worker_seeds = broadcast_seeds(worker_seeds, device) + shuffling_seeds = broadcast_seeds(shuffling_seeds, device) + return worker_seeds, shuffling_seeds def barrier(): - """ Calls all_reduce on dummy tensor.""" - if torch.distributed.is_initialized(): + """ + Works as a temporary distributed barrier, currently pytorch + doesn't implement barrier for NCCL backend. + Calls all_reduce on dummy tensor and synchronizes with GPU. + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.all_reduce(torch.cuda.FloatTensor(1)) torch.cuda.synchronize() def get_rank(): - if torch.distributed.is_initialized(): + """ + Gets distributed rank or returns zero if distributed is not initialized. + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): rank = torch.distributed.get_rank() else: rank = 0 return rank + def get_world_size(): - if torch.distributed.is_initialized(): + """ + Gets total number of distributed workers or returns one if distributed is + not initialized. + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() else: world_size = 1 @@ -33,14 +163,34 @@ def get_world_size(): @contextmanager def sync_workers(): - """ Gets distributed rank and synchronizes workers at exit""" + """ + Yields distributed rank and synchronizes all workers on exit. + """ rank = get_rank() yield rank barrier() -def setup_logging(log_file='log.log'): - """Setup logging configuration +@contextmanager +def timer(name, ndigits=2, sync_gpu=True): + if sync_gpu: + torch.cuda.synchronize() + start = time.time() + yield + if sync_gpu: + torch.cuda.synchronize() + stop = time.time() + elapsed = round(stop - start, ndigits) + logging.info(f'TIMER {name} {elapsed}') + + +def setup_logging(log_file=os.devnull): + """ + Configures logging. + By default logs from all workers are printed to the console, entries are + prefixed with "N: " where N is the rank of the worker. Logs printed to the + console don't include timestaps. + Full logs with timestamps are saved to the log_file file. """ class RankFilter(logging.Filter): def __init__(self, rank): @@ -53,12 +203,13 @@ def filter(self, record): rank = get_rank() rank_filter = RankFilter(rank) + logging_format = "%(asctime)s - %(levelname)s - %(rank)s - %(message)s" logging.basicConfig(level=logging.DEBUG, - format="%(asctime)s - %(levelname)s - %(rank)s - %(message)s", + format=logging_format, datefmt="%Y-%m-%d %H:%M:%S", filename=log_file, filemode='w') - console = logging.StreamHandler() + console = logging.StreamHandler(sys.stdout) console.setLevel(logging.INFO) formatter = logging.Formatter('%(rank)s: %(message)s') console.setFormatter(formatter) @@ -66,9 +217,59 @@ def filter(self, record): logging.getLogger('').addFilter(rank_filter) -class AverageMeter(object): - """Computes and stores the average and current value""" +def set_device(cuda, local_rank): + """ + Sets device based on local_rank and returns instance of torch.device. + + :param cuda: if True: use cuda + :param local_rank: local rank of the worker + """ + if cuda: + torch.cuda.set_device(local_rank) + device = torch.device('cuda') + else: + device = torch.device('cpu') + return device + + +def init_distributed(cuda): + """ + Initializes distributed backend. + + :param cuda: (bool) if True initializes nccl backend, if False initializes + gloo backend + """ + world_size = int(os.environ.get('WORLD_SIZE', 1)) + distributed = (world_size > 1) + if distributed: + backend = 'nccl' if cuda else 'gloo' + dist.init_process_group(backend=backend, + init_method='env://') + assert dist.is_initialized() + return distributed + + +def log_env_info(): + """ + Prints information about execution environment. + """ + logging.info('Collecting environment information...') + env_info = torch.utils.collect_env.get_pretty_env_info() + logging.info(f'{env_info}') + + +def pad_vocabulary(math): + if math == 'fp16': + pad_vocab = 8 + elif math == 'fp32': + pad_vocab = 1 + return pad_vocab + +class AverageMeter: + """ + Computes and stores the average and current value + """ def __init__(self, skip_first=True): self.reset() self.skip = skip_first @@ -89,28 +290,62 @@ def update(self, val, n=1): self.count += n self.avg = self.sum / self.count + def reduce(self, op): + """ + Reduces average value over all workers. -def batch_padded_sequences(seq, batch_first=False, sort=False): - if sort: - key = lambda item: len(item[1]) - indices, seq = zip(*sorted(enumerate(seq), key=key, reverse=True)) - else: - indices = range(len(seq)) + :param op: 'sum' or 'mean', reduction operator + """ + if op not in ('sum', 'mean'): + raise NotImplementedError + + distributed = (get_world_size() > 1) + if distributed: + # Backward/forward compatibility around + # https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and + # https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86 + # To accomodate change in Pytorch's distributed API + if hasattr(dist, "get_backend"): + _backend = dist.get_backend() + if hasattr(dist, "DistBackend"): + backend_enum_holder = dist.DistBackend + else: + backend_enum_holder = dist.Backend + else: + _backend = dist._backend + backend_enum_holder = dist.dist_backend + + cuda = _backend == backend_enum_holder.NCCL + + if cuda: + avg = torch.cuda.FloatTensor([self.avg]) + _sum = torch.cuda.FloatTensor([self.sum]) + else: + avg = torch.FloatTensor([self.avg]) + _sum = torch.FloatTensor([self.sum]) - lengths = [len(sentence) for sentence in seq] - batch_length = max(lengths) - seq_tensor = torch.LongTensor(batch_length, len(seq)).fill_(config.PAD) - for idx, sentence in enumerate(seq): - end_seq = lengths[idx] - seq_tensor[:end_seq, idx].copy_(sentence[:end_seq]) - if batch_first: - seq_tensor = seq_tensor.t() - return seq_tensor, lengths, indices + _reduce_op = dist.reduce_op if hasattr(dist, "reduce_op") else dist.ReduceOp + dist.all_reduce(avg, op=_reduce_op.SUM) + dist.all_reduce(_sum, op=_reduce_op.SUM) + self.avg = avg.item() + self.sum = _sum.item() + + if op == 'mean': + self.avg /= get_world_size() + self.sum /= get_world_size() def debug_tensor(tensor, name): + """ + Simple utility which helps with debugging. + Takes a tensor and outputs: min, max, avg, std, number of NaNs, number of + INFs. + + :param tensor: torch tensor + :param name: name of the tensor (only for logging) + """ logging.info(name) - tensor = tensor.float().cpu().numpy() + tensor = tensor.detach().float().cpu().numpy() logging.info(f'MIN: {tensor.min()} MAX: {tensor.max()} ' f'AVG: {tensor.mean()} STD: {tensor.std()} ' f'NAN: {np.isnan(tensor).sum()} INF: {np.isinf(tensor).sum()}') diff --git a/rnn_translator/pytorch/train.py b/rnn_translator/pytorch/train.py index daf7b9fd4..a4906c9ee 100644 --- a/rnn_translator/pytorch/train.py +++ b/rnn_translator/pytorch/train.py @@ -1,224 +1,319 @@ #!/usr/bin/env python import argparse -import os import logging +import os +import sys from ast import literal_eval -import subprocess import torch.nn as nn import torch.nn.parallel -import torch.utils.data.distributed -import torch.distributed as dist import torch.optim - +import torch.utils.data.distributed from mlperf_compliance import mlperf_log -from seq2seq import models -from seq2seq.train.smoothing import LabelSmoothing -from seq2seq.data.dataset import ParallelDataset -from seq2seq.data.tokenizer import Tokenizer -from seq2seq.utils import setup_logging import seq2seq.data.config as config import seq2seq.train.trainer as trainers +import seq2seq.utils as utils +from seq2seq.data.dataset import LazyParallelDataset +from seq2seq.data.dataset import ParallelDataset +from seq2seq.data.dataset import TextDataset +from seq2seq.data.tokenizer import Tokenizer from seq2seq.inference.inference import Translator +from seq2seq.models.gnmt import GNMT +from seq2seq.train.smoothing import LabelSmoothing +from seq2seq.utils import gnmt_print def parse_args(): - parser = argparse.ArgumentParser(description='GNMT training', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + """ + Parse commandline arguments. + """ + def exclusive_group(group, name, default, help): + destname = name.replace('-', '_') + subgroup = group.add_mutually_exclusive_group(required=False) + subgroup.add_argument(f'--{name}', dest=f'{destname}', + action='store_true', + help=f'{help} (use \'--no-{name}\' to disable)') + subgroup.add_argument(f'--no-{name}', dest=f'{destname}', + action='store_false', help=argparse.SUPPRESS) + subgroup.set_defaults(**{destname: default}) + + parser = argparse.ArgumentParser( + description='GNMT training', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) # dataset dataset = parser.add_argument_group('dataset setup') - dataset.add_argument('--dataset-dir', default=None, required=True, - help='path to directory with training/validation data') + dataset.add_argument('--dataset-dir', default='data/wmt16_de_en', + help='path to the directory with training/test data') dataset.add_argument('--max-size', default=None, type=int, help='use at most MAX_SIZE elements from training \ - dataset (useful for benchmarking), by default \ - uses entire dataset') + dataset (useful for benchmarking), by default \ + uses entire dataset') # results results = parser.add_argument_group('results setup') - results.add_argument('--results-dir', default='../results', - help='path to directory with results, it it will be \ - automatically created if does not exist') - results.add_argument('--save', default='gnmt_wmt16', + results.add_argument('--results-dir', default='results', + help='path to directory with results, it will be \ + automatically created if it does not exist') + results.add_argument('--save', default='gnmt', help='defines subdirectory within RESULTS_DIR for \ - results from this training run') + results from this training run') results.add_argument('--print-freq', default=10, type=int, help='print log every PRINT_FREQ batches') # model model = parser.add_argument_group('model setup') - model.add_argument('--model-config', - default="{'hidden_size': 1024,'num_layers': 4, \ - 'dropout': 0.2, 'share_embedding': True}", - help='GNMT architecture configuration') + model.add_argument('--hidden-size', default=1024, type=int, + help='model hidden size') + model.add_argument('--num-layers', default=4, type=int, + help='number of RNN layers in encoder and in decoder') + model.add_argument('--dropout', default=0.2, type=float, + help='dropout applied to input of RNN cells') + + exclusive_group(group=model, name='share-embedding', default=True, + help='use shared embeddings for encoder and decoder') + model.add_argument('--smoothing', default=0.1, type=float, help='label smoothing, if equal to zero model will use \ - CrossEntropyLoss, if not zero model will be trained \ - with label smoothing loss based on KLDivLoss') + CrossEntropyLoss, if not zero model will be trained \ + with label smoothing loss') # setup general = parser.add_argument_group('general setup') - general.add_argument('--math', default='fp32', choices=['fp32', 'fp16'], + general.add_argument('--math', default='fp32', choices=['fp16', 'fp32'], help='arithmetic type') general.add_argument('--seed', default=None, type=int, - help='set random number generator seed') - general.add_argument('--disable-eval', action='store_true', default=False, - help='disables validation after every epoch') - general.add_argument('--workers', default=0, type=int, - help='number of workers for data loading') - - cuda_parser = general.add_mutually_exclusive_group(required=False) - cuda_parser.add_argument('--cuda', dest='cuda', action='store_true', - help='enables cuda (use \'--no-cuda\' to disable)') - cuda_parser.add_argument('--no-cuda', dest='cuda', action='store_false', - help=argparse.SUPPRESS) - cuda_parser.set_defaults(cuda=True) - - cudnn_parser = general.add_mutually_exclusive_group(required=False) - cudnn_parser.add_argument('--cudnn', dest='cudnn', action='store_true', - help='enables cudnn (use \'--no-cudnn\' to disable)') - cudnn_parser.add_argument('--no-cudnn', dest='cudnn', action='store_false', - help=argparse.SUPPRESS) - cudnn_parser.set_defaults(cudnn=True) + help='master seed for random number generators, if \ + "seed" is undefined then the master seed will be \ + sampled from random.SystemRandom()') + + exclusive_group(group=general, name='eval', default=True, + help='run validation and test after every epoch') + exclusive_group(group=general, name='env', default=False, + help='print info about execution env') + exclusive_group(group=general, name='cuda', default=True, + help='enables cuda') + exclusive_group(group=general, name='cudnn', default=True, + help='enables cudnn') # training training = parser.add_argument_group('training setup') - training.add_argument('--batch-size', default=128, type=int, - help='batch size for training') + training.add_argument('--train-batch-size', default=128, type=int, + help='training batch size per worker') + training.add_argument('--train-global-batch-size', default=None, type=int, + help='global training batch size, this argument \ + does not have to be defined, if it is defined it \ + will be used to automatically \ + compute train_iter_size \ + using the equation: train_iter_size = \ + train_global_batch_size // (train_batch_size * \ + world_size)') + training.add_argument('--train-iter-size', metavar='N', default=1, + type=int, + help='training iter size, training loop will \ + accumulate gradients over N iterations and execute \ + optimizer every N steps') training.add_argument('--epochs', default=8, type=int, - help='number of total epochs to run') - training.add_argument('--optimization-config', - default="{'optimizer': 'Adam', 'lr': 5e-4}", type=str, - help='optimizer config') + help='max number of training epochs') + training.add_argument('--grad-clip', default=5.0, type=float, - help='enabled gradient clipping and sets maximum \ - gradient norm value') + help='enables gradient clipping and sets maximum \ + norm of gradients') training.add_argument('--max-length-train', default=50, type=int, - help='maximum sequence length for training') + help='maximum sequence length for training \ + (including special BOS and EOS tokens)') training.add_argument('--min-length-train', default=0, type=int, - help='minimum sequence length for training') - training.add_argument('--target-bleu', default=None, type=float, - help='target accuracy') - - bucketing_parser = training.add_mutually_exclusive_group(required=False) - bucketing_parser.add_argument('--bucketing', dest='bucketing', action='store_true', - help='enables bucketing (use \'--no-bucketing\' to disable)') - bucketing_parser.add_argument('--no-bucketing', dest='bucketing', action='store_false', - help=argparse.SUPPRESS) - bucketing_parser.set_defaults(bucketing=True) + help='minimum sequence length for training \ + (including special BOS and EOS tokens)') + training.add_argument('--train-loader-workers', default=2, type=int, + help='number of workers for training data loading') + training.add_argument('--batching', default='bucketing', type=str, + choices=['random', 'sharding', 'bucketing'], + help='select batching algorithm') + training.add_argument('--shard-size', default=80, type=int, + help='shard size for "sharding" batching algorithm, \ + in multiples of global batch size') + training.add_argument('--num-buckets', default=5, type=int, + help='number of buckets for "bucketing" batching \ + algorithm') + + # optimizer + optimizer = parser.add_argument_group('optimizer setup') + optimizer.add_argument('--optimizer', type=str, default='Adam', + help='training optimizer') + optimizer.add_argument('--lr', type=float, default=1.00e-3, + help='learning rate') + optimizer.add_argument('--optimizer-extra', type=str, + default="{}", + help='extra options for the optimizer') + + # scheduler + scheduler = parser.add_argument_group('learning rate scheduler setup') + scheduler.add_argument('--warmup-steps', type=str, default='200', + help='number of learning rate warmup iterations') + scheduler.add_argument('--remain-steps', type=str, default='0.666', + help='starting iteration for learning rate decay') + scheduler.add_argument('--decay-interval', type=str, default='None', + help='interval between learning rate decay steps') + scheduler.add_argument('--decay-steps', type=int, default=4, + help='max number of learning rate decay steps') + scheduler.add_argument('--decay-factor', type=float, default=0.5, + help='learning rate decay factor') # validation - validation = parser.add_argument_group('validation setup') - validation.add_argument('--eval-batch-size', default=32, type=int, - help='batch size for validation') - validation.add_argument('--max-length-val', default=150, type=int, - help='maximum sequence length for validation') - validation.add_argument('--min-length-val', default=0, type=int, - help='minimum sequence length for validation') - - validation.add_argument('--beam-size', default=5, type=int, - help='beam size') - validation.add_argument('--len-norm-factor', default=0.6, type=float, - help='length normalization factor') - validation.add_argument('--cov-penalty-factor', default=0.1, type=float, - help='coverage penalty factor') - validation.add_argument('--len-norm-const', default=5.0, type=float, - help='length normalization constant') - + val = parser.add_argument_group('validation setup') + val.add_argument('--val-batch-size', default=64, type=int, + help='batch size for validation') + val.add_argument('--max-length-val', default=125, type=int, + help='maximum sequence length for validation \ + (including special BOS and EOS tokens)') + val.add_argument('--min-length-val', default=0, type=int, + help='minimum sequence length for validation \ + (including special BOS and EOS tokens)') + val.add_argument('--val-loader-workers', default=0, type=int, + help='number of workers for validation data loading') + + # test + test = parser.add_argument_group('test setup') + test.add_argument('--test-batch-size', default=128, type=int, + help='batch size for test') + test.add_argument('--max-length-test', default=150, type=int, + help='maximum sequence length for test \ + (including special BOS and EOS tokens)') + test.add_argument('--min-length-test', default=0, type=int, + help='minimum sequence length for test \ + (including special BOS and EOS tokens)') + test.add_argument('--beam-size', default=5, type=int, + help='beam size') + test.add_argument('--len-norm-factor', default=0.6, type=float, + help='length normalization factor') + test.add_argument('--cov-penalty-factor', default=0.1, type=float, + help='coverage penalty factor') + test.add_argument('--len-norm-const', default=5.0, type=float, + help='length normalization constant') + test.add_argument('--intra-epoch-eval', metavar='N', default=0, type=int, + help='evaluate within training epoch, this option will \ + enable extra N equally spaced evaluations executed \ + during each training epoch') + test.add_argument('--test-loader-workers', default=0, type=int, + help='number of workers for test data loading') # checkpointing - checkpoint = parser.add_argument_group('checkpointing setup') - checkpoint.add_argument('--start-epoch', default=0, type=int, - help='manually set initial epoch counter') - checkpoint.add_argument('--resume', default=None, type=str, metavar='PATH', - help='resumes training from checkpoint from PATH') - checkpoint.add_argument('--save-all', action='store_true', default=False, - help='saves checkpoint after every epoch') - checkpoint.add_argument('--save-freq', default=5000, type=int, - help='save checkpoint every SAVE_FREQ batches') - checkpoint.add_argument('--keep-checkpoints', default=0, type=int, - help='keep only last KEEP_CHECKPOINTS checkpoints, \ - affects only checkpoints controlled by --save-freq \ - option') - - # distributed support + chkpt = parser.add_argument_group('checkpointing setup') + chkpt.add_argument('--start-epoch', default=0, type=int, + help='manually set initial epoch counter') + chkpt.add_argument('--resume', default=None, type=str, metavar='PATH', + help='resumes training from checkpoint from PATH') + chkpt.add_argument('--save-all', action='store_true', default=False, + help='saves checkpoint after every epoch') + chkpt.add_argument('--save-freq', default=5000, type=int, + help='save checkpoint every SAVE_FREQ batches') + chkpt.add_argument('--keep-checkpoints', default=0, type=int, + help='keep only last KEEP_CHECKPOINTS checkpoints, \ + affects only checkpoints controlled by --save-freq \ + option') + + # benchmarking + benchmark = parser.add_argument_group('benchmark setup') + benchmark.add_argument('--target-bleu', default=24.0, type=float, + help='target accuracy, training will be stopped \ + when the target is achieved') + + # distributed distributed = parser.add_argument_group('distributed setup') distributed.add_argument('--rank', default=0, type=int, - help='rank of the process, do not set! Done by multiproc module') - distributed.add_argument('--world-size', default=1, type=int, - help='number of processes, do not set! Done by multiproc module') - distributed.add_argument('--dist-url', default='tcp://localhost:23456', type=str, - help='url used to set up distributed training') + help='global rank of the process, do not set!') + distributed.add_argument('--local_rank', default=0, type=int, + help='local rank of the process, do not set!') + + args = parser.parse_args() + + args.warmup_steps = literal_eval(args.warmup_steps) + args.remain_steps = literal_eval(args.remain_steps) + args.decay_interval = literal_eval(args.decay_interval) - return parser.parse_args() + return args def build_criterion(vocab_size, padding_idx, smoothing): if smoothing == 0.: - logging.info(f'building CrossEntropyLoss') + logging.info(f'Building CrossEntropyLoss') loss_weight = torch.ones(vocab_size) loss_weight[padding_idx] = 0 criterion = nn.CrossEntropyLoss(weight=loss_weight, size_average=False) - mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_LOSS_FN, - value='Cross Entropy') + gnmt_print(key=mlperf_log.MODEL_HP_LOSS_FN, + value='Cross Entropy', sync=False) else: - logging.info(f'building SmoothingLoss (smoothing: {smoothing})') + logging.info(f'Building LabelSmoothingLoss (smoothing: {smoothing})') criterion = LabelSmoothing(padding_idx, smoothing) - mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_LOSS_FN, - value='Cross Entropy with label smoothing') - mlperf_log.gnmt_print(key=mlperf_log.MODEL_HP_LOSS_SMOOTHING, - value=smoothing) + gnmt_print(key=mlperf_log.MODEL_HP_LOSS_FN, + value='Cross Entropy with label smoothing', sync=False) + gnmt_print(key=mlperf_log.MODEL_HP_LOSS_SMOOTHING, + value=smoothing, sync=False) return criterion def main(): + """ + Launches data-parallel multi-gpu training. + """ mlperf_log.ROOT_DIR_GNMT = os.path.dirname(os.path.abspath(__file__)) mlperf_log.LOGGER.propagate = False - mlperf_log.gnmt_print(key=mlperf_log.RUN_START) args = parse_args() - print(args) + device = utils.set_device(args.cuda, args.local_rank) + distributed = utils.init_distributed(args.cuda) + gnmt_print(key=mlperf_log.RUN_START, sync=True) + args.rank = utils.get_rank() if not args.cudnn: torch.backends.cudnn.enabled = False - mlperf_log.gnmt_print(key=mlperf_log.RUN_SET_RANDOM_SEED) - if args.seed: - torch.manual_seed(args.seed + args.rank) - - # initialize distributed backend - distributed = args.world_size > 1 - if distributed: - backend = 'nccl' if args.cuda else 'gloo' - dist.init_process_group(backend=backend, rank=args.rank, - init_method=args.dist_url, - world_size=args.world_size) # create directory for results save_path = os.path.join(args.results_dir, args.save) + args.save_path = save_path os.makedirs(save_path, exist_ok=True) # setup logging - log_filename = f'log_gpu_{args.rank}.log' - setup_logging(os.path.join(save_path, log_filename)) + log_filename = f'log_rank_{utils.get_rank()}.log' + utils.setup_logging(os.path.join(save_path, log_filename)) + + if args.env: + utils.log_env_info() logging.info(f'Saving results to: {save_path}') logging.info(f'Run arguments: {args}') - if args.cuda: - torch.cuda.set_device(args.rank) + # automatically set train_iter_size based on train_global_batch_size, + # world_size and per-worker train_batch_size + if args.train_global_batch_size is not None: + global_bs = args.train_global_batch_size + bs = args.train_batch_size + world_size = utils.get_world_size() + assert global_bs % (bs * world_size) == 0 + args.train_iter_size = global_bs // (bs * world_size) + logging.info(f'Global batch size was set in the config, ' + f'Setting train_iter_size to {args.train_iter_size}') + + worker_seeds, shuffling_seeds = utils.setup_seeds(args.seed, args.epochs, + device) + worker_seed = worker_seeds[args.rank] + logging.info(f'Worker {args.rank} is using worker seed: {worker_seed}') + torch.manual_seed(worker_seed) # build tokenizer - tokenizer = Tokenizer(os.path.join(args.dataset_dir, config.VOCAB_FNAME)) + pad_vocab = utils.pad_vocabulary(args.math) + tokenizer = Tokenizer(os.path.join(args.dataset_dir, config.VOCAB_FNAME), + pad_vocab) # build datasets - mlperf_log.gnmt_print(key=mlperf_log.PREPROC_TOKENIZE_TRAINING) - mlperf_log.gnmt_print(key=mlperf_log.TRAIN_HP_MAX_SEQ_LEN, - value=args.max_length_train) + gnmt_print(key=mlperf_log.PREPROC_TOKENIZE_TRAINING, sync=False) + gnmt_print(key=mlperf_log.TRAIN_HP_MAX_SEQ_LEN, + value=args.max_length_train, sync=False) - train_data = ParallelDataset( + train_data = LazyParallelDataset( src_fname=os.path.join(args.dataset_dir, config.SRC_TRAIN_FNAME), tgt_fname=os.path.join(args.dataset_dir, config.TGT_TRAIN_FNAME), tokenizer=tokenizer, @@ -227,8 +322,8 @@ def main(): sort=False, max_size=args.max_size) - mlperf_log.gnmt_print(key=mlperf_log.PREPROC_NUM_TRAIN_EXAMPLES, - value=len(train_data)) + gnmt_print(key=mlperf_log.PREPROC_NUM_TRAIN_EXAMPLES, + value=len(train_data), sync=False) val_data = ParallelDataset( src_fname=os.path.join(args.dataset_dir, config.SRC_VAL_FNAME), @@ -238,65 +333,121 @@ def main(): max_len=args.max_length_val, sort=True) - mlperf_log.gnmt_print(key=mlperf_log.PREPROC_TOKENIZE_EVAL) + gnmt_print(key=mlperf_log.PREPROC_TOKENIZE_EVAL, sync=False) - test_data = ParallelDataset( + test_data = TextDataset( src_fname=os.path.join(args.dataset_dir, config.SRC_TEST_FNAME), - tgt_fname=os.path.join(args.dataset_dir, config.TGT_TEST_FNAME), tokenizer=tokenizer, - min_len=args.min_length_val, - max_len=args.max_length_val, - sort=False) + min_len=args.min_length_test, + max_len=args.max_length_test, + sort=True) - mlperf_log.gnmt_print(key=mlperf_log.PREPROC_NUM_EVAL_EXAMPLES, - value=len(test_data)) + gnmt_print(key=mlperf_log.PREPROC_NUM_EVAL_EXAMPLES, + value=len(test_data), sync=False) vocab_size = tokenizer.vocab_size - mlperf_log.gnmt_print(key=mlperf_log.PREPROC_VOCAB_SIZE, value=vocab_size) + gnmt_print(key=mlperf_log.PREPROC_VOCAB_SIZE, + value=vocab_size, sync=False) # build GNMT model - model_config = dict(vocab_size=vocab_size, math=args.math, - **literal_eval(args.model_config)) - model = models.GNMT(**model_config) + model_config = {'hidden_size': args.hidden_size, + 'num_layers': args.num_layers, + 'dropout': args.dropout, 'batch_first': False, + 'share_embedding': args.share_embedding} + model = GNMT(vocab_size=vocab_size, **model_config) logging.info(model) batch_first = model.batch_first # define loss function (criterion) and optimizer criterion = build_criterion(vocab_size, config.PAD, args.smoothing) - opt_config = literal_eval(args.optimization_config) - logging.info(f'Training optimizer: {opt_config}') + + opt_config = {'optimizer': args.optimizer, 'lr': args.lr} + opt_config.update(literal_eval(args.optimizer_extra)) + logging.info(f'Training optimizer config: {opt_config}') + + scheduler_config = {'warmup_steps': args.warmup_steps, + 'remain_steps': args.remain_steps, + 'decay_interval': args.decay_interval, + 'decay_steps': args.decay_steps, + 'decay_factor': args.decay_factor} + + logging.info(f'Training LR schedule config: {scheduler_config}') + + num_parameters = sum([l.nelement() for l in model.parameters()]) + logging.info(f'Number of parameters: {num_parameters}') + + batching_opt = {'shard_size': args.shard_size, + 'num_buckets': args.num_buckets} + # get data loaders + train_loader = train_data.get_loader(batch_size=args.train_batch_size, + seeds=shuffling_seeds, + batch_first=batch_first, + shuffle=True, + batching=args.batching, + batching_opt=batching_opt, + num_workers=args.train_loader_workers) + + gnmt_print(key=mlperf_log.INPUT_BATCH_SIZE, + value=args.train_batch_size * utils.get_world_size(), + sync=False) + gnmt_print(key=mlperf_log.INPUT_SIZE, + value=train_loader.sampler.num_samples, sync=False) + + val_loader = val_data.get_loader(batch_size=args.val_batch_size, + batch_first=batch_first, + shuffle=False, + num_workers=args.val_loader_workers) + + test_loader = test_data.get_loader(batch_size=args.test_batch_size, + batch_first=batch_first, + shuffle=False, + pad=True, + num_workers=args.test_loader_workers) + + gnmt_print(key=mlperf_log.EVAL_SIZE, + value=len(test_loader.dataset), sync=False) + + translator = Translator(model=model, + tokenizer=tokenizer, + loader=test_loader, + beam_size=args.beam_size, + max_seq_len=args.max_length_test, + len_norm_factor=args.len_norm_factor, + len_norm_const=args.len_norm_const, + cov_penalty_factor=args.cov_penalty_factor, + cuda=args.cuda, + print_freq=args.print_freq, + dataset_dir=args.dataset_dir, + target_bleu=args.target_bleu, + save_path=args.save_path) # create trainer + total_train_iters = len(train_loader) // args.train_iter_size * args.epochs + save_info = {'model_config': model_config, 'config': args, 'tokenizer': + tokenizer.get_state()} trainer_options = dict( criterion=criterion, grad_clip=args.grad_clip, + iter_size=args.train_iter_size, save_path=save_path, save_freq=args.save_freq, - save_info={'config': args, 'tokenizer': tokenizer}, + save_info=save_info, opt_config=opt_config, + scheduler_config=scheduler_config, + train_iterations=total_train_iters, batch_first=batch_first, keep_checkpoints=args.keep_checkpoints, math=args.math, print_freq=args.print_freq, cuda=args.cuda, - distributed=distributed) + distributed=distributed, + intra_epoch_eval=args.intra_epoch_eval, + translator=translator) trainer_options['model'] = model trainer = trainers.Seq2SeqTrainer(**trainer_options) - translator = Translator(model, - tokenizer, - beam_size=args.beam_size, - max_seq_len=args.max_length_val, - len_norm_factor=args.len_norm_factor, - len_norm_const=args.len_norm_const, - cov_penalty_factor=args.cov_penalty_factor, - cuda=args.cuda) - - num_parameters = sum([l.nelement() for l in model.parameters()]) - logging.info(f'Number of parameters: {num_parameters}') - # optionally resume from a checkpoint if args.resume: checkpoint_file = args.resume @@ -308,177 +459,69 @@ def main(): else: logging.error(f'No checkpoint found at {args.resume}') - # get data loaders - train_loader = train_data.get_loader(batch_size=args.batch_size, - batch_first=batch_first, - shuffle=True, - bucket=args.bucketing, - num_workers=args.workers, - drop_last=True, - distributed=distributed) - - mlperf_log.gnmt_print(key=mlperf_log.INPUT_BATCH_SIZE, - value=args.batch_size * args.world_size) - mlperf_log.gnmt_print(key=mlperf_log.INPUT_SIZE, - value=train_loader.sampler.num_samples) - - - val_loader = val_data.get_loader(batch_size=args.eval_batch_size, - batch_first=batch_first, - shuffle=False, - num_workers=args.workers, - drop_last=False, - distributed=False) - - test_loader = test_data.get_loader(batch_size=args.eval_batch_size, - batch_first=batch_first, - shuffle=False, - num_workers=0, - drop_last=False, - distributed=False) - - mlperf_log.gnmt_print(key=mlperf_log.EVAL_SIZE, - value=len(test_loader.sampler)) - # training loop best_loss = float('inf') - mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP) + break_training = False + test_bleu = None + gnmt_print(key=mlperf_log.TRAIN_LOOP, sync=True) for epoch in range(args.start_epoch, args.epochs): - mlperf_log.gnmt_print(key=mlperf_log.TRAIN_EPOCH, - value=epoch) logging.info(f'Starting epoch {epoch}') + gnmt_print(key=mlperf_log.TRAIN_EPOCH, + value=epoch, sync=True) - if distributed: - train_loader.sampler.set_epoch(epoch) + train_loader.sampler.set_epoch(epoch) trainer.epoch = epoch - train_loss = trainer.optimize(train_loader) + train_loss, train_perf = trainer.optimize(train_loader) # evaluate on validation set - if args.rank == 0 and not args.disable_eval: + if args.eval: logging.info(f'Running validation on dev set') - val_loss = trainer.evaluate(val_loader) + val_loss, val_perf = trainer.evaluate(val_loader) # remember best prec@1 and save checkpoint - is_best = val_loss < best_loss - best_loss = min(val_loss, best_loss) - - mlperf_log.gnmt_print(key=mlperf_log.TRAIN_CHECKPOINT) - trainer.save(save_all=args.save_all, is_best=is_best) - - logging.info(f'Epoch: {epoch}\t' - f'Training Loss {train_loss:.4f}\t' - f'Validation Loss {val_loss:.4f}') - else: - logging.info(f'Epoch: {epoch}\t' - f'Training Loss {train_loss:.4f}') - - if args.cuda: - break_training = torch.cuda.LongTensor([0]) - else: - break_training = torch.LongTensor([0]) - - if args.rank == 0 and not args.disable_eval: - logging.info(f'Running evaluation on test set') - mlperf_log.gnmt_print(key=mlperf_log.EVAL_START, value=epoch) - - model.eval() - torch.cuda.empty_cache() - - eval_path = os.path.join(save_path, f'eval_epoch_{epoch}') - eval_file = open(eval_path, 'w') - - for i, (src, tgt, indices) in enumerate(test_loader): - src, src_length = src - - if translator.batch_first: - batch_size = src.size(0) - else: - batch_size = src.size(1) - beam_size = args.beam_size - - bos = [translator.insert_target_start] * (batch_size * beam_size) - bos = torch.LongTensor(bos) - if translator.batch_first: - bos = bos.view(-1, 1) - else: - bos = bos.view(1, -1) - - src_length = torch.LongTensor(src_length) - - if args.cuda: - src = src.cuda() - src_length = src_length.cuda() - bos = bos.cuda() - - with torch.no_grad(): - context = translator.model.encode(src, src_length) - context = [context, src_length, None] - - if beam_size == 1: - generator = translator.generator.greedy_search - else: - generator = translator.generator.beam_search - preds, lengths, counter = generator(batch_size, bos, context) - - preds = preds.cpu() - lengths = lengths.cpu() - - output = [] - for idx, pred in enumerate(preds): - end = lengths[idx] - 1 - pred = pred[1: end] - pred = pred.tolist() - out = translator.tok.detokenize(pred) - output.append(out) - - output = [output[indices.index(i)] for i in range(len(output))] - for line in output: - eval_file.write(line) - eval_file.write('\n') - - eval_file.close() - - # run moses detokenizer - detok_path = os.path.join(args.dataset_dir, config.DETOKENIZER) - detok_eval_path = eval_path + '.detok' - - with open(detok_eval_path, 'w') as detok_eval_file, \ - open(eval_path, 'r') as eval_file: - subprocess.run(['perl', f'{detok_path}'], stdin=eval_file, - stdout=detok_eval_file, stderr=subprocess.DEVNULL) - - # run sacrebleu - reference_path = os.path.join(args.dataset_dir, config.TGT_TEST_TARGET_FNAME) - sacrebleu = subprocess.run([f'sacrebleu --input {detok_eval_path} \ - {reference_path} --score-only -lc --tokenize intl'], - stdout=subprocess.PIPE, shell=True) - bleu = float(sacrebleu.stdout.strip()) - logging.info(f'Finished evaluation on test set') - logging.info(f'BLEU on test dataset: {bleu}') - - if args.target_bleu: - if bleu >= args.target_bleu: - logging.info(f'Target accuracy reached') - break_training[0] = 1 - - torch.cuda.empty_cache() - mlperf_log.gnmt_print(key=mlperf_log.EVAL_ACCURACY, - value={"epoch": epoch, "value": bleu}) - mlperf_log.gnmt_print(key=mlperf_log.EVAL_TARGET, - value=args.target_bleu) - mlperf_log.gnmt_print(key=mlperf_log.EVAL_STOP) - - if distributed: - dist.broadcast(break_training, 0) + gnmt_print(key=mlperf_log.TRAIN_CHECKPOINT, sync=False) + if args.rank == 0: + is_best = val_loss < best_loss + best_loss = min(val_loss, best_loss) + trainer.save(save_all=args.save_all, is_best=is_best) + + if args.eval: + gnmt_print(key=mlperf_log.EVAL_START, value=epoch, sync=True) + test_bleu, break_training = translator.run(calc_bleu=True, + epoch=epoch) + gnmt_print(key=mlperf_log.EVAL_ACCURACY, + value={"epoch": epoch, "value": round(test_bleu, 2)}, + sync=False) + gnmt_print(key=mlperf_log.EVAL_TARGET, + value=args.target_bleu, sync=False) + gnmt_print(key=mlperf_log.EVAL_STOP, sync=True) + + acc_log = [] + acc_log += [f'Summary: Epoch: {epoch}'] + acc_log += [f'Training Loss: {train_loss:.4f}'] + if args.eval: + acc_log += [f'Validation Loss: {val_loss:.4f}'] + acc_log += [f'Test BLEU: {test_bleu:.2f}'] + + perf_log = [] + perf_log += [f'Performance: Epoch: {epoch}'] + perf_log += [f'Training: {train_perf:.0f} Tok/s'] + if args.eval: + perf_log += [f'Validation: {val_perf:.0f} Tok/s'] + + if args.rank == 0: + logging.info('\t'.join(acc_log)) + logging.info('\t'.join(perf_log)) logging.info(f'Finished epoch {epoch}') if break_training: break - mlperf_log.gnmt_print(key=mlperf_log.RUN_STOP, - value={"success": bool(break_training)}) - mlperf_log.gnmt_print(key=mlperf_log.RUN_FINAL) + gnmt_print(key=mlperf_log.RUN_STOP, + value={"success": bool(break_training)}, sync=True) + gnmt_print(key=mlperf_log.RUN_FINAL, sync=False) + if __name__ == '__main__': main() diff --git a/rnn_translator/pytorch/translate.py b/rnn_translator/pytorch/translate.py index 47bc374ea..ae4c384d3 100644 --- a/rnn_translator/pytorch/translate.py +++ b/rnn_translator/pytorch/translate.py @@ -1,37 +1,64 @@ #!/usr/bin/env python import argparse -import codecs -import time +import logging +import os import warnings from ast import literal_eval -from itertools import zip_longest +from itertools import product import torch +import torch.distributed as dist -from seq2seq import models +import seq2seq.utils as utils +from seq2seq.data.dataset import TextDataset +from seq2seq.data.tokenizer import Tokenizer from seq2seq.inference.inference import Translator -from seq2seq.utils import AverageMeter +from seq2seq.models.gnmt import GNMT +from seq2seq.utils import setup_logging def parse_args(): - parser = argparse.ArgumentParser(description='GNMT Translate', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - # data + """ + Parse commandline arguments. + """ + def exclusive_group(group, name, default, help): + destname = name.replace('-', '_') + subgroup = group.add_mutually_exclusive_group(required=False) + subgroup.add_argument(f'--{name}', dest=f'{destname}', + action='store_true', + help=f'{help} (use \'--no-{name}\' to disable)') + subgroup.add_argument(f'--no-{name}', dest=f'{destname}', + action='store_false', help=argparse.SUPPRESS) + subgroup.set_defaults(**{destname: default}) + + parser = argparse.ArgumentParser( + description='GNMT Translate', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + # dataset dataset = parser.add_argument_group('data setup') + dataset.add_argument('--dataset-dir', default='data/wmt16_de_en/', + help='path to directory with training/test data') dataset.add_argument('-i', '--input', required=True, - help='input file (tokenized)') + help='full path to the input file (tokenized)') dataset.add_argument('-o', '--output', required=True, - help='output file (tokenized)') + help='full path to the output file (tokenized)') + dataset.add_argument('-r', '--reference', default=None, + help='full path to the file with reference \ + translations (for sacrebleu)') dataset.add_argument('-m', '--model', required=True, - help='model checkpoint file') + help='full path to the model checkpoint file') + exclusive_group(group=dataset, name='sort', default=True, + help='sorts dataset by sequence length') + # parameters params = parser.add_argument_group('inference setup') - params.add_argument('--batch-size', default=128, type=int, - help='batch size') - params.add_argument('--beam-size', default=5, type=int, + params.add_argument('--batch-size', nargs='+', default=[128], type=int, + help='batch size per GPU') + params.add_argument('--beam-size', nargs='+', default=[5], type=int, help='beam size') params.add_argument('--max-seq-len', default=80, type=int, - help='maximum prediciton sequence length') + help='maximum generated sequence length') params.add_argument('--len-norm-factor', default=0.6, type=float, help='length normalization factor') params.add_argument('--cov-penalty-factor', default=0.1, type=float, @@ -40,8 +67,18 @@ def parse_args(): help='length normalization constant') # general setup general = parser.add_argument_group('general setup') - general.add_argument('--math', default='fp16', choices=['fp32', 'fp16'], - help='arithmetic type') + general.add_argument('--math', nargs='+', default=['fp16'], + choices=['fp16', 'fp32'], help='arithmetic type') + + exclusive_group(group=general, name='env', default=True, + help='print info about execution env') + exclusive_group(group=general, name='bleu', default=True, + help='compares with reference translation and computes \ + BLEU') + exclusive_group(group=general, name='cuda', default=True, + help='enables cuda') + exclusive_group(group=general, name='cudnn', default=True, + help='enables cudnn') batch_first_parser = general.add_mutually_exclusive_group(required=False) batch_first_parser.add_argument('--batch-first', dest='batch_first', @@ -54,170 +91,103 @@ def parse_args(): format for RNNs') batch_first_parser.set_defaults(batch_first=True) - cuda_parser = general.add_mutually_exclusive_group(required=False) - cuda_parser.add_argument('--cuda', dest='cuda', action='store_true', - help='enables cuda (use \'--no-cuda\' to disable)') - cuda_parser.add_argument('--no-cuda', dest='cuda', action='store_false', - help=argparse.SUPPRESS) - cuda_parser.set_defaults(cuda=True) - - cudnn_parser = general.add_mutually_exclusive_group(required=False) - cudnn_parser.add_argument('--cudnn', dest='cudnn', action='store_true', - help='enables cudnn (use \'--no-cudnn\' to disable)') - cudnn_parser.add_argument('--no-cudnn', dest='cudnn', action='store_false', - help=argparse.SUPPRESS) - cudnn_parser.set_defaults(cudnn=True) - general.add_argument('--print-freq', '-p', default=1, type=int, help='print log every PRINT_FREQ batches') - return parser.parse_args() - - -def grouper(iterable, size, fillvalue=None): - args = [iter(iterable)] * size - return zip_longest(*args, fillvalue=fillvalue) - + # distributed + distributed = parser.add_argument_group('distributed setup') + distributed.add_argument('--rank', default=0, type=int, + help='global rank of the process, do not set!') + distributed.add_argument('--local_rank', default=0, type=int, + help='local rank of the process, do not set!') -def write_output(output_file, lines): - for line in lines: - output_file.write(line) - output_file.write('\n') + args = parser.parse_args() + if args.bleu and args.reference is None: + parser.error('--bleu requires --reference') -def checkpoint_from_distributed(state_dict): - ret = False - for key, _ in state_dict.items(): - if key.find('module.') != -1: - ret = True - break - return ret + if 'fp16' in args.math and not args.cuda: + parser.error('--math fp16 requires --cuda') - -def unwrap_distributed(state_dict): - new_state_dict = {} - for key, value in state_dict.items(): - new_key = key.replace('module.', '') - new_state_dict[new_key] = value - - return new_state_dict + return args def main(): + """ + Launches translation (inference). + Inference is executed on a single GPU, implementation supports beam search + with length normalization and coverage penalty. + """ args = parse_args() - print(args) + utils.set_device(args.cuda, args.local_rank) + utils.init_distributed(args.cuda) + setup_logging() + + if args.env: + utils.log_env_info() + + logging.info(f'Run arguments: {args}') - if args.cuda: - torch.cuda.set_device(0) if not args.cuda and torch.cuda.is_available(): warnings.warn('cuda is available but not enabled') - if args.math == 'fp16' and not args.cuda: - raise RuntimeError('fp16 requires cuda') if not args.cudnn: torch.backends.cudnn.enabled = False + # load checkpoint and deserialize to CPU (to save GPU memory) checkpoint = torch.load(args.model, map_location={'cuda:0': 'cpu'}) - vocab_size = checkpoint['tokenizer'].vocab_size - model_config = dict(vocab_size=vocab_size, math=checkpoint['config'].math, - **literal_eval(checkpoint['config'].model_config)) + # build GNMT model + tokenizer = Tokenizer() + tokenizer.set_state(checkpoint['tokenizer']) + vocab_size = tokenizer.vocab_size + model_config = checkpoint['model_config'] model_config['batch_first'] = args.batch_first - model = models.GNMT(**model_config) - - state_dict = checkpoint['state_dict'] - if checkpoint_from_distributed(state_dict): - state_dict = unwrap_distributed(state_dict) - - model.load_state_dict(state_dict) - - if args.math == 'fp32': - dtype = torch.FloatTensor - if args.math == 'fp16': - dtype = torch.HalfTensor - - model.type(dtype) - if args.cuda: - model = model.cuda() - model.eval() - - tokenizer = checkpoint['tokenizer'] - - translation_model = Translator(model, - tokenizer, - beam_size=args.beam_size, - max_seq_len=args.max_seq_len, - len_norm_factor=args.len_norm_factor, - len_norm_const=args.len_norm_const, - cov_penalty_factor=args.cov_penalty_factor, - cuda=args.cuda) - - output_file = codecs.open(args.output, 'w', encoding='UTF-8') - - # run model on generated data, for accurate timings starting from 1st batch - dummy_data = ['abc ' * (args.max_seq_len // 4)] * args.batch_size - translation_model.translate(dummy_data) - - if args.cuda: - torch.cuda.synchronize() - - batch_time = AverageMeter(False) - enc_tok_per_sec = AverageMeter(False) - dec_tok_per_sec = AverageMeter(False) - tot_tok_per_sec = AverageMeter(False) - - enc_seq_len = AverageMeter(False) - dec_seq_len = AverageMeter(False) - - total_lines = 0 - total_iters = 0 - with codecs.open(args.input, encoding='UTF-8') as input_file: - for idx, lines in enumerate(grouper(input_file, args.batch_size)): - lines = [l for l in lines if l] - n_lines = len(lines) - total_lines += n_lines - - translate_timer = time.time() - translated_lines, stats = translation_model.translate(lines) - elapsed = time.time() - translate_timer - - batch_time.update(elapsed, n_lines) - etps = stats['total_enc_len'] / elapsed - dtps = stats['total_dec_len'] / elapsed - enc_seq_len.update(stats['total_enc_len'] / n_lines, n_lines) - dec_seq_len.update(stats['total_dec_len'] / n_lines, n_lines) - enc_tok_per_sec.update(etps, n_lines) - dec_tok_per_sec.update(dtps, n_lines) - - tot_tok = stats['total_dec_len'] + stats['total_enc_len'] - ttps = tot_tok / elapsed - tot_tok_per_sec.update(ttps, n_lines) - - n_iterations = stats['iters'] - total_iters += n_iterations - - write_output(output_file, translated_lines) - - if idx % args.print_freq == args.print_freq - 1: - print(f'TRANSLATION: ' - f'Batch {idx} ' - f'Iters {n_iterations}\t' - f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - f'Tot tok/s {tot_tok_per_sec.val:.0f} ({tot_tok_per_sec.avg:.0f})\t' - f'Enc tok/s {enc_tok_per_sec.val:.0f} ({enc_tok_per_sec.avg:.0f})\t' - f'Dec tok/s {dec_tok_per_sec.val:.0f} ({dec_tok_per_sec.avg:.0f})') - - output_file.close() - - print(f'TRANSLATION SUMMARY:\n' - f'Lines translated: {total_lines}\t' - f'Avg time per batch: {batch_time.avg:.3f} s\t' - f'Avg time per sentence: {1000*(batch_time.avg / args.batch_size):.3f} ms\n' - f'Avg enc seq len: {enc_seq_len.avg:.2f}\t' - f'Avg dec seq len: {dec_seq_len.avg:.2f}\t' - f'Total iterations: {total_iters}\t\n' - f'Avg tot tok/s: {tot_tok_per_sec.avg:.0f}\t' - f'Avg enc tok/s: {enc_tok_per_sec.avg:.0f}\t' - f'Avg dec tok/s: {dec_tok_per_sec.avg:.0f}') + model = GNMT(vocab_size=vocab_size, **model_config) + model.load_state_dict(checkpoint['state_dict']) + + for (math, batch_size, beam_size) in product(args.math, args.batch_size, + args.beam_size): + logging.info(f'math: {math}, batch size: {batch_size}, ' + f'beam size: {beam_size}') + if math == 'fp32': + dtype = torch.FloatTensor + if math == 'fp16': + dtype = torch.HalfTensor + model.type(dtype) + + if args.cuda: + model = model.cuda() + model.eval() + + # construct the dataset + test_data = TextDataset(src_fname=args.input, + tokenizer=tokenizer, + sort=args.sort) + + # build the data loader + test_loader = test_data.get_loader(batch_size=batch_size, + batch_first=args.batch_first, + shuffle=False, + pad=True, + num_workers=0) + + # build the translator object + translator = Translator(model=model, + tokenizer=tokenizer, + loader=test_loader, + beam_size=beam_size, + max_seq_len=args.max_seq_len, + len_norm_factor=args.len_norm_factor, + len_norm_const=args.len_norm_const, + cov_penalty_factor=args.cov_penalty_factor, + cuda=args.cuda, + print_freq=args.print_freq, + dataset_dir=args.dataset_dir) + + # execute the inference + translator.run(calc_bleu=args.bleu, eval_path=args.output, + reference_path=args.reference, summary=True) + if __name__ == '__main__': main()