Skip to content

Commit

Permalink
update speechio whisper ft results (#1605)
Browse files Browse the repository at this point in the history
* update speechio whisper ft results
  • Loading branch information
yuekaizhang authored Apr 30, 2024
1 parent b49351f commit 6d7c1d1
Show file tree
Hide file tree
Showing 14 changed files with 1,812 additions and 207 deletions.
43 changes: 43 additions & 0 deletions egs/multi_zh-hans/ASR/RESULTS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,48 @@
## Results

### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2
#### Whisper
[./whisper](./whisper)

Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.

| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------|
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting |
| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 |

Command for training is:
```bash
pip install -r whisper/requirements.txt

# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54

torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json
```

Command for decoding using fine-tuned models:
```bash
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper
ln -s icefall_asr_multi-hans-zh_whisper/v1.1/epoch-3-avg-10.pt whisper/exp_large_v2/epoch-999.pt

python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
```

Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>


### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model

This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.
Expand Down
4 changes: 2 additions & 2 deletions egs/multi_zh-hans/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare WenetSpeech"
if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then
cd data/fbank
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV_fixed.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L_fixed.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .

Expand Down
50 changes: 49 additions & 1 deletion egs/multi_zh-hans/ASR/whisper/decode.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert

Expand Down Expand Up @@ -214,7 +215,7 @@ def get_parser():
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"],
help="""The model name to use.
""",
)
Expand All @@ -226,6 +227,13 @@ def get_parser():
help="replace whisper encoder forward method to remove input length restriction",
)

parser.add_argument(
"--use-distill-whisper",
type=str2bool,
default=False,
help="Whether to use architecture of distill whisper.",
)

return parser


Expand Down Expand Up @@ -307,6 +315,43 @@ def decode_dataset(
Returns:
Return a dict, whose key may be "beam-search".
"""

def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re

text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("A", "A")
text = text.replace("a", "A")
text = text.replace("b", "B")
text = text.replace("c", "C")
text = text.replace("k", "K")
text = text.replace("t", "T")
text = text.replace(",", "")
text = text.replace("丶", "")
text = text.replace("。", "")
text = text.replace("、", "")
text = text.replace("?", "")
return text

results = []

num_cuts = 0
Expand All @@ -331,6 +376,7 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))

Expand Down Expand Up @@ -430,6 +476,8 @@ def main():

if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
if params.use_distill_whisper:
replace_whisper_decoder_forward()
model = whisper.load_model(params.model_name, "cpu")
if params.epoch > 0:
if params.avg > 1:
Expand Down
91 changes: 21 additions & 70 deletions egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, fbank_dir: str):
- thchs_30_cuts_train.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
- wenetspeech/cuts_L.jsonl.gz
- wenetspeech/cuts_L_fixed.jsonl.gz
"""
self.fbank_dir = Path(fbank_dir)

Expand Down Expand Up @@ -105,7 +105,7 @@ def train_cuts(self) -> CutSet:
# WeNetSpeech
logging.info("Loading WeNetSpeech in lazy mode")
wenetspeech_L_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz"
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
)

# KeSpeech
Expand All @@ -124,10 +124,10 @@ def train_cuts(self) -> CutSet:
aishell_4_L_cuts,
aishell_4_M_cuts,
aishell_4_S_cuts,
alimeeting_cuts,
stcmds_cuts,
primewords_cuts,
magicdata_cuts,
alimeeting_cuts,
wenetspeech_L_cuts,
kespeech_1_cuts,
kespeech_2_cuts,
Expand All @@ -138,10 +138,10 @@ def train_cuts(self) -> CutSet:
len(aishell_4_L_cuts),
len(aishell_4_M_cuts),
len(aishell_4_S_cuts),
len(alimeeting_cuts),
len(stcmds_cuts),
len(primewords_cuts),
len(magicdata_cuts),
len(alimeeting_cuts),
len(wenetspeech_L_cuts),
len(kespeech_1_cuts),
len(kespeech_2_cuts),
Expand All @@ -151,55 +151,13 @@ def train_cuts(self) -> CutSet:
def dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")

# AISHELL
logging.info("Loading Aishell DEV set in lazy mode")
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
)

# AISHELL-2
logging.info("Loading Aishell-2 DEV set in lazy mode")
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)

