forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Wav2vec-U implementation Pull Request resolved: fairinternal/fairseq-py#1889 Reviewed By: michaelauli Differential Revision: D28596815 Pulled By: alexeib fbshipit-source-id: bb09d081d167d5d10968acc6e056044bf96679ac
- Loading branch information
1 parent
f9edd9f
commit 649af63
Showing
82 changed files
with
6,625 additions
and
255 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <iostream> | ||
#include "fstext/fstext-lib.h" // @manual | ||
#include "util/common-utils.h" // @manual | ||
|
||
/* | ||
* This program is to modify a FST without self-loop by: | ||
* for each incoming arc with non-eps input symbol, add a self-loop arc | ||
* with that non-eps symbol as input and eps as output. | ||
* | ||
* This is to make sure the resultant FST can do deduplication for repeated | ||
* symbols, which is very common in acoustic model | ||
* | ||
*/ | ||
namespace { | ||
int32 AddSelfLoopsSimple(fst::StdVectorFst* fst) { | ||
typedef fst::MutableArcIterator<fst::StdVectorFst> IterType; | ||
|
||
int32 num_states_before = fst->NumStates(); | ||
fst::MakePrecedingInputSymbolsSame(false, fst); | ||
int32 num_states_after = fst->NumStates(); | ||
KALDI_LOG << "There are " << num_states_before | ||
<< " states in the original FST; " | ||
<< " after MakePrecedingInputSymbolsSame, there are " | ||
<< num_states_after << " states " << std::endl; | ||
|
||
auto weight_one = fst::StdArc::Weight::One(); | ||
|
||
int32 num_arc_added = 0; | ||
|
||
fst::StdArc self_loop_arc; | ||
self_loop_arc.weight = weight_one; | ||
|
||
int32 num_states = fst->NumStates(); | ||
std::vector<std::set<int32>> incoming_non_eps_label_per_state(num_states); | ||
|
||
for (int32 state = 0; state < num_states; state++) { | ||
for (IterType aiter(fst, state); !aiter.Done(); aiter.Next()) { | ||
fst::StdArc arc(aiter.Value()); | ||
if (arc.ilabel != 0) { | ||
incoming_non_eps_label_per_state[arc.nextstate].insert(arc.ilabel); | ||
} | ||
} | ||
} | ||
|
||
for (int32 state = 0; state < num_states; state++) { | ||
if (!incoming_non_eps_label_per_state[state].empty()) { | ||
auto& ilabel_set = incoming_non_eps_label_per_state[state]; | ||
for (auto it = ilabel_set.begin(); it != ilabel_set.end(); it++) { | ||
self_loop_arc.ilabel = *it; | ||
self_loop_arc.olabel = 0; | ||
self_loop_arc.nextstate = state; | ||
fst->AddArc(state, self_loop_arc); | ||
num_arc_added++; | ||
} | ||
} | ||
} | ||
return num_arc_added; | ||
} | ||
|
||
void print_usage() { | ||
std::cout << "add-self-loop-simple usage:\n" | ||
"\tadd-self-loop-simple <in-fst> <out-fst> \n"; | ||
} | ||
} // namespace | ||
|
||
int main(int argc, char** argv) { | ||
if (argc != 3) { | ||
print_usage(); | ||
exit(1); | ||
} | ||
|
||
auto input = argv[1]; | ||
auto output = argv[2]; | ||
|
||
auto fst = fst::ReadFstKaldi(input); | ||
auto num_states = fst->NumStates(); | ||
KALDI_LOG << "Loading FST from " << input << " with " << num_states | ||
<< " states." << std::endl; | ||
|
||
int32 num_arc_added = AddSelfLoopsSimple(fst); | ||
KALDI_LOG << "Adding " << num_arc_added << " self-loop arcs " << std::endl; | ||
|
||
fst::WriteFstKaldi(*fst, std::string(output)); | ||
KALDI_LOG << "Writing FST to " << output << std::endl; | ||
|
||
delete fst; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from concurrent.futures import ThreadPoolExecutor | ||
import logging | ||
from omegaconf import MISSING | ||
import os | ||
import torch | ||
from typing import Optional | ||
import warnings | ||
|
||
|
||
from dataclasses import dataclass | ||
from fairseq.dataclass import FairseqDataclass | ||
from .kaldi_initializer import KaldiInitializerConfig, initalize_kaldi | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class KaldiDecoderConfig(FairseqDataclass): | ||
hlg_graph_path: Optional[str] = None | ||
output_dict: str = MISSING | ||
|
||
kaldi_initializer_config: Optional[KaldiInitializerConfig] = None | ||
|
||
acoustic_scale: float = 0.5 | ||
max_active: int = 10000 | ||
beam_delta: float = 0.5 | ||
hash_ratio: float = 2.0 | ||
|
||
is_lattice: bool = False | ||
lattice_beam: float = 10.0 | ||
prune_interval: int = 25 | ||
determinize_lattice: bool = True | ||
prune_scale: float = 0.1 | ||
max_mem: int = 0 | ||
phone_determinize: bool = True | ||
word_determinize: bool = True | ||
minimize: bool = True | ||
|
||
num_threads: int = 1 | ||
|
||
|
||
class KaldiDecoder(object): | ||
def __init__( | ||
self, | ||
cfg: KaldiDecoderConfig, | ||
beam: int, | ||
nbest: int = 1, | ||
): | ||
try: | ||
from kaldi.asr import FasterRecognizer, LatticeFasterRecognizer | ||
from kaldi.base import set_verbose_level | ||
from kaldi.decoder import ( | ||
FasterDecoder, | ||
FasterDecoderOptions, | ||
LatticeFasterDecoder, | ||
LatticeFasterDecoderOptions, | ||
) | ||
from kaldi.lat.functions import DeterminizeLatticePhonePrunedOptions | ||
from kaldi.fstext import read_fst_kaldi, SymbolTable | ||
except: | ||
warnings.warn( | ||
"pykaldi is required for this functionality. Please install from https://github.com/pykaldi/pykaldi" | ||
) | ||
|
||
# set_verbose_level(2) | ||
|
||
self.acoustic_scale = cfg.acoustic_scale | ||
self.nbest = nbest | ||
|
||
if cfg.hlg_graph_path is None: | ||
assert ( | ||
cfg.kaldi_initializer_config is not None | ||
), "Must provide hlg graph path or kaldi initializer config" | ||
cfg.hlg_graph_path = initalize_kaldi(cfg.kaldi_initializer_config) | ||
|
||
assert os.path.exists(cfg.hlg_graph_path), cfg.hlg_graph_path | ||
|
||
if cfg.is_lattice: | ||
self.dec_cls = LatticeFasterDecoder | ||
opt_cls = LatticeFasterDecoderOptions | ||
self.rec_cls = LatticeFasterRecognizer | ||
else: | ||
assert self.nbest == 1, "nbest > 1 requires lattice decoder" | ||
self.dec_cls = FasterDecoder | ||
opt_cls = FasterDecoderOptions | ||
self.rec_cls = FasterRecognizer | ||
|
||
self.decoder_options = opt_cls() | ||
self.decoder_options.beam = beam | ||
self.decoder_options.max_active = cfg.max_active | ||
self.decoder_options.beam_delta = cfg.beam_delta | ||
self.decoder_options.hash_ratio = cfg.hash_ratio | ||
|
||
if cfg.is_lattice: | ||
self.decoder_options.lattice_beam = cfg.lattice_beam | ||
self.decoder_options.prune_interval = cfg.prune_interval | ||
self.decoder_options.determinize_lattice = cfg.determinize_lattice | ||
self.decoder_options.prune_scale = cfg.prune_scale | ||
det_opts = DeterminizeLatticePhonePrunedOptions() | ||
det_opts.max_mem = cfg.max_mem | ||
det_opts.phone_determinize = cfg.phone_determinize | ||
det_opts.word_determinize = cfg.word_determinize | ||
det_opts.minimize = cfg.minimize | ||
self.decoder_options.det_opts = det_opts | ||
|
||
self.output_symbols = {} | ||
with open(cfg.output_dict, "r") as f: | ||
for line in f: | ||
items = line.rstrip().split() | ||
assert len(items) == 2 | ||
self.output_symbols[int(items[1])] = items[0] | ||
|
||
logger.info(f"Loading FST from {cfg.hlg_graph_path}") | ||
self.fst = read_fst_kaldi(cfg.hlg_graph_path) | ||
self.symbol_table = SymbolTable.read_text(cfg.output_dict) | ||
|
||
self.executor = ThreadPoolExecutor(max_workers=cfg.num_threads) | ||
|
||
def generate(self, models, sample, **unused): | ||
"""Generate a batch of inferences.""" | ||
# model.forward normally channels prev_output_tokens into the decoder | ||
# separately, but SequenceGenerator directly calls model.encoder | ||
encoder_input = { | ||
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" | ||
} | ||
emissions, padding = self.get_emissions(models, encoder_input) | ||
return self.decode(emissions, padding) | ||
|
||
def get_emissions(self, models, encoder_input): | ||
"""Run encoder and normalize emissions""" | ||
model = models[0] | ||
|
||
all_encoder_out = [m(**encoder_input) for m in models] | ||
|
||
if len(all_encoder_out) > 1: | ||
|
||
if "encoder_out" in all_encoder_out[0]: | ||
encoder_out = { | ||
"encoder_out": sum(e["encoder_out"] for e in all_encoder_out) | ||
/ len(all_encoder_out), | ||
"encoder_padding_mask": all_encoder_out[0]["encoder_padding_mask"], | ||
} | ||
padding = encoder_out["encoder_padding_mask"] | ||
else: | ||
encoder_out = { | ||
"logits": sum(e["logits"] for e in all_encoder_out) | ||
/ len(all_encoder_out), | ||
"padding_mask": all_encoder_out[0]["padding_mask"], | ||
} | ||
padding = encoder_out["padding_mask"] | ||
else: | ||
encoder_out = all_encoder_out[0] | ||
padding = ( | ||
encoder_out["padding_mask"] | ||
if "padding_mask" in encoder_out | ||
else encoder_out["encoder_padding_mask"] | ||
) | ||
|
||
if hasattr(model, "get_logits"): | ||
emissions = model.get_logits(encoder_out, normalize=True) | ||
else: | ||
emissions = model.get_normalized_probs(encoder_out, log_probs=True) | ||
|
||
return ( | ||
emissions.cpu().float().transpose(0, 1), | ||
padding.cpu() if padding is not None and padding.any() else None, | ||
) | ||
|
||
def decode_one(self, logits, padding): | ||
from kaldi.matrix import Matrix | ||
|
||
decoder = self.dec_cls(self.fst, self.decoder_options) | ||
asr = self.rec_cls( | ||
decoder, self.symbol_table, acoustic_scale=self.acoustic_scale | ||
) | ||
|
||
if padding is not None: | ||
logits = logits[~padding] | ||
|
||
mat = Matrix(logits.numpy()) | ||
|
||
out = asr.decode(mat) | ||
|
||
if self.nbest > 1: | ||
from kaldi.fstext import shortestpath | ||
from kaldi.fstext.utils import ( | ||
convert_compact_lattice_to_lattice, | ||
convert_lattice_to_std, | ||
convert_nbest_to_list, | ||
get_linear_symbol_sequence, | ||
) | ||
|
||
lat = out["lattice"] | ||
|
||
sp = shortestpath(lat, nshortest=self.nbest) | ||
|
||
sp = convert_compact_lattice_to_lattice(sp) | ||
sp = convert_lattice_to_std(sp) | ||
seq = convert_nbest_to_list(sp) | ||
|
||
results = [] | ||
for s in seq: | ||
_, o, w = get_linear_symbol_sequence(s) | ||
words = list(self.output_symbols[z] for z in o) | ||
results.append( | ||
{ | ||
"tokens": words, | ||
"words": words, | ||
"score": w.value, | ||
"emissions": logits, | ||
} | ||
) | ||
return results | ||
else: | ||
words = out["text"].split() | ||
return [ | ||
{ | ||
"tokens": words, | ||
"words": words, | ||
"score": out["likelihood"], | ||
"emissions": logits, | ||
} | ||
] | ||
|
||
def decode(self, emissions, padding): | ||
if padding is None: | ||
padding = [None] * len(emissions) | ||
|
||
ret = list( | ||
map( | ||
lambda e, p: self.executor.submit(self.decode_one, e, p), | ||
emissions, | ||
padding, | ||
) | ||
) | ||
return ret |
Oops, something went wrong.