This is an implementation of our paper DistilALHuBERT: A Distilled Parameter Sharing Audio Representation Model. This repository contains the pre-trained models and the codes. To use our model, you can either add our code to S3PRL as an additional "Upstream model" or use it separately.
Our model is implemented by S3RPL, a flexible toolbox for pre-trained speech models. S3PRL support adding customized pre-trained models (called Upstream models) and customized pre-training methods.
To add our code into S3PRL, you should:
-
Make sure your S3PRL is installed in development mode.
-
copy
src/upstream/alhubert
tos3prl/upstream/alhubert
. -
copy
src/upstream/hubert
tos3prl/upstream/hubert
since we have changed some of the HuBERT implementation code in S3PRL to support multi-GPU training. Please remember to keep a copy of those codes! -
copy
src/pretrain/alhubert
tos3prl/pretrain/alhubert
. -
add
from .upstream.alhubert.hubconf import *
tos3prl/hub.py
. You will find many similar import statements there, so just append this one to the last one. -
copy
config/alhubert/*.yaml
to anywhere you like. e.g.,s3prl/pretrain/alhubert
.
We recommend adding our code to your S3PRL installation to evaluate our pre-trained models on all the downstream tasks.
To use our code without S3PRL, you should
- run
pip install -r requirement.txt
to install all the dependencies.
You can extract features from the pre-trained model by
import torch
from src.upstream.alhubert.expert import UpstreamExpert
# when using in s3prl, you can use
# from s3prl.upstream.alhubert.expert import UpstreamExpert
model_ckpt_path = "small.ckpt"
model = UpstreamExpert(model_ckpt_path)
data = [torch.randn(10000) for _ in range(2)] # 16KHz
states = model(data)
print(states["last_hidden_state"].shape) # torch.Tensor: hidden state of the last layer
print(len(states["hidden_states"])) # list[torch.Tensor] hidden states of each layer
# please note that if layer_norm_first=False (default), "hidden_states" will be the outputs of transformer layer 0,1,...11
# layer_norm_first=True (for HuBERT Large teachers), "hidden_states" will be the outputs of the CNN feature estractor and transformer layer 0,1,...10.
# in that case, the output of transformer layer 11 is in states["last_hidden_state"].
# This is because that the feature after layer norm is better for distillation.
The pre-trained models can be downloaded at:
Model | Link |
---|---|
small | Google Drive |
middle | Google Drive |
large | Google Drive |
Take the small model (2*6=12 layers) for an example.
Prepare the data according to the instructions in S3PRL.
Edit s3prl/pretrain/alhubert/config_runner.yaml
to add the dataset path.
libri_root: '/mnt/data/LibriSpeech/'
# path to the librispeech dataset
# contains folders like ‘train-clean-100’ or 'train-other-500'
file_path: '/mnt/exp/len_for_bucket'
# path to the audio length files generated in step 1.
Edit s3prl/pretrain/alhubert/config_model_l2.yaml
to add the path of the teacher model.
teacher:
model: hubert_local
path: "/path/to/teacher/model.pt"
We use the pre-trained Hubert Base model provided by hugging face.
Perform distillation.
#!/bin/bash
export OMP_NUM_THREADS=1
set -ue
name=l2
expdir=exp
python run_pretrain.py \
--upstream alhubert \
--upstream_config "pretrain/alhubert/config_model_l2.yaml" \
--config "pretrain/alhubert/config_runner.yaml" \
--expname $name \
--expdir $expdir/$name
If you have added our code to S3PRL, you can follow the official instructions to evaluate our model in all the downstream tasks. Our model is registered as alhubert_local
. E.g., You can perform ASR fine-tuning by
#!/bin/bash
set -ue
export OMP_NUM_THREADS=1
name="asr"
python3 run_downstream.py \
--config downstream/asr/config.yaml \
--upstream alhubert_local \
--upstream_feature_selection hidden_states \
--downstream asr \
--expname $name \
--mode train \
--upstream_ckpt "small.ckpt" \
--expdir /mnt/exp/$name
We also add ASR fine-tuning code to this repository and you can use similar code to evaluate the pre-trained model on the ASR task without S3PRL. For other tasks, we still recommend using S3PRL's official implementations.
Most of the source code is based on s3prl and DistilHuBERT.