Home

Awesome

English | 中文说明

<p align="center"> <br> <img src="./pics/banner.png" width="500"/> <br> <p> <p align="center"> <a href="https://github.com/airaria/TextBrewer/blob/master/LICENSE"> <img alt="GitHub" src="https://img.shields.io/github/license/airaria/TextBrewer.svg?color=blue&style=flat-square"> </a> <a href="https://textbrewer.readthedocs.io/"> <img alt="Documentation" src="https://img.shields.io/website?down_message=offline&label=Documentation&up_message=online&url=https%3A%2F%2Ftextbrewer.readthedocs.io"> </a> <a href="https://pypi.org/project/textbrewer"> <img alt="PyPI" src="https://img.shields.io/pypi/v/textbrewer"> </a> <a href="https://github.com/airaria/TextBrewer/releases"> <img alt="GitHub release" src="https://img.shields.io/github/v/release/airaria/TextBrewer?include_prereleases"> </a> </p>

TextBrewer is a PyTorch-based model distillation toolkit for natural language processing. It includes various distillation techniques from both NLP and CV field and provides an easy-to-use distillation framework, which allows users to quickly experiment with the state-of-the-art distillation methods to compress the model with a relatively small sacrifice in the performance, increasing the inference speed and reducing the memory usage.

Check our paper through ACL Anthology or arXiv pre-print.

Full Documentation

News

Dec 17, 2021

Oct 24, 2021

Jul 8, 2021

Mar 1, 2021

<details> <summary>Click here to see old news</summary>

Nov 11, 2020

August 27, 2020

We are happy to announce that our model is on top of GLUE benchmark, check leaderboard.

Aug 24, 2020

Jul 29, 2020

Jul 14, 2020

Apr 26, 2020

Apr 22, 2020

Mar 17, 2020

Mar 11, 2020

Mar 2, 2020

</details>

Table of Contents

<!-- TOC -->
SectionContents
IntroductionIntroduction to TextBrewer
InstallationHow to install
WorkflowTwo stages of TextBrewer workflow
QuickstartExample: distilling BERT-base to a 3-layer BERT
ExperimentsDistillation experiments on typical English and Chinese datasets
Core ConceptsBrief explanations of the core concepts in TextBrewer
FAQFrequently asked questions
Known IssuesKnown issues
CitationCitation to TextBrewer
Follow Us-
<!-- /TOC -->

Introduction

Textbrewer is designed for the knowledge distillation of NLP models. It provides various distillation methods and offers a distillation framework for quickly setting up experiments.

The main features of TextBrewer are:

TextBrewer currently is shipped with the following distillation techniques:

TextBrewer includes:

  1. Distillers: the cores of distillation. Different distillers perform different distillation modes. There are GeneralDistiller, MultiTeacherDistiller, BasicTrainer, etc.
  2. Configurations and presets: Configuration classes for training and distillation, and predefined distillation loss functions and strategies.
  3. Utilities: auxiliary tools such as model parameters analysis.

To start distillation, users need to provide

  1. the models (the trained teacher model and the un-trained student model)
  2. datasets and experiment configurations

TextBrewer has achieved impressive results on several typical NLP tasks. See Experiments.

See Full Documentation for detailed usages.

Architecture

Installation

Workflow

Quickstart

Here we show the usage of TextBrewer by distilling BERT-base to a 3-layer BERT.

Before distillation, we assume users have provided:

Distill with TextBrewer:

import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig

# Show the statistics of model parameters
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print (result)

print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)

# Define an adaptor for interpreting the model inputs and outputs
def simple_adaptor(batch, model_outputs):
      # The second and third elements of model outputs are the logits and hidden states
    return {'logits': model_outputs[1],
            'hidden': model_outputs[2]}

