Home

Awesome

English | 中文说明

<p align="center"> <br> <img src="./pics/banner.png" width="500"/> <br> <p> <p> <p align="center"> <a href="https://github.com/airaria/TextPruner/blob/master/LICENSE"> <img alt="GitHub" src="https://img.shields.io/github/license/airaria/TextPruner.svg?color=green&style=flat-square"> </a> <a href="https://TextPruner.readthedocs.io/"> <img alt="Documentation" src="https://img.shields.io/website?down_message=offline&label=Documentation&up_message=online&url=https%3A%2F%2FTextPruner.readthedocs.io"> </a> <a href="https://pypi.org/project/TextPruner"> <img alt="PyPI" src="https://img.shields.io/pypi/v/TextPruner"> </a> <a href="https://github.com/airaria/TextPruner/releases"> <img alt="GitHub release" src="https://img.shields.io/github/v/release/airaria/TextPruner?include_prereleases"> </a> </p>

TextPruner is a model pruning toolkit for pre-trained language models. It provides low-cost and training-free methods to reduce your model size and speed up your model inference speed by removing redundant neurons.

You may also be interested in,

News

Table of Contents

<!-- TOC -->
SectionContents
IntroductionIntroduction to TextPruner
InstallationRequirements and how to install
Pruning ModesA brief introduction to the three pruning modes
UsageA quick guide on how to use TextPruner
ExperimentsPruning experiments on typical tasks
FAQFrequently asked questions
Follow Us-

Introduction

TextPruner is a toolkit for pruning pre-trained transformer-based language models written in PyTorch. It offers structured training-free pruning methods and a user-friendly interface.

The main features of TexPruner include:

TextPruner currently supports vocabulary pruning and transformer pruning. For the explanation of the pruning modes, see Pruning Modes.

To use TextPruner, users can either import TextPruner into the python scripts or run the TextPruner command line tool. See the examples in Usage.

For the performance of the pruned model on typical tasks, see Experiments.

Paper: TextPruner: A Model Pruning Toolkit for Pre-Trained Language Models

Supporting Models

TextPruner currently supports the following pre-trained models in transformers:

ModelVocabualry PruningTransformer Pruning
BERT:heavy_check_mark::heavy_check_mark:
ALBERT:heavy_check_mark::heavy_check_mark:
RoBERTa:heavy_check_mark::heavy_check_mark:
ELECTRA:heavy_check_mark::heavy_check_mark:
XLM-RoBERTa:heavy_check_mark::heavy_check_mark:
XLM:heavy_check_mark::x:
BART:heavy_check_mark::x:
T5:heavy_check_mark::x:
mT5:heavy_check_mark::x:

See the online documentation for the API reference.

Installation

Pruning Modes

In TextPruner, there are three pruning modes: vocabulary pruning, transformer pruning and pipeline pruning.

Vocabulary Pruning

The pre-trained models usually have a large vocabulary, but some tokens rarely appear in the datasets of the downstream tasks. These tokens can be removed to reduce the model size and accelerate MLM pre-training.

Transformer Pruning

AP

Another approach is pruning the transformer blocks. Some studies have shown that not all attention heads are equally important in the transformers. TextPruner reduces the model size and keeps the model performance as high as possible by locating and removing the unimportant attention heads and the feed-forward networks' neurons.

Pipeline Pruning

In pipeline pruning, TextPruner performs transformer pruning and vocabulary pruning successively to fully reduce the model size.

Usage

The pruners perform the pruning process. The configurations set their behaviors. There names are self-explained:

See the online documentation for the API reference. The Configurations are explained in Configurations. We demonstrate the basic usage below.

Vocabulary Pruning

To perform vocabulary pruning, users should provide a text file or a list of strings. The tokens that do not appear in the texts are removed from the model and the tokenizer.

See the examples at examples/vocabulary_pruning and examples/vocabulary_pruning_xnli.

Use TextPruner as a package

Pruning the vocabulary in 3 lines of code:

from textpruner import VocabularyPruner
pruner = VocabularyPruner(model, tokenizer)
pruner.prune(dataiter=texts)

