Skip to content

Commit

Permalink
W2v u update (facebookresearch#1954)
Browse files Browse the repository at this point in the history
Summary:
updating the scripts and examples to be easier to follow

Pull Request resolved: fairinternal/fairseq-py#1954

Reviewed By: wnhsu

Differential Revision: D29041166

Pulled By: alexeib

fbshipit-source-id: d9410c6e925337b810e92b393e226869ef9e1733
  • Loading branch information
alexeib authored and facebook-github-bot committed Jun 11, 2021
1 parent 50158da commit f8a7c93
Show file tree
Hide file tree
Showing 23 changed files with 372 additions and 226 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# @package _group_

data_dir: ???
fst_dir: ???
in_labels: ???
kaldi_root: ???
lm_arpa: ???
blank_symbol: <s>
51 changes: 17 additions & 34 deletions examples/speech_recognition/w2l_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,19 @@ def __init__(self, args, tgt_dict):
self.nbest = args.nbest

# criterion-specific init
if args.criterion == "ctc":
self.criterion_type = CriterionType.CTC
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
if "<sep>" in tgt_dict.indices:
self.silence = tgt_dict.index("<sep>")
elif "|" in tgt_dict.indices:
self.silence = tgt_dict.index("|")
else:
self.silence = tgt_dict.eos()
self.asg_transitions = None
elif args.criterion == "asg_loss":
self.criterion_type = CriterionType.ASG
self.blank = -1
self.silence = -1
self.asg_transitions = args.asg_transitions
self.max_replabel = args.max_replabel
assert len(self.asg_transitions) == self.vocab_size ** 2
self.criterion_type = CriterionType.CTC
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
if "<sep>" in tgt_dict.indices:
self.silence = tgt_dict.index("<sep>")
elif "|" in tgt_dict.indices:
self.silence = tgt_dict.index("|")
else:
raise RuntimeError(f"unknown criterion: {args.criterion}")
self.silence = tgt_dict.eos()
self.asg_transitions = None

def generate(self, models, sample, **unused):
"""Generate a batch of inferences."""
Expand All @@ -90,23 +80,16 @@ def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
model = models[0]
encoder_out = model(**encoder_input)
if self.criterion_type == CriterionType.CTC:
if hasattr(model, "get_logits"):
emissions = model.get_logits(encoder_out) # no need to normalize emissions
else:
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
elif self.criterion_type == CriterionType.ASG:
emissions = encoder_out["encoder_out"]
if hasattr(model, "get_logits"):
emissions = model.get_logits(encoder_out) # no need to normalize emissions
else:
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()

def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
if self.criterion_type == CriterionType.CTC:
idxs = filter(lambda x: x != self.blank, idxs)
elif self.criterion_type == CriterionType.ASG:
idxs = filter(lambda x: x >= 0, idxs)
idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel)
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))


