Skip to content

Commit

Permalink
wav2vec2 speech translation OSS
Browse files Browse the repository at this point in the history
Summary:
wav2vec2 speech translation OSS
- Based on fairinternal/fairseq-py#1829
- Updated `Wav2VecEncoder` API to make it consistent for `Wav2VecCTC` (for ASR) and `Wav2Vec2Seq2Seq` (for ST)
- Small fixes in `Wav2Vec2Seq2Seq`
- Refactored `audio_pretraining` into `audio_pretraining` and `audio_finetuning`

Reviewed By: sravyapopuri388, cndn

Differential Revision: D29285182

fbshipit-source-id: 89f93b42caa88079940a4b2cac0f8952547d3ff0
  • Loading branch information
kahne authored and facebook-github-bot committed Jul 27, 2021
1 parent 7ca95a6 commit 75051ec
Show file tree
Hide file tree
Showing 17 changed files with 419 additions and 262 deletions.
2 changes: 1 addition & 1 deletion examples/wav2vec/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ Next, run the evaluation command:

```shell script
$subset=dev_other
python examples/speech_recognition/infer.py /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw --task audio_pretraining \
python examples/speech_recognition/infer.py /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw --task audio_finetuning \
--nbest 1 --path /path/to/model --gen-subset $subset --results-path /path/to/save/results/for/sclite --w2l-decoder kenlm \
--lm-model /path/to/kenlm.bin --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 \
--post-process letter
Expand Down
3 changes: 1 addition & 2 deletions examples/wav2vec/config/finetuning/base_100h.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ checkpoint:
best_checkpoint_metric: wer

task:
_name: audio_pretraining
_name: audio_finetuning
data: ???
normalize: false
labels: ltr
Expand Down Expand Up @@ -56,4 +56,3 @@ model:
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 0

3 changes: 1 addition & 2 deletions examples/wav2vec/config/finetuning/base_10h.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ checkpoint:
best_checkpoint_metric: wer

task:
_name: audio_pretraining
_name: audio_finetuning
data: ???
normalize: false
labels: ltr
Expand Down Expand Up @@ -61,4 +61,3 @@ model:
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 10000

3 changes: 1 addition & 2 deletions examples/wav2vec/config/finetuning/base_10m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ checkpoint:
best_checkpoint_metric: wer

task:
_name: audio_pretraining
_name: audio_finetuning
data: ???
normalize: false
labels: ltr
Expand Down Expand Up @@ -61,4 +61,3 @@ model:
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 10000

3 changes: 1 addition & 2 deletions examples/wav2vec/config/finetuning/base_1h.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ checkpoint:
best_checkpoint_metric: wer

task:
_name: audio_pretraining
_name: audio_finetuning
data: ???
normalize: false
labels: ltr
Expand Down Expand Up @@ -61,4 +61,3 @@ model:
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 10000

3 changes: 1 addition & 2 deletions examples/wav2vec/config/finetuning/base_960h.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ checkpoint:
best_checkpoint_metric: wer

task:
_name: audio_pretraining
_name: audio_finetuning
data: ???
normalize: false
labels: ltr
Expand Down Expand Up @@ -55,4 +55,3 @@ model:
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 0

3 changes: 1 addition & 2 deletions examples/wav2vec/config/finetuning/vox_100h.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ checkpoint:
best_checkpoint_metric: wer

task:
_name: audio_pretraining
_name: audio_finetuning
data: ???
normalize: true
labels: ltr
Expand Down Expand Up @@ -56,4 +56,3 @@ model:
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 10000

3 changes: 1 addition & 2 deletions examples/wav2vec/config/finetuning/vox_10h.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ checkpoint:
best_checkpoint_metric: wer

task:
_name: audio_pretraining
_name: audio_finetuning
data: ???
normalize: true
labels: ltr
Expand Down Expand Up @@ -61,4 +61,3 @@ model:
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 10000

3 changes: 1 addition & 2 deletions examples/wav2vec/config/finetuning/vox_10m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ checkpoint:
best_checkpoint_metric: wer

task:
_name: audio_pretraining
_name: audio_finetuning
data: ???
normalize: true
labels: ltr
Expand Down Expand Up @@ -61,4 +61,3 @@ model:
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 10000

3 changes: 1 addition & 2 deletions examples/wav2vec/config/finetuning/vox_1h.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ checkpoint:
best_checkpoint_metric: wer

task:
_name: audio_pretraining
_name: audio_finetuning
data: ???
normalize: true
labels: ltr
Expand Down Expand Up @@ -61,4 +61,3 @@ model:
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 10000

3 changes: 1 addition & 2 deletions examples/wav2vec/config/finetuning/vox_960h.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ checkpoint:
best_checkpoint_metric: wer

task:
_name: audio_pretraining
_name: audio_finetuning
data: ???
normalize: true
labels: ltr
Expand Down Expand Up @@ -55,4 +55,3 @@ model:
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 10000

Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ checkpoint:
save_interval_updates: 20000

task:
_name: audio_pretraining
_name: audio_finetuning
data: ???
normalize: true
labels: ltr
Expand Down
6 changes: 6 additions & 0 deletions fairseq/data/add_target_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,9 @@ def collater(self, samples):
).long()
collated["ntokens"] += target.size(0)
return collated

def filter_indices_by_size(self, indices, max_sizes):
indices, ignored = data_utils._filter_by_size_dynamic(
indices, self.size, max_sizes
)
return indices, ignored
77 changes: 42 additions & 35 deletions fairseq/models/wav2vec/wav2vec2_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ class Wav2Vec2AsrConfig(FairseqDataclass):
"help": "dropout probability after activation in FFN inside wav2vec 2.0 model"
},
)
conv_feature_layers: Optional[str] = field(
default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
metadata={
"help": (
"string describing convolutional feature extraction "
"layers in form of a python list that contains "
"[(dim, kernel_size, stride), ...]"
),
},
)
encoder_embed_dim: Optional[int] = field(
default=768, metadata={"help": "encoder embedding dimension"}
)

# masking
apply_mask: bool = field(
Expand Down Expand Up @@ -92,6 +105,10 @@ class Wav2Vec2AsrConfig(FairseqDataclass):
no_mask_overlap: bool = field(
default=False, metadata={"help": "whether to allow masks to overlap"}
)
mask_min_space: Optional[int] = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)

# channel masking
mask_channel_length: int = field(
Expand Down Expand Up @@ -123,6 +140,10 @@ class Wav2Vec2AsrConfig(FairseqDataclass):
layerdrop: float = field(
default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"}
)
mask_channel_min_space: Optional[int] = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
mask_channel_before: bool = False
normalize: bool = II("task.normalize")
data: str = II("task.data")
Expand All @@ -134,27 +155,6 @@ class Wav2Vec2AsrConfig(FairseqDataclass):
class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig):
blank_weight: float = 0
blank_mode: str = "add"
mask_min_space: Optional[int] = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
mask_channel_min_space: Optional[int] = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
conv_feature_layers: Optional[str] = field(
default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
metadata={
"help": (
"string describing convolutional feature extraction "
"layers in form of a python list that contains "
"[(dim, kernel_size, stride), ...]"
),
},
)
encoder_embed_dim: Optional[int] = field(
default=768, metadata={"help": "encoder embedding dimension"}
)


@register_model("wav2vec_ctc", dataclass=Wav2Vec2CtcConfig)
Expand Down Expand Up @@ -299,7 +299,7 @@ def build_decoder(cls, cfg: Wav2Vec2Seq2SeqConfig, tgt_dict, embed_tokens):
return TransformerDecoder(cfg, tgt_dict, embed_tokens)

def forward(self, **kwargs):
encoder_out = self.encoder(tbc=False, **kwargs)
encoder_out = self.encoder(**kwargs)
decoder_out = self.decoder(encoder_out=encoder_out, **kwargs)
return decoder_out

Expand Down Expand Up @@ -386,7 +386,8 @@ def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
self.num_updates = num_updates

def forward(self, source, padding_mask, tbc=True, **kwargs):
def forward(self, source, padding_mask, **kwargs):

w2v_args = {
"source": source,
"padding_mask": padding_mask,
Expand All @@ -401,9 +402,8 @@ def forward(self, source, padding_mask, tbc=True, **kwargs):
x = res["x"]
padding_mask = res["padding_mask"]

if tbc:
# BTC -> TBC
x = x.transpose(0, 1)
# B x T x C -> T x B x C
x = x.transpose(0, 1)

x = self.final_dropout(x)

Expand All @@ -412,21 +412,24 @@ def forward(self, source, padding_mask, tbc=True, **kwargs):

return {
"encoder_out": x, # T x B x C
"encoder_padding_mask": padding_mask.transpose(0, 1)
if padding_mask is not None
else None, # T x B
"padding_mask": padding_mask,
"padding_mask": padding_mask, # B x T,
"layer_results": res["layer_results"],
}

def forward_torchscript(self, net_input):
if torch.jit.is_scripting():
return self.forward(net_input["source"], net_input["padding_mask"])
else:
return self.forward_non_torchscript(net_input)

def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
if encoder_out["padding_mask"] is not None:
encoder_out["padding_mask"] = encoder_out[
"padding_mask"
].index_select(0, new_order)
return encoder_out

Expand Down Expand Up @@ -469,7 +472,7 @@ def __init__(

self.layerdrop = cfg.decoder_layerdrop

padding_idx = embed_tokens.padding_idx
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = cfg.max_target_positions

self.embed_tokens = embed_tokens
Expand All @@ -485,7 +488,7 @@ def __init__(
PositionalEmbedding(
cfg.max_target_positions,
embed_dim,
padding_idx,
self.padding_idx,
learned=cfg.decoder_learned_pos,
)
if not cfg.no_token_positional_embeddings
Expand Down Expand Up @@ -589,6 +592,9 @@ def extract_features(
inner_states = [x]

# decoder layers
self_attn_padding_mask = None
if prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
for layer in self.layers:
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
Expand All @@ -600,6 +606,7 @@ def extract_features(
self_attn_mask=self.buffered_future_mask(x)
if incremental_state is None
else None,
self_attn_padding_mask=self_attn_padding_mask
)
inner_states.append(x)

Expand Down
Loading

0 comments on commit 75051ec

Please sign in to comment.