Skip to content

Commit

Permalink
Increasing test coverage (ASR demo) (pytorch#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamarshon authored and cpuhrsch committed Aug 21, 2019
1 parent 42a705d commit ed17513
Show file tree
Hide file tree
Showing 10 changed files with 328 additions and 187 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,8 @@ ENV/
test/assets/sinewave.wav
torchaudio/version.py
gen.yml

# Examples
examples/interactive_asr/data/*.txt
examples/interactive_asr/data/*.model
examples/interactive_asr/data/*.pt
15 changes: 10 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,27 @@ cache:
directories:
- /home/travis/download

# This matrix tests that the code works on Python 3.5, 3.6, and passes lint.
# This matrix tests that the code works on Python 2.7, 3.5, 3.6, 3.7, passes
# lint and example tests.
matrix:
fast_finish: true
include:
- env: PYTHON_VERSION="3.7"
- env: PYTHON_VERSION="3.6"
# TODO add this back in when there is a pytorch 1.2 for python 3.5
- env: PYTHON_VERSION="3.5" RUN_FLAKE8="true" SKIP_TESTS="true"
- env: PYTHON_VERSION="2.7"
- env: PYTHON_VERSION="3.5"
- env: PYTHON_VERSION="3.6"
- env: PYTHON_VERSION="3.7"
- env: PYTHON_VERSION="3.5" RUN_FLAKE8="true" SKIP_INSTALL="true" SKIP_TESTS="true"
- env: PYTHON_VERSION="3.5" RUN_EXAMPLE_TESTS="true" SKIP_TESTS="true"
allow_failures:
- env: PYTHON_VERSION="3.5" RUN_EXAMPLE_TESTS="true" SKIP_TESTS="true"

addons:
apt:
packages:
sox
libsox-dev
libsox-fmt-all
portaudio19-dev

notifications:
email: false
Expand Down
25 changes: 24 additions & 1 deletion build_tools/travis/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,30 @@ source activate testenv
pip install -r requirements.txt

# Install the following only if running tests
if [[ "$SKIP_TESTS" != "true" ]]; then
if [[ "$SKIP_INSTALL" != "true" ]]; then
# TorchAudio CPP Extensions
python setup.py install
fi

if [[ "$RUN_EXAMPLE_TESTS" == "true" ]]; then
# Install dependencies
pip install sentencepiece PyAudio

if [[ ! -d $HOME/download/fairseq ]]; then
# Install fairseq from source
git clone https://github.com/pytorch/fairseq $HOME/download/fairseq
fi

pushd $HOME/download/fairseq
pip install --editable .
popd

mkdir -p $HOME/download/data
# Install dictionary, sentence piece model, and model
# These are cached so they are not downloaded if they already exist
wget -nc -O $HOME/download/data/dict.txt https://download.pytorch.org/models/audio/dict.txt || true
wget -nc -O $HOME/download/data/spm.model https://download.pytorch.org/models/audio/spm.model || true
wget -nc -O $HOME/download/data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt || true
fi

echo "Finished installation"
12 changes: 12 additions & 0 deletions build_tools/travis/test_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,17 @@ if [[ "$RUN_FLAKE8" == "true" ]]; then
fi

if [[ "$SKIP_TESTS" != "true" ]]; then
echo "run_tests"
run_tests
fi

if [[ "$RUN_EXAMPLE_TESTS" == "true" ]]; then
echo "run_example_tests"
pushd examples
ASR_MODEL_PATH=$HOME/download/data/model.pt \
ASR_INPUT_FILE=interactive_asr/data/sample.wav \
ASR_DATA_PATH=$HOME/download/data \
ASR_USER_DIR=$HOME/download/fairseq/examples/speech_recognition \
python -m unittest test/test_interactive_asr.py
popd
fi
33 changes: 24 additions & 9 deletions examples/interactive_asr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,48 @@ and the following models

We recommend that you use [conda](https://docs.conda.io/en/latest/miniconda.html) to install the dependencies when available.
```bash
# Assume that all commands are from the examples folder
cd examples

# Install dependencies
conda install -c pytorch torchaudio
conda install -c conda-forge librosa
conda install pyaudio
pip install sentencepiece

# Install fairseq from source
git clone https://github.com/pytorch/fairseq
cd fairseq
git clone https://github.com/pytorch/fairseq interactive_asr/fairseq
pushd interactive_asr/fairseq
export CFLAGS='-stdlib=libc++' # For Mac only
pip install --editable .
cd ..
popd

# Install dictionary, sentence piece model, and model
wget -O ./data/dict.txt https://download.pytorch.org/models/audio/dict.txt
wget -O ./data/spm.model https://download.pytorch.org/models/audio/spm.model
wget -O ./data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt
wget -O interactive_asr/data/dict.txt https://download.pytorch.org/models/audio/dict.txt
wget -O interactive_asr/data/spm.model https://download.pytorch.org/models/audio/spm.model
wget -O interactive_asr/data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt
```

## Run
On a file
```bash
INPUT_FILE=./data/sample.wav
python asr.py ./data --input_file $INPUT_FILE --max-tokens 10000000 --nbest 1 --path ./data/model.pt --beam 40 --task speech_recognition --user-dir ./fairseq/examples/speech_recognition
INPUT_FILE=interactive_asr/data/sample.wav
python -m interactive_asr.asr interactive_asr/data --input_file $INPUT_FILE --max-tokens 10000000 --nbest 1 \
--path interactive_asr/data/model.pt --beam 40 --task speech_recognition \
--user-dir interactive_asr/fairseq/examples/speech_recognition
```

As a microphone
```bash
python asr.py ./data --max-tokens 10000000 --nbest 1 --path ./data/model.pt --beam 40 --task speech_recognition --user-dir ./fairseq/examples/speech_recognition
python -m interactive_asr.asr interactive_asr/data --max-tokens 10000000 --nbest 1 \
--path interactive_asr/data/model.pt --beam 40 --task speech_recognition \
--user-dir interactive_asr/fairseq/examples/speech_recognition
```
To run the testcase associated with this example
```bash
ASR_MODEL_PATH=interactive_asr/data/model.pt \
ASR_INPUT_FILE=interactive_asr/data/sample.wav \
ASR_DATA_PATH=interactive_asr/data \
ASR_USER_DIR=interactive_asr/fairseq/examples/speech_recognition \
python -m unittest test/test_interactive_asr.py
```
1 change: 1 addition & 0 deletions examples/interactive_asr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import utils, vad
181 changes: 9 additions & 172 deletions examples/interactive_asr/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,187 +11,24 @@

import datetime as dt
import logging
import os
import sys
import time

import torch
from fairseq import options

import sentencepiece as spm
import torchaudio
from fairseq import options, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.utils import import_user_module
from vad import get_microphone_chunks

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def add_asr_eval_argument(parser):
parser.add_argument("--input_file", help="input file")
parser.add_argument("--ctc", action="store_true", help="decode a ctc model")
parser.add_argument("--rnnt", default=False, help="decode a rnnt model")
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
parser.add_argument(
"--wfstlm", default=None, help="wfstlm on dictonary output units"
)
parser.add_argument(
"--rnnt_decoding_type",
default="greedy",
help="wfstlm on dictonary output units",
)
parser.add_argument(
"--lm_weight",
default=0.2,
help="weight for wfstlm while interpolating with neural score",
)
parser.add_argument(
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
)
return parser


def check_args(args):
assert args.path is not None, "--path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"


def process_predictions(args, hypos, sp, tgt_dict):
res = []
for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
hyp_words = sp.DecodePieces(hyp_pieces.split())
res.append(hyp_words)
return res


def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation
"""
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
if use_cuda:
model.cuda()


def calc_mean_invstddev(feature):
if len(feature.shape) != 2:
raise ValueError("We expect the input feature to be 2-D tensor")
mean = torch.mean(feature, dim=0)
var = torch.var(feature, dim=0)
# avoid division by ~zero
if (var < sys.float_info.epsilon).any():
return mean, 1.0 / (torch.sqrt(var) + sys.float_info.epsilon)
return mean, 1.0 / torch.sqrt(var)


def calcMN(features):
mean, invstddev = calc_mean_invstddev(features)
res = (features - mean) * invstddev
return res


def transcribe(waveform, args, task, generator, models, sp, tgt_dict):
num_features = 80
output = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=num_features)
output_cmvn = calcMN(output.cpu().detach())

# size (m, n)
source = torch.tensor(output_cmvn)
frames_lengths = torch.LongTensor([source.size(0)])

# size (1, m, n). In general, if source is (x, m, n), then hypos is (x, ...)
source.unsqueeze_(0)
sample = {"net_input": {"src_tokens": source, "src_lengths": frames_lengths}}

hypos = task.inference_step(generator, models, sample)

assert len(hypos) == 1
transcription = []
for i in range(len(hypos)):
# Process top predictions
hyp_words = process_predictions(args, hypos[i], sp, tgt_dict)
transcription.append(hyp_words)

return transcription
from interactive_asr.utils import add_asr_eval_argument, setup_asr, get_microphone_transcription, transcribe_file


def main(args):
check_args(args)
import_user_module(args)

if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 30000
logger.info(args)

use_cuda = torch.cuda.is_available() and not args.cpu

# Load dataset splits
task = tasks.setup_task(args)

# Set dictionary
tgt_dict = task.target_dictionary

if args.ctc or args.rnnt:
tgt_dict.add_symbol("<ctc_blank>")
if args.ctc:
logger.info("| decoding a ctc model")
if args.rnnt:
logger.info("| decoding a rnnt model")

# Load ensemble
logger.info("| loading model(s) from {}".format(args.path))
models, _model_args = utils.load_ensemble_for_inference(
args.path.split(":"),
task,
model_arg_overrides=eval(args.model_overrides), # noqa
)
optimize_models(args, use_cuda, models)

# Initialize generator
generator = task.build_generator(args)

sp = spm.SentencePieceProcessor()
sp.Load(os.path.join(args.data, "spm.model"))
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
task, generator, models, sp, tgt_dict = setup_asr(args, logger)

print("READY!")
if args.input_file:
path = args.input_file
if not os.path.exists(path):
raise FileNotFoundError("Audio file not found: {}".format(path))
waveform, sample_rate = torchaudio.load_wav(path)
waveform = waveform.mean(0, True)
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(waveform)

print(sample_rate, waveform.shape)
start = time.time()
transcription = transcribe(
waveform, args, task, generator, models, sp, tgt_dict
)
end = time.time()
transcription_time, transcription = transcribe_file(args, task, generator, models, sp, tgt_dict)
print("transcription:", transcription)
print(end - start)
print("transcription_time:", transcription_time)
else:
print("READY!")
for (waveform, sample_rate) in get_microphone_chunks():
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(waveform.reshape(1, -1))
transcription = transcribe(
waveform, args, task, generator, models, sp, tgt_dict
)
for transcription in get_microphone_transcription(args, task, generator, models, sp, tgt_dict):
print(
"{}: {}".format(
dt.datetime.now().strftime("%H:%M:%S"), transcription[0][0]
Expand Down
Loading

0 comments on commit ed17513

Please sign in to comment.