Expand Down
60 changes: 39 additions & 21 deletions examples/wav2vec/unsupervised/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,41 @@ Similar to [wav2vec 2.0](https://github.com/pytorch/fairseq/blob/master/examples

In **/path/to/data/with_silence** you need a *train.tsv* file as well as (optionally) *{valid,test}.{tsv,wrd,phn}*. It is nice to have *10h.{tsv,phn}* files there too for reproducing the ablation study on layer selection. In **/path/to/data/without_silence** you have the same files, except *.tsv* files contain audios with silences removed using rVAD.

Here is how you can create new audio files without silences from a list of input audio files:
Pre-requisites:
* set FAIRSEQ_ROOT environmental variable to your fairseq installation
* set RVAD_ROOT environmental variable to a checkout of [rVADfast](https://github.com/zhenghuatan/rVADfast)
* set KENLM_ROOT environmental variable to the location of [KenLM](https://github.com/kpu/kenlm) binaries
* install [PyKaldi](https://github.com/pykaldi/pykaldi) and set KALDI_ROOT environmental variable to the location of your kaldi installation. To use the version bundled with PyKaldi, you can use /path/to/pykaldi/tools/kaldi

Create new audio files without silences:
```shell
python scripts/vads.py < /path/to/train.tsv > train.vads
# create a manifest file for the set original of audio files
python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py /dir/to/save/audio/files --ext wav --dest /path/to/new/train.tsv --valid-percent 0

python scripts/vads.py -r $RVAD_ROOT < /path/to/train.tsv > train.vads

python scripts/remove_silence.py --tsv /path/to/train.tsv --vads train.vads --out /dir/to/save/audio/files

python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py /dir/to/save/audio/files --ext wav --dest /path/to/new/train.tsv --valid-percent 0
python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py /dir/to/save/audio/files --ext wav --dest /path/to/new/train.tsv --valid-percent 0.01
```

You will need to add the path to rVAD directory to vads.py.

Next, we need to preprocess the audio data to better match phonemized text data:

```shell
zsh scripts/prepare_audio.sh /dir/with/{train,test,valid}.tsv /output/dir /path/to/wav2vec2/model.pt
zsh scripts/prepare_audio.sh /dir/with/{train,test,valid}.tsv /output/dir /path/to/wav2vec2/model.pt 512 14
```
Note that if you have splits different than train/valid/test, you will need to modify this script.
Note that if you have splits different than train/valid/test, you will need to modify this script. The last two arguments are the PCA dimensionality and the 0-based index of the layer from which to extract representations.

Now we need to prepare text data:
```shell
zsh scripts/prepare_text.sh language /path/to/text/file /output/dir
zsh scripts/prepare_text.sh language /path/to/text/file /output/dir 1000 espeak /path/to/fasttext/lid/model
```

Note that if you want to use a different phonemizer, such as G2P, you will need to modify this script.
The fourth argument is minimum number observations of phones to keep. If your text corpus is small, you might want to reduce this number.

The fifth argument is which phonemizer to use. Supported values are [espeak](http://espeak.sourceforge.net/), [espeak-ng](https://github.com/espeak-ng/espeak-ng), and [G2P](https://github.com/Kyubyong/g2p) (english only).

Pre-trained fasttext LID models can be downloaded [here](https://fasttext.cc/docs/en/language-identification.html).

## Generative adversarial training (GAN)

Expand All @@ -46,26 +56,34 @@ Launching GAN training on top of preprocessed features, with default hyperparame

```
PREFIX=w2v_unsup_gan_xp
TASK_DATA=/path/to/features/unfiltered/precompute_unfiltered_pca512_cls128_mean_pooled
TEXT_DATA=/path/to/data # path to fairseq-preprocessed GAN data
KENLM_PATH=/path/to/data/kenlm.phn.o4.bin # KenLM 4-gram phoneme language model (LM data = GAN data here)
PREFIX=$PREFIX fairseq-hydra-train \
-m --config-dir configs/gan \
--config-name w2vu \
task.data=${TASK_DATA} \
task.text_data=${TEXT_DATA} \
task.kenlm_path=${KENLM_PATH} \
'common.seed=range(0,5)' &
TASK_DATA=/path/to/features/precompute_unfiltered_pca512_cls128_mean_pooled
TEXT_DATA=/path/to/data/phones # path to fairseq-preprocessed GAN data (phones dir)
KENLM_PATH=/path/to/data/phones/kenlm.phn.o4.bin # KenLM 4-gram phoneme language model (LM data = GAN data here)
PYTHONPATH=$FAIRSEQ_ROOT PREFIX=$PREFIX fairseq-hydra-train \
-m --config-dir config/gan \
--config-name w2vu \
task.data=${TASK_DATA} \
task.text_data=${TEXT_DATA} \
task.kenlm_path=${KENLM_PATH} \
common.user_dir=${FAIRSEQ_ROOT}/examples/wav2vec/unsupervised \
model.code_penalty=2,4 model.gradient_penalty=1.5,2.0 \
model.smoothness_weight=0.5,0.75,1.0 'common.seed=range(0,5)'
```


Once we find the best checkpoint (chosen using unsupervised metric that combined language model perplexity and vocabulary usage), we can use it to generate phone labels (or word labels with an appropriate kaldi WFST):

```shell
python w2vu_generate.py --config-dir config/generate --config-name viterbi \
fairseq.task.data=/path/to/dir/with/tsvs fairseq.common_eval.path=/path/to/gan/checkpoint \
fairseq.common.user_dir=${FAIRSEQ_ROOT}/examples/wav2vec/unsupervised \
fairseq.task.data=/path/to/dir/with/features \
fairseq.common_eval.path=/path/to/gan/checkpoint \
fairseq.dataset.gen_subset=valid results_path=/where/to/save/transcriptions
```

The decoding without LM works best on the same adjacent-mean-pooled features that the gan was trained on, while decoding with LM works better on features before the adjacent timestep mean-pooling step (without the "_pooled" suffix).

## Iterative self-training + Kaldi LM-decoding
After the GAN training provides a first unsupervised model, we can then progressively refine the quality of transcriptions using several iterations of semi-supervised learning. We perform two iterations: first, pseudo-label the training data with the unsupervised GAN model and train an HMM on the pseudo-labels. Second, we relabel the training data with the HMM and then fine-tune the original wav2vec 2.0 model using the HMM pseudo-labels with a CTC loss. Note that HMM models use phonemes as output, while wav2vec 2.0 use letter. Both are decoded using WFST decoders into words.

Expand Down
7 changes: 7 additions & 0 deletions examples/wav2vec/unsupervised/config/gan/w2vu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@ common:
suppress_crashes: false

checkpoint:
save_interval: 1000
save_interval_updates: 1000
no_epoch_checkpoints: true
best_checkpoint_metric: weighted_lm_ppl
save_dir: .

distributed_training:
distributed_world_size: 1

task:
_name: unpaired_audio_text
data: ???
Expand All @@ -30,6 +35,8 @@ dataset:
batch_size: 160
skip_invalid_size_inputs_valid_test: true
valid_subset: valid
validate_interval: 1000
validate_interval_updates: 1000

criterion:
_name: model
Expand Down
1 change: 0 additions & 1 deletion examples/wav2vec/unsupervised/config/generate/viterbi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,4 @@ fairseq:
batch_size: 1

w2l_decoder: VITERBI
lm_model: ???
post_process: silence
65 changes: 22 additions & 43 deletions examples/wav2vec/unsupervised/models/wav2vec_u.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class Wav2vec_UConfig(FairseqDataclass):
hard_gumbel: bool = True
temp: Tuple[float, float, float] = (2, 0.1, 0.99995)
input_dim: int = 128
wgan_loss: bool = False

segmentation: SegmentationConfig = SegmentationConfig()

Expand Down Expand Up @@ -397,12 +396,7 @@ def set_num_updates(self, num_updates):
)

def discrim_step(self, num_updates):
if num_updates < self.zero_pretrain_updates:
return False
if self.dynamic_step_thresh <= 0 or self.last_acc is None:
return num_updates % 2 == 1
else:
return self.last_acc < self.dynamic_step_thresh
return num_updates % 2 == 1

def get_groups_for_update(self, num_updates):
return "discriminator" if self.discrim_step(num_updates) else "generator"
Expand All @@ -413,7 +407,6 @@ def __init__(self, cfg: Wav2vec_UConfig, target_dict):
self.cfg = cfg
self.zero_index = target_dict.index("<SIL>") if "<SIL>" in target_dict else 0
self.smoothness_weight = cfg.smoothness_weight
self.wgan_loss = cfg.wgan_loss

output_size = len(target_dict)
self.pad = target_dict.pad()
Expand All @@ -432,7 +425,7 @@ def __init__(self, cfg: Wav2vec_UConfig, target_dict):
self.blank_index = target_dict.index("<SIL>") if cfg.blank_is_sil else 0
assert self.blank_index != target_dict.unk()

self.discriminator = self.Discriminator(output_size, cfg)
self.discriminator = Discriminator(output_size, cfg)
for p in self.discriminator.parameters():
p.param_group = "discriminator"

Expand All @@ -441,9 +434,7 @@ def __init__(self, cfg: Wav2vec_UConfig, target_dict):

self.segmenter = SEGMENT_FACTORY[cfg.segmentation.type](cfg.segmentation)

self.generator = self.Generator(
d, output_size, cfg, lambda x: self.normalize(x)[0]
)
self.generator = Generator(d, output_size, cfg)

for p in self.generator.parameters():
p.param_group = "generator"
Expand Down Expand Up @@ -589,20 +580,16 @@ def forward(
code_pen = None

if d_step:
if self.wgan_loss:
loss_dense = dense_y.sum()
loss_token = -1 * token_y.sum()
else:
loss_dense = F.binary_cross_entropy_with_logits(
dense_y,
dense_y.new_ones(dense_y.shape) - fake_smooth,
reduction="sum",
)
loss_token = F.binary_cross_entropy_with_logits(
token_y,
token_y.new_zeros(token_y.shape) + real_smooth,
reduction="sum",
)
loss_dense = F.binary_cross_entropy_with_logits(
dense_y,
dense_y.new_ones(dense_y.shape) - fake_smooth,
reduction="sum",
)
loss_token = F.binary_cross_entropy_with_logits(
token_y,
token_y.new_zeros(token_y.shape) + real_smooth,
reduction="sum",
)
if self.training and self.gradient_penalty > 0:
grad_pen = self.calc_gradient_penalty(token_x, dense_x)
grad_pen = grad_pen.sum() * self.gradient_penalty
Expand All @@ -611,23 +598,15 @@ def forward(
else:
grad_pen = None
loss_token = None
if self.update_num >= self.zero_pretrain_updates:
if self.wgan_loss:
loss_dense = -1 * dense_y.sum()
else:
loss_dense = F.binary_cross_entropy_with_logits(
dense_y,
dense_y.new_zeros(dense_y.shape) + fake_smooth,
reduction="sum",
)
num_vars = dense_x.size(-1)
if prob_perplexity is not None:
code_pen = (num_vars - prob_perplexity) / num_vars
if self.exponential_code_pen:
code_pen = (1 - 1 / code_pen ** 2).exp()
code_pen = code_pen * sample_size * self.code_penalty
else:
loss_dense = None
loss_dense = F.binary_cross_entropy_with_logits(
dense_y,
dense_y.new_zeros(dense_y.shape) + fake_smooth,
reduction="sum",
)
num_vars = dense_x.size(-1)
if prob_perplexity is not None:
code_pen = (num_vars - prob_perplexity) / num_vars
code_pen = code_pen * sample_size * self.code_penalty

if self.smoothness_weight > 0:
smoothness_loss = F.mse_loss(
Expand Down
8 changes: 6 additions & 2 deletions examples/wav2vec/unsupervised/scripts/apply_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ def main():

copyfile(source_path + ".tsv", save_path + ".tsv")
copyfile(data_poth + ".lengths", save_path + ".lengths")
copyfile(source_path + ".phn", save_path + ".phn")
copyfile(source_path + ".wrd", save_path + ".wrd")

if osp.exists(source_path + ".phn"):
copyfile(source_path + ".phn", save_path + ".phn")

if osp.exists(source_path + ".wrd"):
copyfile(source_path + ".wrd", save_path + ".wrd")

if osp.exists(save_path + ".npy"):
os.remove(save_path + ".npy")
Expand Down
20 changes: 12 additions & 8 deletions examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,32 @@

def main():
parser = argparse.ArgumentParser()
parser.add_argument("root_dirs", nargs="*")
parser.add_argument("--insert-silence", "-s", action="store_true")
parser.add_argument(
"--compact",
action="store_true",
help="if set, compacts phones",
)
args = parser.parse_args()
sil = "<s>"

compact = args.compact

wrd_to_phn = {}
g2p = G2p()
for line in sys.stdin:
words = line.strip().split()
phones = []
if args.insert_silence:
phones.append(sil)
for w in words:
if w not in wrd_to_phn:
wrd_to_phn[w] = g2p(w)
if compact:
wrd_to_phn[w] = [
p[:-1] if p[-1].isnumeric() else p for p in wrd_to_phn[w]
]
phones.extend(wrd_to_phn[w])
if args.insert_silence:
phones.append(sil)
try:
print(" ".join(phones))
except:
print(wrd_to_phn, w, phones, file=sys.stderr)
print(wrd_to_phn, words, phones, file=sys.stderr)
raise


Expand Down
13 changes: 11 additions & 2 deletions examples/wav2vec/unsupervised/scripts/mean_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,17 @@ def main():
save_path = osp.join(args.save_dir, args.split)

copyfile(source_path + ".tsv", save_path + ".tsv")
copyfile(source_path + ".phn", save_path + ".phn")
copyfile(source_path + ".wrd", save_path + ".wrd")

if os.path.exists(source_path + ".phn"):
copyfile(source_path + ".phn", save_path + ".phn")
if os.path.exists(source_path + ".wrd"):
copyfile(source_path + ".wrd", save_path + ".wrd")

if os.path.exists(osp.join(args.source, "dict.phn.txt")):
copyfile(
osp.join(args.source, "dict.phn.txt"),
osp.join(args.save_dir, "dict.phn.txt"),
)

if osp.exists(save_path + ".npy"):
os.remove(save_path + ".npy")
Expand Down
Loading

0 comments on commit f8a7c93

Please sign in to comment.