VocabularyPruner accepts GeneralConfig and VocabularyPruningConfig for fine control. By default we could omit them. See the API reference for details.

Use TextPruner-CLI tool

textpruner-cli  \
  --pruning_mode vocabulary \
  --configurations gc.json vc.json \
  --model_class XLMRobertaForSequenceClassification \
  --tokenizer_class XLMRobertaTokenizer \
  --model_path /path/to/model/and/config/directory \
  --vocabulary /path/to/a/text/file

Transformer Pruning

See the examples at examples/transformer_pruning.

For self-supervised pruning, see the examples examples/transformer_pruning_xnli.

Use TextPruner as a package

from textpruner import TransformerPruner, TransformerPruningConfig
transformer_pruning_config = TransformerPruningConfig(
      target_ffn_size=2048, 
      target_num_of_heads=8, 
      pruning_method='iterative',
      n_iters=4)
pruner = TransformerPruner(model,transformer_pruning_config=transformer_pruning_config)   
pruner.prune(dataloader=dataloader, save_model=True)

TransformerPruner accepts GeneralConfig and TransformerPruningConfig for fine control. See the API reference for details.

Use TextPruner-CLI tool

textpruner-cli  \
  --pruning_mode transformer \
  --configurations gc.json tc.json \
  --model_class XLMRobertaForSequenceClassification \
  --tokenizer_class XLMRobertaTokenizer \
  --model_path ../models/xlmr_pawsx \
  --dataloader_and_adaptor dataloader_script

Pipeline Pruning

Pipeline pruning combines transformer pruning and vocabulary pruning into a single call.

See the examples at examples/pipeline_pruning.

Use TextPruner as a package

from textpruner import PipelinePruner, TransformerPruningConfig
transformer_pruning_config = TransformerPruningConfig(
    target_ffn_size=2048, target_num_of_heads=8, 
    pruning_method='iterative',n_iters=4)
pruner = PipelinePruner(model, tokenizer, transformer_pruning_config=transformer_pruning_config)
pruner.prune(dataloader=dataloader, dataiter=texts, save_model=True)

PipelinePruner accepts GeneralConfig, VocabularyPruningConfig and TransformerPruningConfig for fine control. See the API reference for details.

Use TextPruner-CLI tool

textpruner-cli  \
  --pruning_mode pipeline \
  --configurations gc.json tc.json vc.json \
  --model_class XLMRobertaForSequenceClassification \
  --tokenizer_class XLMRobertaTokenizer \
  --model_path ../models/xlmr_pawsx \
  --vocabulary /path/to/a/text/file \
  --dataloader_and_adaptor dataloader_script

Configurations

The pruning process is configured by the configuration objects:

They are used in different pruning modes:

The configurations are dataclass objects (used in the python scripts) or JSON files (used in the command line). If no configurations are provided, TextPruner will use the default configurations. See the API reference for details.

In the python script:

from textpruner import GeneralConfig, VocabularyPruningConfig, TransformerPruningConfig
from textpruner import VocabularyPruner, TransformerPruner, PipelinePruner

#GeneralConfig
general_config = GeneralConfig(device='auto',output_dir='./pruned_models')

#VocabularyPruningConfig
vocabulary_pruning_config = VocabularyPruningConfig(min_count=1,prune_lm_head='auto')

#TransformerPruningConfig
#Pruning with the given masks 
transformer_pruning_config = TransformerPruningConfig(pruning_method = 'masks')

#TransformerPruningConfig
#Pruning on labeled dataset iteratively
transformer_pruning_config = TransformerPruningConfig(
    target_ffn_size  = 2048,
    target_num_of_heads = 8,
    pruning_method = 'iterative',
    ffn_even_masking = True,
    head_even_masking = True,
    n_iters = 1,
    multiple_of = 1
)

As JSON files:

Helper functions

Example:

from transformers import BertForMaskedLM
import textpruner
import torch

model = BertForMaskedLM.from_pretrained('bert-base-uncased')
print("Model summary:")
print(textpruner.summary(model,max_level=3))