# Ali-Meeting
logging.info("Loading Ali-Meeting DEV set in lazy mode")
alimeeting_dev_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
)

# MagicData
logging.info("Loading MagicData DEV set in lazy mode")
magicdata_dev_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
)

# KeSpeech
logging.info("Loading KeSpeech DEV set in lazy mode")
kespeech_dev_phase1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
)
kespeech_dev_phase2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
)

# WeNetSpeech
logging.info("Loading WeNetSpeech DEV set in lazy mode")
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
)

return wenetspeech_dev_cuts
# return [
# aishell_dev_cuts,
# aishell2_dev_cuts,
# alimeeting_dev_cuts,
# magicdata_dev_cuts,
# kespeech_dev_phase1_cuts,
# kespeech_dev_phase2_cuts,
# wenetspeech_dev_cuts,
# ]

def test_cuts(self) -> Dict[str, CutSet]:
logging.info("About to get multidataset test cuts")
Expand Down Expand Up @@ -267,30 +225,23 @@ def test_cuts(self) -> Dict[str, CutSet]:
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
)
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
)

return {
"aishell-2_test": aishell2_test_cuts,
"aishell-4": aishell4_test_cuts,
"magicdata_test": magicdata_test_cuts,
"kespeech-asr_test": kespeech_test_cuts,
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
# "aishell_test": aishell_test_cuts,
# "aishell_dev": aishell_dev_cuts,
# "ali-meeting_test": alimeeting_test_cuts,
# "ali-meeting_eval": alimeeting_eval_cuts,
# "aishell-4_test": aishell4_test_cuts,
# "aishell-2_test": aishell2_test_cuts,
# "aishell-2_dev": aishell2_dev_cuts,
# "magicdata_test": magicdata_test_cuts,
# "magicdata_dev": magicdata_dev_cuts,
# "kespeech-asr_test": kespeech_test_cuts,
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
# "wenetspeech_dev": wenetspeech_dev_cuts,
}

# return {
# "alimeeting_test": alimeeting_test_cuts,
# "alimeeting_eval": alimeeting_eval_cuts,
# "aishell_test": aishell_test_cuts,
# "aishell_dev": aishell_dev_cuts,
# "aishell-2_test": aishell2_test_cuts,
# "aishell-2_dev": aishell2_dev_cuts,
# "aishell-4": aishell4_test_cuts,
# "magicdata_test": magicdata_test_cuts,
# "magicdata_dev": magicdata_dev_cuts,
# "kespeech-asr_test": kespeech_test_cuts,
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
# "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
# "wenetspeech_dev": wenetspeech_dev_cuts,
# }
50 changes: 48 additions & 2 deletions egs/multi_zh-hans/ASR/whisper/train.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from torch.nn.functional import pad as pad_tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward

from icefall import diagnostics
Expand Down Expand Up @@ -146,7 +147,7 @@ def get_parser():
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"],
help="""The model name to use.
""",
)
Expand Down Expand Up @@ -232,6 +233,13 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--use-distill-whisper",
type=str2bool,
default=False,
help="Whether to use architecture of distill whisper.",
)

parser = deepspeed.add_config_arguments(parser)

return parser
Expand Down Expand Up @@ -441,6 +449,42 @@ def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
return torch.stack([tensor for tensor in padded_tensors], dim=0)

def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re

text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("A", "A")
text = text.replace("a", "A")
text = text.replace("b", "B")
text = text.replace("c", "C")
text = text.replace("k", "K")
text = text.replace("t", "T")
text = text.replace(",", "")
text = text.replace("丶", "")
text = text.replace("。", "")
text = text.replace("、", "")
text = text.replace("?", "")
return text

max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
Expand All @@ -459,7 +503,7 @@ def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:

texts = batch["supervisions"]["text"]
# remove spaces in texts
texts = [text.replace(" ", "") for text in texts]
texts = [normalize_text_alimeeting(text) for text in texts]

text_tokens_list = [
list(tokenizer.sot_sequence_including_notimestamps)
Expand Down Expand Up @@ -759,6 +803,8 @@ def run(rank, world_size, args):
logging.info("About to create model")

replace_whisper_encoder_forward()
if params.use_distill_whisper:
replace_whisper_decoder_forward()
model = whisper.load_model(params.model_name, "cpu")
del model.alignment_heads

Expand Down
Loading

0 comments on commit 6d7c1d1

Please sign in to comment.