# Training configuration 
train_config = TrainingConfig()
# Distillation configuration
# Matching different layers of the student and the teacher
distill_config = DistillationConfig(
    intermediate_matches=[    
     {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},
     {'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])

# Build distiller
distiller = GeneralDistiller(
    train_config=train_config, distill_config = distill_config,
    model_T = teacher_model, model_S = student_model, 
    adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)

# Start!
with distiller:
    distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)

Examples

Experiments

We have performed distillation experiments on several typical English and Chinese NLP datasets. The setups and configurations are listed below.

Models

We have tested different student models. To compare with public results, the student models are built with standard transformer blocks except for BiGRU which is a single-layer bidirectional GRU. The architectures are listed below. Note that the number of parameters includes the embedding layer but does not include the output layer of each specific task.

English models

Model#LayersHidden sizeFeed-forward size#ParamsRelative size
BERT-base-cased (teacher)127683072108M100%
T6 (student)6768307265M60%
T3 (student)3768307244M41%
T3-small (student)3384153617M16%
T4-Tiny (student)4312120014M13%
T12-nano (student)12256102417M16%
BiGRU (student)-768-31M29%

Chinese models

Model#LayersHidden sizeFeed-forward size#ParamsRelative size
RoBERTa-wwm-ext (teacher)127683072102M100%
Electra-base (teacher)127683072102M100%
T3 (student)3768307238M37%
T3-small (student)3384153614M14%
T4-Tiny (student)4312120011M11%
Electra-small (student)12256102412M12%

Distillation Configurations

distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches)
# Others arguments take the default values

matches are differnt for different models:

Modelmatches
BiGRUNone
T6L6_hidden_mse + L6_hidden_smmd
T3L3_hidden_mse + L3_hidden_smmd
T3-smallL3n_hidden_mse + L3_hidden_smmd
T4-TinyL4t_hidden_mse + L4_hidden_smmd
T12-nanosmall_hidden_mse + small_hidden_smmd
Electra-smallsmall_hidden_mse + small_hidden_smmd

The definitions of matches are at examples/matches/matches.py.

We use GeneralDistiller in all the distillation experiments.

Training Configurations

Results on English Datasets

We experiment on the following typical English datasets:

DatasetTask typeMetrics#Train#DevNote
MNLItext classificationm/mm Acc393K20Ksentence-pair 3-class classification
SQuAD 1.1reading comprehensionEM/F188K11Kspan-extraction machine reading comprehension
CoNLL-2003sequence labelingF123K6Knamed entity recognition

We list the public results from DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT and our results below for comparison.

Public results:

Model (public)MNLISQuADCoNLL-2003
DistilBERT (T6)81.6 / 81.178.1 / 86.2-
BERT<sub>6</sub>-PKD (T6)81.5 / 81.077.1 / 85.3-
BERT-of-Theseus (T6)82.4/ 82.1--
BERT<sub>3</sub>-PKD (T3)76.7 / 76.3--
TinyBERT (T4-tiny)82.8 / 82.972.7 / 82.1-

Our results:

Model (ours)MNLISQuADCoNLL-2003
BERT-base-cased (teacher)83.7 / 84.081.5 / 88.691.1
BiGRU--85.3
T683.5 / 84.080.8 / 88.190.7
T381.8 / 82.776.4 / 84.987.5
T3-small81.3 / 81.772.3 / 81.478.6
T4-tiny82.0 / 82.675.2 / 84.089.1
T12-nano83.2 / 83.979.0 / 86.689.6

Note:

  1. The equivalent model structures of public models are shown in the brackets after their names.
  2. When distilling to T4-tiny, NewsQA is used for data augmentation on SQuAD and HotpotQA is used for data augmentation on CoNLL-2003.
  3. When distilling to T12-nano, HotpotQA is used for data augmentation on CoNLL-2003.

Results on Chinese Datasets

We experiment on the following typical Chinese datasets:

DatasetTask typeMetrics#Train#DevNote
XNLItext classificationAcc393K2.5KChinese translation version of MNLI
LCQMCtext classificationAcc239K8.8Ksentence-pair matching, binary classification
CMRC 2018reading comprehensionEM/F110K3.4Kspan-extraction machine reading comprehension
DRCDreading comprehensionEM/F127K3.5Kspan-extraction machine reading comprehension (Traditional Chinese)
MSRA NERsequence labelingF145K3.4K (#Test)Chinese named entity recognition

The results are listed below.

ModelXNLILCQMCCMRC 2018DRCD
RoBERTa-wwm-ext (teacher)79.989.468.8 / 86.486.5 / 92.5
T378.489.066.4 / 84.278.2 / 86.4
T3-small76.088.158.0 / 79.375.8 / 84.8
T4-tiny76.288.461.8 / 81.877.3 / 86.1
ModelXNLILCQMCCMRC 2018DRCDMSRA NER
Electra-base (teacher))77.889.865.6 / 84.786.9 / 92.395.14
Electra-small77.789.366.5 / 84.985.5 / 91.393.48

Note:

  1. Learning rate decay is not used in distillation on CMRC 2018 and DRCD.
  2. CMRC 2018 and DRCD take each other as the augmentation dataset in the distillation.
  3. The settings of training Electra-base teacher model can be found at Chinese-ELECTRA.
  4. Electra-small student model is initialized with the pretrained weights.

Core Concepts

Configurations

Distillers

Distillers are in charge of conducting the actual experiments. The following distillers are available:

User-Defined Functions

In TextBrewer, there are two functions that should be implemented by users: callback and adaptor.

Callback

At each checkpoint, after saving the student model, the callback function will be called by the distiller. A callback can be used to evaluate the performance of the student model at each checkpoint.

Adaptor

It converts the model inputs and outputs to the specified format so that they could be recognized by the distiller, and distillation losses can be computed. At each training step, batch and model outputs will be passed to the adaptor; the adaptor re-organizes the data and returns a dictionary.

For more details, see the explanations in Full Documentation.

FAQ

Q: How to initialize the student model?

A: The student model could be randomly initialized (i.e., with no prior knowledge) or be initialized by pre-trained weights. For example, when distilling a BERT-base model to a 3-layer BERT, you could initialize the student model with RBT3 (for Chinese tasks) or the first three layers of BERT (for English tasks) to avoid cold start problem. We recommend that users use pre-trained student models whenever possible to fully take advantage of large-scale pre-training.

Q: How to set training hyperparameters for the distillation experiments?

A: Knowledge distillation usually requires more training epochs and larger learning rate than training on the labeled dataset. For example, training SQuAD on BERT-base usually takes 3 epochs with lr=3e-5; however, distillation takes 30~50 epochs with lr=1e-4. The conclusions are based on our experiments, and you are advised to try on your own data.

Q: My teacher model and student model take different inputs (they do not share vocabularies), so how can I distill?

A: You need to feed different batches to the teacher and the student. See the section Feed Different batches to Student and Teacher, Feed Cached Values in the full documentation.

Q: I have stored the logits from my teacher model. Can I use them in the distillation to save the forward pass time?

A: Yes, see the section Feed Different batches to Student and Teacher, Feed Cached Values in the full documentation.

Known Issues

Citation

If you find TextBrewer is helpful, please cite our paper:

@InProceedings{textbrewer-acl2020-demo,
    title = "{T}ext{B}rewer: {A}n {O}pen-{S}ource {K}nowledge {D}istillation {T}oolkit for {N}atural {L}anguage {P}rocessing",
    author = "Yang, Ziqing and Cui, Yiming and Chen, Zhipeng and Che, Wanxiang and Liu, Ting and Wang, Shijin and Hu, Guoping",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations",
    year = "2020",
    publisher = "Association for Computational Linguistics",
    url = "https://www.aclweb.org/anthology/2020.acl-demos.2",
    pages = "9--16",
}

Follow Us

Follow our official WeChat account to keep updated with our latest technologies!