Skip to content

Commit

Permalink
Whisper Fine-tuning Recipe on Aishell1 (#1466)
Browse files Browse the repository at this point in the history
* add decode seamlessm4t

* add requirements

* add decoding with avg model

* add token files

* add custom tokenizer

* support deepspeed to finetune large model

* support large-v3

* add model saving

* using monkey patch to replace models

* add manifest dir option
  • Loading branch information
yuekaizhang authored Jan 26, 2024
1 parent 8d39f95 commit 1c30847
Show file tree
Hide file tree
Showing 14 changed files with 1,682 additions and 21 deletions.
7 changes: 7 additions & 0 deletions egs/aishell/ASR/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,10 @@ The following table lists the differences among them.
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.

# Whisper

Recipe to finetune large pretrained models
| | Encoder | Decoder | Comment |
|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------|
| `whisper` | Transformer | Transformer | support fine-tuning using deepspeed
67 changes: 62 additions & 5 deletions egs/aishell/ASR/RESULTS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,63 @@
## Results

### Aishell training results (Fine-tuning Pretrained Models)
#### Whisper
[./whisper](./whisper)
##### fine-tuning results on Aishell test set on whisper medium, large-v2, large-v3

| | test (before fine-tuning) | test (after fine-tuning) | comment |
|------------------------|------|------|-----------------------------------------|
| medium | 7.23 | 3.27 | --epoch 10 --avg 4, ddp |
| large-v2 | 6.56 | 2.47 | --epoch 10 --avg 6, deepspeed zero stage1 |
| large-v3 | 6.06 | 2.84 | --epoch 5 --avg 3, deepspeed zero stage1 |

Command for training is:
```bash
pip install -r whisper/requirements.txt

./prepare.sh --stage 30 --stop_stage 30

#fine-tuning with deepspeed zero stage 1
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json

# fine-tuning with ddp
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_medium \
--base-lr 1e-5 \
--model-name medium
```

Command for decoding using fine-tuned models:
```bash
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt

python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
```
Command for decoding using pretrained models (before fine-tuning):
```bash
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch -1 --avg 1 \
--remove-whisper-encoder-input-length-restriction False \
--beam-size 10 --max-duration 50
```
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/yuekai/icefall_asr_aishell_whisper>

### Aishell training result (Stateless Transducer)

#### Zipformer (Byte-level BPE)
Expand Down Expand Up @@ -71,7 +129,7 @@ It's reworked Zipformer with Pruned RNNT loss.

Command for training is:
```bash
./prepare.sh
./prepare.sh

export CUDA_VISIBLE_DEVICES="0,1"

Expand Down Expand Up @@ -136,7 +194,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--max-duration 1200
--max-duration 1200
```

Command for decoding is:
Expand Down Expand Up @@ -186,7 +244,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--max-duration 800
--max-duration 800
```

Command for decoding is:
Expand All @@ -202,7 +260,7 @@ for m in greedy_search modified_beam_search fast_beam_search ; do
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192
--encoder-unmasked-dim 192,192,256,320,256,192
done
```

Expand Down Expand Up @@ -755,7 +813,6 @@ python3 ./transducer_stateless/decode.py \
--max-sym-per-frame 3
```

### Aishell training results (Transducer-stateless)
#### 2022-02-18
(Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/k3QL6QMhRbCwCKYKM9po9w/>
And pretrained model is available at <https://huggingface.co/pfluo/icefall-aishell-transducer-stateless-char-2021-12-29>
Expand Down
45 changes: 38 additions & 7 deletions egs/aishell/ASR/local/compute_fbank_aishell.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
from pathlib import Path

import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached

from icefall.utils import get_executor, str2bool
Expand All @@ -42,9 +49,14 @@
torch.set_num_interop_threads(1)


def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
def compute_fbank_aishell(
num_mel_bins: int = 80,
perturb_speed: bool = False,
whisper_fbank: bool = False,
output_dir: str = "data/fbank",
):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir = Path(output_dir)
num_jobs = min(15, os.cpu_count())

dataset_parts = (
Expand All @@ -68,8 +80,12 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
list(manifests.keys()),
dataset_parts,
)

extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))

with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
Expand All @@ -82,7 +98,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
Expand Down Expand Up @@ -111,6 +127,18 @@ def get_args():
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
parser.add_argument(
"--output-dir",
type=str,
default="data/fbank",
help="Output directory. Default: data/fbank.",
)
return parser.parse_args()


Expand All @@ -121,5 +149,8 @@ def get_args():

args = get_args()
compute_fbank_aishell(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
output_dir=args.output_dir,
)
13 changes: 13 additions & 0 deletions egs/aishell/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,16 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
--vocab-size 4336 \
--master-port 12345
fi

# whisper large-v3 using 128 mel bins, others using 80 mel bins
whisper_mel_bins=80
output_dir=data/fbank_whisper
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
log "Stage 30: Compute ${whisper_mel_bins} dim fbank for whisper model fine-tuning"
if [ ! -f $output_dir/.aishell.whisper.done ]; then
mkdir -p $output_dir
./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir
./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir
touch $output_dir/.aishell.whisper.done
fi
fi
1 change: 1 addition & 0 deletions egs/aishell/ASR/whisper/asr_datamodule.py
Loading

0 comments on commit 1c30847

Please sign in to comment.