Skip to content

A framework for Knowledge Distillation on Transformer-based Models

Notifications You must be signed in to change notification settings

hnliu-git/transformer-distiller

Repository files navigation

Transformer Distiller

Introduction

The Transformer Distiller is a framework for knowledge distillation on Transformer-based models inspired by TextBrewer. Currently, it implements distillation methods presented in TinyBERT, MiniLM, and Patient-KD.

Part of this work is in my master thesis in which I propose a two-step distilaltion method seperating prediction and intermediate layer distillation and prove it produces a robust and effective student model by experiments.

Workflows

Update the Transformer library

We modify the source code to output some intermeidate behaviors of Transformer. More refers to the README.md A quick modification to the library can be achived by

python update_transformers.py

Finetune

Knowledge distillation usually starts with fine-tuning a teacher model. finetuner is provided to help with the first step. Below shows a short script on how to use it to fine-tune a BERT model on tweet dataset.

from datasets import load_dataset
from pytorch_lightning import Trainer
from data.data_module import ClfDataModule
from helper.finetuner import ClfFinetune, HgCkptIO
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import AutoModelForSequenceClassification

model_name = 'bert-base-uncased'

dataset = load_dataset('tweet_eval', 'sentiment')
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
dm = ClfDataModule(dataset, tokenizer=model_name)

fintuner = ClfFinetune(model, dm)

trainer = Trainer(
    plugins=[HgCkptIO()],
    callbacks=[ModelCheckpoint(monitor='val_loss',mode='min')]
)

trainer.fit(fintuner, dm)

In finetune.py, you will find how to configure finetune experiments using finetune.yaml.

Mix-Step Distillation

The mix-step distillation refers to the original distillation method on Transformer models. Both intermediate layers and the prediction layer will be updated by the defined adaptors. distiller is used to perform distillation, below shows a short script on how to use it to distill a fine-tuned BERT model into a compact student model.

from datasets import load_dataset
from pytorch_lightning import Trainer
from data.data_module import ClfDataModule
from helper.distiller import BaseDistiller, HgCkptIO
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import AutoModelForSequenceClassification

def get_model(name, num_labels):
    model = AutoModelForSequenceClassification.from_pretrained(name, num_labels=num_labels)
    model.config.output_attentions = True
    model.config.output_hidden_states = True
    model.config.output_values = True

    return model

teacher_model = '/path-to-teacher/'
student_model = 'huawei-noah/TinyBERT_General_4L_312D'

adaptors = [
  'AttnMiniLM',
  'ValMiniLM',
  'LogitMSE',
]

dataset = load_dataset('tweet_eval', 'sentiment')
teacher = get_model(teacher_model, num_labels=3)
student = get_model(student_model, num_labels=3)

dm = ClfDataModule(dataset, tokenizer=teacher_model)

distiller = BaseDistiller(teacher, student, adaptors, dm)

trainer = Trainer(
    plugins=[HgCkptIO()],
    callbacks=[ModelCheckpoint(monitor='val_loss',mode='min')]
)

trainer.fit(distiller, dm)

In distillation_ms.py, you will find how to configure distillation experiments using distillation.yaml.

Two-Step distillation

Two-step distillation is a method we proposed in the project to alleviate over-fitting issue. The method split distillaiton into two steps, the first step only update the intermediate layers while the second step only update the prediction layer. Below gives an example on how to perform the two-step method.

from datasets import load_dataset
from pytorch_lightning import Trainer
from data.data_module import ClfDataModule
from helper.distiller import InterDistiller, PredDistiller, HgCkptIO
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import AutoModelForSequenceClassification

def get_model(name, num_labels):
    model = AutoModelForSequenceClassification.from_pretrained(name, num_labels=num_labels)
    model.config.output_attentions = True
    model.config.output_hidden_states = True
    model.config.output_values = True

    return model

teacher_model = 'huawei-noah/TinyBERT_General_4L_312D'
student_model = 'huawei-noah/TinyBERT_General_4L_312D'

adaptors_1 = ['AttnMiniLM','ValMiniLM',]
adaptors_2 = ['LogitMSE']

dataset = load_dataset('tweet_eval', 'sentiment')
teacher = get_model(teacher_model, num_labels=3)
student = get_model(student_model, num_labels=3)

dm = ClfDataModule(dataset, tokenizer=teacher_model)

distiller_1 = InterDistiller(teacher, student, adaptors_1, dm)

trainer_1 = Trainer(
    plugins=[HgCkptIO()],
    callbacks=[ModelCheckpoint(dirpath='student_1st', monitor='val_loss',mode='min', save_last=True)]
)

trainer_1.fit(distiller_1, dm)

distiller_2 = PredDistiller(teacher, student, ['LogitMSE'], dm)

trainer_2 = Trainer(
    plugins=[HgCkptIO()],
    callbacks=[ModelCheckpoint(dirpath='student_2nd', monitor='val_loss',mode='min', save_last=True)]
)

trainer_2.fit(distiller_2, dm)

To use wandb to monitor the experiments, we have to separate the two steps into two files distillation_ts_1st.py and distillation_ts_2nd.py or it will lead to unexpected behaviors.

Adaptor Options

Adaptor is the module define how to calculate loss between a student layer and a teacher layer. Current available options in configuration file are:

  • AttnMiniLM: KLD between attention matrices
  • AttnTinyBERT: MSE between attention matrices
  • ValMiniLM: KLD between value matrices
  • HidnTinyBERT: MSE between hidden states
  • HidnPKD: MSE between the [CLS] token vectors
  • LogitMSE: MSE between logits after softmax
  • LogitCE: CE between logits after softmax

Citations

TextBrewer

@InProceedings{textbrewer-acl2020-demo,
    title = "{T}ext{B}rewer: {A}n {O}pen-{S}ource {K}nowledge {D}istillation {T}oolkit for {N}atural {L}anguage {P}rocessing",
    author = "Yang, Ziqing and Cui, Yiming and Chen, Zhipeng and Che, Wanxiang and Liu, Ting and Wang, Shijin and Hu, Guoping",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations",
    year = "2020",
    publisher = "Association for Computational Linguistics",
    url = "https://www.aclweb.org/anthology/2020.acl-demos.2",
    pages = "9--16",
}

TinyBERT

@article{jiao2019tinybert,
  title={Tinybert: Distilling bert for natural language understanding},
  author={Jiao, Xiaoqi and Yin, Yichun and Shang, Lifeng and Jiang, Xin and Chen, Xiao and Li, Linlin and Wang, Fang and Liu, Qun},
  journal={arXiv preprint arXiv:1909.10351},
  year={2019}
}

MiniLM

@article{wang2020minilm,
  title={Minilm: Deep self-attention distillation for task-agnostic compression of pre-trained transformers},
  author={Wang, Wenhui and Wei, Furu and Dong, Li and Bao, Hangbo and Yang, Nan and Zhou, Ming},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  pages={5776--5788},
  year={2020}
}

Patient Knowledge Distillation

@article{sun2019patient,
  title={Patient knowledge distillation for bert model compression},
  author={Sun, Siqi and Cheng, Yu and Gan, Zhe and Liu, Jingjing},
  journal={arXiv preprint arXiv:1908.09355},
  year={2019}
}

About

A framework for Knowledge Distillation on Transformer-based Models

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published