dummy_inputs = [torch.randint(low=0,high=10000,size=(32,512))]
print("Inference time:")
textpruner.inference_time(model.to('cuda'),dummy_inputs)

Outputs:

Model summary:
LAYER NAME                          	        #PARAMS	     RATIO	 MEM(MB)
--model:                            	    109,514,810	   100.00%	  417.77
  --bert:                           	    108,892,160	    99.43%	  415.39
    --embeddings:                   	     23,837,696	    21.77%	   90.94
      --position_ids:               	            512	     0.00%	    0.00
      --word_embeddings:            	     23,440,896	    21.40%	   89.42
      --position_embeddings:        	        393,216	     0.36%	    1.50
      --token_type_embeddings:      	          1,536	     0.00%	    0.01
      --LayerNorm:                  	          1,536	     0.00%	    0.01
    --encoder
      --layer:                      	     85,054,464	    77.66%	  324.46
  --cls
    --predictions(partially shared):	        622,650	     0.57%	    2.38
      --bias:                       	         30,522	     0.03%	    0.12
      --transform:                  	        592,128	     0.54%	    2.26
      --decoder(shared):            	              0	     0.00%	    0.00

Inference time:
Device: cuda:0
Mean inference time: 1214.41ms
Standard deviation: 2.39ms

Experiments

We prune a XLM-RoBERTa-base classification model trained on the Multilingual Natural Language Inference (NLI) task PAWS-X. The model is fine-tuned and evaluated on the Egnlish dataset.

Vocabulary Pruning

We use a 100k-lines subset of XNLI English training set as the vocabulary file. The pruning result is listed below.

ModelTotal size (MB)Vocab sizeAcc on en (%)
XLM-RoBERTa-base1060 (100%)25000294.65
+ Vocabulary Pruning398 (37.5%)2393694.20

Transfomer Pruning

We denote the model structure as (H, F) where H is the average number of attention heads per layer, F is the average FFN hidden size per layer. With this notation, (12,3072) stands for the original XLM-RoBERTa model. In addition we consider (8, 2048) and (6, 1536).

Inference time

The speed is measured on inputs of length 512 and batch size 32. Each layer of the model has the same number of attention heads and FFN hidden size.

ModelTotal size (MB)Encoder size (MB)Inference time (ms)Speed up
(12, 3072)106032410121.0x
(8, 2048)9522166661.5x
(6, 1536)8991625042.0x

Performance

We prune the model with different numbers of iterations (n_iters). The accuracies are listed below:

Modeln_iters=1n_iters=2n_iters=4n_iters=8n_iters=16
(12, 3072)94.65----
(8, 2048)93.3093.6093.6093.8593.95
(8, 2048) with uneven heads92.9593.5093.9594.0594.25
(6, 1536)85.1589.1090.9090.6090.85
(6, 1536) with uneven heads45.3586.4590.5590.9091.95

uneven heads means the number of attention heads may vary from layer to layer. With the same model structure, the performance increases as we increase the number of iterations n_iters.

FAQ

Q: Does TextPruner support Tensorflow 2 ?

A: No.

Q: Can you compare the knowledge distillation and model pruning? Which one should I use ?

A: Both model pruning and knowledge distillation are popular approaches for reducing the model size and accelerating model speed.

(There are some pruning methods that involves training can also achieve a high compression ratio)

If you are interested in applying knowledge distillation, please refer to our TextBrewer.

if you want to achieve the best performance, you may consider applying both distillation and pruning.

Citation

If you find TextPruner is helpful, please cite our paper:

@inproceedings{yang-etal-2022-textpruner,
    title = "{T}ext{P}runer: A Model Pruning Toolkit for Pre-Trained Language Models",
    author = "Yang, Ziqing  and
      Cui, Yiming  and
      Chen, Zhigang",
    booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics: System Demonstrations",
    month = may,
    year = "2022",
    address = "Dublin, Ireland",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2022.acl-demo.4",
    pages = "35--43"
}

Follow Us

Follow our official WeChat account to keep updated with our latest technologies!