diff --git a/.gitignore b/.gitignore index 98b0764cdc..a141737386 100644 --- a/.gitignore +++ b/.gitignore @@ -110,3 +110,8 @@ ENV/ test/assets/sinewave.wav torchaudio/version.py gen.yml + +# Examples +examples/interactive_asr/data/*.txt +examples/interactive_asr/data/*.model +examples/interactive_asr/data/*.pt diff --git a/.travis.yml b/.travis.yml index d202003b5f..e9649ed7f7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,15 +8,19 @@ cache: directories: - /home/travis/download -# This matrix tests that the code works on Python 3.5, 3.6, and passes lint. +# This matrix tests that the code works on Python 2.7, 3.5, 3.6, 3.7, passes +# lint and example tests. matrix: fast_finish: true include: - - env: PYTHON_VERSION="3.7" - - env: PYTHON_VERSION="3.6" - # TODO add this back in when there is a pytorch 1.2 for python 3.5 - - env: PYTHON_VERSION="3.5" RUN_FLAKE8="true" SKIP_TESTS="true" - env: PYTHON_VERSION="2.7" + - env: PYTHON_VERSION="3.5" + - env: PYTHON_VERSION="3.6" + - env: PYTHON_VERSION="3.7" + - env: PYTHON_VERSION="3.5" RUN_FLAKE8="true" SKIP_INSTALL="true" SKIP_TESTS="true" + - env: PYTHON_VERSION="3.5" RUN_EXAMPLE_TESTS="true" SKIP_TESTS="true" + allow_failures: + - env: PYTHON_VERSION="3.5" RUN_EXAMPLE_TESTS="true" SKIP_TESTS="true" addons: apt: @@ -24,6 +28,7 @@ addons: sox libsox-dev libsox-fmt-all + portaudio19-dev notifications: email: false diff --git a/build_tools/travis/install.sh b/build_tools/travis/install.sh index 13a739ca0c..61c564a5bf 100644 --- a/build_tools/travis/install.sh +++ b/build_tools/travis/install.sh @@ -51,7 +51,30 @@ source activate testenv pip install -r requirements.txt # Install the following only if running tests -if [[ "$SKIP_TESTS" != "true" ]]; then +if [[ "$SKIP_INSTALL" != "true" ]]; then # TorchAudio CPP Extensions python setup.py install fi + +if [[ "$RUN_EXAMPLE_TESTS" == "true" ]]; then + # Install dependencies + pip install sentencepiece PyAudio + + if [[ ! -d $HOME/download/fairseq ]]; then + # Install fairseq from source + git clone https://github.com/pytorch/fairseq $HOME/download/fairseq + fi + + pushd $HOME/download/fairseq + pip install --editable . + popd + + mkdir -p $HOME/download/data + # Install dictionary, sentence piece model, and model + # These are cached so they are not downloaded if they already exist + wget -nc -O $HOME/download/data/dict.txt https://download.pytorch.org/models/audio/dict.txt || true + wget -nc -O $HOME/download/data/spm.model https://download.pytorch.org/models/audio/spm.model || true + wget -nc -O $HOME/download/data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt || true +fi + +echo "Finished installation" diff --git a/build_tools/travis/test_script.sh b/build_tools/travis/test_script.sh index f1b45f63a9..06d70956b8 100755 --- a/build_tools/travis/test_script.sh +++ b/build_tools/travis/test_script.sh @@ -32,5 +32,17 @@ if [[ "$RUN_FLAKE8" == "true" ]]; then fi if [[ "$SKIP_TESTS" != "true" ]]; then + echo "run_tests" run_tests fi + +if [[ "$RUN_EXAMPLE_TESTS" == "true" ]]; then + echo "run_example_tests" + pushd examples + ASR_MODEL_PATH=$HOME/download/data/model.pt \ + ASR_INPUT_FILE=interactive_asr/data/sample.wav \ + ASR_DATA_PATH=$HOME/download/data \ + ASR_USER_DIR=$HOME/download/fairseq/examples/speech_recognition \ + python -m unittest test/test_interactive_asr.py + popd +fi diff --git a/examples/interactive_asr/README.md b/examples/interactive_asr/README.md index 0bcef07e6d..39a5c53b75 100644 --- a/examples/interactive_asr/README.md +++ b/examples/interactive_asr/README.md @@ -16,6 +16,9 @@ and the following models We recommend that you use [conda](https://docs.conda.io/en/latest/miniconda.html) to install the dependencies when available. ```bash +# Assume that all commands are from the examples folder +cd examples + # Install dependencies conda install -c pytorch torchaudio conda install -c conda-forge librosa @@ -23,26 +26,38 @@ conda install pyaudio pip install sentencepiece # Install fairseq from source -git clone https://github.com/pytorch/fairseq -cd fairseq +git clone https://github.com/pytorch/fairseq interactive_asr/fairseq +pushd interactive_asr/fairseq export CFLAGS='-stdlib=libc++' # For Mac only pip install --editable . -cd .. +popd # Install dictionary, sentence piece model, and model -wget -O ./data/dict.txt https://download.pytorch.org/models/audio/dict.txt -wget -O ./data/spm.model https://download.pytorch.org/models/audio/spm.model -wget -O ./data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt +wget -O interactive_asr/data/dict.txt https://download.pytorch.org/models/audio/dict.txt +wget -O interactive_asr/data/spm.model https://download.pytorch.org/models/audio/spm.model +wget -O interactive_asr/data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt ``` ## Run On a file ```bash -INPUT_FILE=./data/sample.wav -python asr.py ./data --input_file $INPUT_FILE --max-tokens 10000000 --nbest 1 --path ./data/model.pt --beam 40 --task speech_recognition --user-dir ./fairseq/examples/speech_recognition +INPUT_FILE=interactive_asr/data/sample.wav +python -m interactive_asr.asr interactive_asr/data --input_file $INPUT_FILE --max-tokens 10000000 --nbest 1 \ + --path interactive_asr/data/model.pt --beam 40 --task speech_recognition \ + --user-dir interactive_asr/fairseq/examples/speech_recognition ``` As a microphone ```bash -python asr.py ./data --max-tokens 10000000 --nbest 1 --path ./data/model.pt --beam 40 --task speech_recognition --user-dir ./fairseq/examples/speech_recognition +python -m interactive_asr.asr interactive_asr/data --max-tokens 10000000 --nbest 1 \ + --path interactive_asr/data/model.pt --beam 40 --task speech_recognition \ + --user-dir interactive_asr/fairseq/examples/speech_recognition +``` +To run the testcase associated with this example +```bash +ASR_MODEL_PATH=interactive_asr/data/model.pt \ +ASR_INPUT_FILE=interactive_asr/data/sample.wav \ +ASR_DATA_PATH=interactive_asr/data \ +ASR_USER_DIR=interactive_asr/fairseq/examples/speech_recognition \ +python -m unittest test/test_interactive_asr.py ``` diff --git a/examples/interactive_asr/__init__.py b/examples/interactive_asr/__init__.py new file mode 100644 index 0000000000..505cb7fc4e --- /dev/null +++ b/examples/interactive_asr/__init__.py @@ -0,0 +1 @@ +from . import utils, vad diff --git a/examples/interactive_asr/asr.py b/examples/interactive_asr/asr.py index 7a28a917cf..8e2510f347 100644 --- a/examples/interactive_asr/asr.py +++ b/examples/interactive_asr/asr.py @@ -11,187 +11,24 @@ import datetime as dt import logging -import os -import sys -import time -import torch +from fairseq import options -import sentencepiece as spm -import torchaudio -from fairseq import options, tasks, utils -from fairseq.meters import StopwatchMeter, TimeMeter -from fairseq.utils import import_user_module -from vad import get_microphone_chunks - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -def add_asr_eval_argument(parser): - parser.add_argument("--input_file", help="input file") - parser.add_argument("--ctc", action="store_true", help="decode a ctc model") - parser.add_argument("--rnnt", default=False, help="decode a rnnt model") - parser.add_argument("--kspmodel", default=None, help="sentence piece model") - parser.add_argument( - "--wfstlm", default=None, help="wfstlm on dictonary output units" - ) - parser.add_argument( - "--rnnt_decoding_type", - default="greedy", - help="wfstlm on dictonary output units", - ) - parser.add_argument( - "--lm_weight", - default=0.2, - help="weight for wfstlm while interpolating with neural score", - ) - parser.add_argument( - "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" - ) - return parser - - -def check_args(args): - assert args.path is not None, "--path required for generation!" - assert ( - not args.sampling or args.nbest == args.beam - ), "--sampling requires --nbest to be equal to --beam" - assert ( - args.replace_unk is None or args.raw_text - ), "--replace-unk requires a raw text dataset (--raw-text)" - - -def process_predictions(args, hypos, sp, tgt_dict): - res = [] - for hypo in hypos[: min(len(hypos), args.nbest)]: - hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) - hyp_words = sp.DecodePieces(hyp_pieces.split()) - res.append(hyp_words) - return res - - -def optimize_models(args, use_cuda, models): - """Optimize ensemble for generation - """ - for model in models: - model.make_generation_fast_( - beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, - need_attn=args.print_alignment, - ) - if args.fp16: - model.half() - if use_cuda: - model.cuda() - - -def calc_mean_invstddev(feature): - if len(feature.shape) != 2: - raise ValueError("We expect the input feature to be 2-D tensor") - mean = torch.mean(feature, dim=0) - var = torch.var(feature, dim=0) - # avoid division by ~zero - if (var < sys.float_info.epsilon).any(): - return mean, 1.0 / (torch.sqrt(var) + sys.float_info.epsilon) - return mean, 1.0 / torch.sqrt(var) - - -def calcMN(features): - mean, invstddev = calc_mean_invstddev(features) - res = (features - mean) * invstddev - return res - - -def transcribe(waveform, args, task, generator, models, sp, tgt_dict): - num_features = 80 - output = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=num_features) - output_cmvn = calcMN(output.cpu().detach()) - - # size (m, n) - source = torch.tensor(output_cmvn) - frames_lengths = torch.LongTensor([source.size(0)]) - - # size (1, m, n). In general, if source is (x, m, n), then hypos is (x, ...) - source.unsqueeze_(0) - sample = {"net_input": {"src_tokens": source, "src_lengths": frames_lengths}} - - hypos = task.inference_step(generator, models, sample) - - assert len(hypos) == 1 - transcription = [] - for i in range(len(hypos)): - # Process top predictions - hyp_words = process_predictions(args, hypos[i], sp, tgt_dict) - transcription.append(hyp_words) - - return transcription +from interactive_asr.utils import add_asr_eval_argument, setup_asr, get_microphone_transcription, transcribe_file def main(args): - check_args(args) - import_user_module(args) - - if args.max_tokens is None and args.max_sentences is None: - args.max_tokens = 30000 - logger.info(args) - - use_cuda = torch.cuda.is_available() and not args.cpu - - # Load dataset splits - task = tasks.setup_task(args) - - # Set dictionary - tgt_dict = task.target_dictionary - - if args.ctc or args.rnnt: - tgt_dict.add_symbol("") - if args.ctc: - logger.info("| decoding a ctc model") - if args.rnnt: - logger.info("| decoding a rnnt model") - - # Load ensemble - logger.info("| loading model(s) from {}".format(args.path)) - models, _model_args = utils.load_ensemble_for_inference( - args.path.split(":"), - task, - model_arg_overrides=eval(args.model_overrides), # noqa - ) - optimize_models(args, use_cuda, models) - - # Initialize generator - generator = task.build_generator(args) - - sp = spm.SentencePieceProcessor() - sp.Load(os.path.join(args.data, "spm.model")) + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + task, generator, models, sp, tgt_dict = setup_asr(args, logger) + print("READY!") if args.input_file: - path = args.input_file - if not os.path.exists(path): - raise FileNotFoundError("Audio file not found: {}".format(path)) - waveform, sample_rate = torchaudio.load_wav(path) - waveform = waveform.mean(0, True) - waveform = torchaudio.transforms.Resample( - orig_freq=sample_rate, new_freq=16000 - )(waveform) - - print(sample_rate, waveform.shape) - start = time.time() - transcription = transcribe( - waveform, args, task, generator, models, sp, tgt_dict - ) - end = time.time() + transcription_time, transcription = transcribe_file(args, task, generator, models, sp, tgt_dict) print("transcription:", transcription) - print(end - start) + print("transcription_time:", transcription_time) else: - print("READY!") - for (waveform, sample_rate) in get_microphone_chunks(): - waveform = torchaudio.transforms.Resample( - orig_freq=sample_rate, new_freq=16000 - )(waveform.reshape(1, -1)) - transcription = transcribe( - waveform, args, task, generator, models, sp, tgt_dict - ) + for transcription in get_microphone_transcription(args, task, generator, models, sp, tgt_dict): print( "{}: {}".format( dt.datetime.now().strftime("%H:%M:%S"), transcription[0][0] diff --git a/examples/interactive_asr/utils.py b/examples/interactive_asr/utils.py new file mode 100644 index 0000000000..713954029a --- /dev/null +++ b/examples/interactive_asr/utils.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. +import os +import sys +import time + +import torch +import torchaudio +import sentencepiece as spm + +from fairseq import tasks +from fairseq.utils import load_ensemble_for_inference, import_user_module + +from interactive_asr.vad import get_microphone_chunks + + +def add_asr_eval_argument(parser): + parser.add_argument("--input_file", help="input file") + parser.add_argument("--ctc", action="store_true", help="decode a ctc model") + parser.add_argument("--rnnt", default=False, help="decode a rnnt model") + parser.add_argument("--kspmodel", default=None, help="sentence piece model") + parser.add_argument( + "--wfstlm", default=None, help="wfstlm on dictonary output units" + ) + parser.add_argument( + "--rnnt_decoding_type", + default="greedy", + help="wfstlm on dictonary output units", + ) + parser.add_argument( + "--lm_weight", + default=0.2, + help="weight for wfstlm while interpolating with neural score", + ) + parser.add_argument( + "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" + ) + return parser + + +def check_args(args): + assert args.path is not None, "--path required for generation!" + assert ( + not args.sampling or args.nbest == args.beam + ), "--sampling requires --nbest to be equal to --beam" + assert ( + args.replace_unk is None or args.raw_text + ), "--replace-unk requires a raw text dataset (--raw-text)" + + +def process_predictions(args, hypos, sp, tgt_dict): + res = [] + for hypo in hypos[: min(len(hypos), args.nbest)]: + hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) + hyp_words = sp.DecodePieces(hyp_pieces.split()) + res.append(hyp_words) + return res + + +def optimize_models(args, use_cuda, models): + """Optimize ensemble for generation + """ + for model in models: + model.make_generation_fast_( + beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, + need_attn=args.print_alignment, + ) + if args.fp16: + model.half() + if use_cuda: + model.cuda() + + +def calc_mean_invstddev(feature): + if len(feature.shape) != 2: + raise ValueError("We expect the input feature to be 2-D tensor") + mean = torch.mean(feature, dim=0) + var = torch.var(feature, dim=0) + # avoid division by ~zero + if (var < sys.float_info.epsilon).any(): + return mean, 1.0 / (torch.sqrt(var) + sys.float_info.epsilon) + return mean, 1.0 / torch.sqrt(var) + + +def calcMN(features): + mean, invstddev = calc_mean_invstddev(features) + res = (features - mean) * invstddev + return res + + +def transcribe(waveform, args, task, generator, models, sp, tgt_dict): + num_features = 80 + output = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=num_features) + output_cmvn = calcMN(output.cpu().detach()) + + # size (m, n) + source = output_cmvn + frames_lengths = torch.LongTensor([source.size(0)]) + + # size (1, m, n). In general, if source is (x, m, n), then hypos is (x, ...) + source.unsqueeze_(0) + sample = {"net_input": {"src_tokens": source, "src_lengths": frames_lengths}} + + hypos = task.inference_step(generator, models, sample) + + assert len(hypos) == 1 + transcription = [] + for i in range(len(hypos)): + # Process top predictions + hyp_words = process_predictions(args, hypos[i], sp, tgt_dict) + transcription.append(hyp_words) + + return transcription + + +def setup_asr(args, logger): + check_args(args) + import_user_module(args) + + if args.max_tokens is None and args.max_sentences is None: + args.max_tokens = 30000 + logger.info(args) + + use_cuda = torch.cuda.is_available() and not args.cpu + + # Load dataset splits + task = tasks.setup_task(args) + + # Set dictionary + tgt_dict = task.target_dictionary + + if args.ctc or args.rnnt: + tgt_dict.add_symbol("") + if args.ctc: + logger.info("| decoding a ctc model") + if args.rnnt: + logger.info("| decoding a rnnt model") + + # Load ensemble + logger.info("| loading model(s) from {}".format(args.path)) + models, _model_args = load_ensemble_for_inference( + args.path.split(":"), + task, + model_arg_overrides=eval(args.model_overrides), # noqa + ) + optimize_models(args, use_cuda, models) + + # Initialize generator + generator = task.build_generator(args) + + sp = spm.SentencePieceProcessor() + sp.Load(os.path.join(args.data, "spm.model")) + return task, generator, models, sp, tgt_dict + + +def transcribe_file(args, task, generator, models, sp, tgt_dict): + path = args.input_file + if not os.path.exists(path): + raise FileNotFoundError("Audio file not found: {}".format(path)) + waveform, sample_rate = torchaudio.load_wav(path) + waveform = waveform.mean(0, True) + waveform = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=16000 + )(waveform) + + start = time.time() + transcription = transcribe( + waveform, args, task, generator, models, sp, tgt_dict + ) + transcription_time = time.time() - start + return transcription_time, transcription + + +def get_microphone_transcription(args, task, generator, models, sp, tgt_dict): + for (waveform, sample_rate) in get_microphone_chunks(): + waveform = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=16000 + )(waveform.reshape(1, -1)) + transcription = transcribe( + waveform, args, task, generator, models, sp, tgt_dict + ) + yield transcription diff --git a/examples/test/__init__.py b/examples/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/test/test_interactive_asr.py b/examples/test/test_interactive_asr.py new file mode 100644 index 0000000000..aa48d0a89f --- /dev/null +++ b/examples/test/test_interactive_asr.py @@ -0,0 +1,56 @@ +import argparse +import logging +import os +import unittest + +from interactive_asr.utils import setup_asr, transcribe_file + + +class ASRTest(unittest.TestCase): + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + + arguments_dict = { + 'path': '/scratch/jamarshon/downloads/model.pt', + 'input_file': '/scratch/jamarshon/audio/examples/interactive_asr/data/sample.wav', + 'data': '/scratch/jamarshon/downloads', + 'user_dir': '/scratch/jamarshon/fairseq-py/examples/speech_recognition', + 'no_progress_bar': False, 'log_interval': 1000, 'log_format': None, + 'tensorboard_logdir': '', 'tbmf_wrapper': False, 'seed': 1, 'cpu': True, + 'fp16': False, 'memory_efficient_fp16': False, 'fp16_init_scale': 128, + 'fp16_scale_window': None, 'fp16_scale_tolerance': 0.0, + 'min_loss_scale': 0.0001, 'threshold_loss_scale': None, + 'criterion': 'cross_entropy', 'tokenizer': None, 'bpe': None, 'optimizer': + 'nag', 'lr_scheduler': 'fixed', 'task': 'speech_recognition', 'num_workers': 0, + 'skip_invalid_size_inputs_valid_test': False, 'max_tokens': 10000000, + 'max_sentences': None, 'required_batch_size_multiple': 8, 'dataset_impl': None, + 'gen_subset': 'test', 'num_shards': 1, 'shard_id': 0, + 'remove_bpe': None, 'quiet': False, 'model_overrides': '{}', + 'results_path': None, 'beam': 40, 'nbest': 1, 'max_len_a': 0, + 'max_len_b': 200, 'min_len': 1, 'match_source_len': False, + 'no_early_stop': False, 'unnormalized': False, 'no_beamable_mm': False, + 'lenpen': 1, 'unkpen': 0, 'replace_unk': None, 'sacrebleu': False, + 'score_reference': False, 'prefix_size': 0, 'no_repeat_ngram_size': 0, + 'sampling': False, 'sampling_topk': -1, 'sampling_topp': -1.0, + 'temperature': 1.0, 'diverse_beam_groups': -1, 'diverse_beam_strength': 0.5, + 'print_alignment': False, 'ctc': False, + 'rnnt': False, 'kspmodel': None, 'wfstlm': None, 'rnnt_decoding_type': 'greedy', + 'lm_weight': 0.2, 'rnnt_len_penalty': -0.5, 'momentum': 0.99, 'weight_decay': 0.0, + 'force_anneal': None, 'lr_shrink': 0.1, 'warmup_updates': 0} + + arguments_dict['path'] = os.environ.get('ASR_MODEL_PATH', None) + arguments_dict['input_file'] = os.environ.get('ASR_INPUT_FILE', None) + arguments_dict['data'] = os.environ.get('ASR_DATA_PATH', None) + arguments_dict['user_dir'] = os.environ.get('ASR_USER_DIR', None) + args = argparse.Namespace(**arguments_dict) + + def test_transcribe_file(self): + task, generator, models, sp, tgt_dict = setup_asr(self.args, self.logger) + _, transcription = transcribe_file(self.args, task, generator, models, sp, tgt_dict) + + expected_transcription = [['THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG']] + self.assertEqual(transcription, expected_transcription, msg=str(transcription)) + + +if __name__ == "__main__": + unittest.main()