Home

Awesome

Lite Transformer

paper | website | slides

@inproceedings{Wu2020LiteTransformer,
  title={Lite Transformer with Long-Short Range Attention},
  author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2020}
}

Overview

overview

How to Use

Prerequisite

Installation

  1. Codebase

    To install fairseq from source and develop locally:

    pip install --editable .
    
  2. Costumized Modules

    We also need to build the lightconv and dynamicconv for GPU support.

    Lightconv_layer

    cd fairseq/modules/lightconv_layer
    python cuda_function_gen.py
    python setup.py install
    

    Dynamicconv_layer

    cd fairseq/modules/dynamicconv_layer
    python cuda_function_gen.py
    python setup.py install
    

Data Preparation

IWSLT'14 De-En

We follow the data preparation in fairseq. To download and preprocess the data, one can run

bash configs/iwslt14.de-en/prepare.sh

WMT'14 En-Fr

We follow the data pre-processing in fairseq. To download and preprocess the data, one can run

bash configs/wmt14.en-fr/prepare.sh

WMT'16 En-De

We follow the data pre-processing in fairseq. One should first download the preprocessed data from the Google Drive provided by Google. To binarized the data, one can run

bash configs/wmt16.en-de/prepare.sh [path to the downloaded zip file]

WIKITEXT-103

As the language model task has many additional codes, we place it in another branch: language-model. We follow the data pre-processing in fairseq. To download and preprocess the data, one can run

git checkout language-model
bash configs/wikitext-103/prepare.sh

Testing

For example, to test the models on WMT'14 En-Fr, one can run

configs/wmt14.en-fr/test.sh [path to the model checkpoints] [gpu-id] [test|valid]

For instance, to evaluate Lite Transformer on GPU 0 (with the BLEU score on test set of WMT'14 En-Fr), one can run

configs/wmt14.en-fr/test.sh embed496/ 0 test

We provide several pretrained models at the bottom. You can download the model and extract the file by

tar -xzvf [filename]

Training

We provided several examples to train Lite Transformer with this repo:

To train Lite Transformer on WMT'14 En-Fr (with 8 GPUs), one can run

python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml

To train Lite Transformer with less GPUs, e.g. 4 GPUS, one can run

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32

In general, to train a model, one can run

python train.py [path to the data binary] --configs [path to config file] [override options]

Note that --update-freq should be adjusted according to the GPU numbers (16 for 8 GPUs, 32 for 4 GPUs).

Distributed Training (optional)

To train Lite Transformer in distributed manner. For example on two GPU nodes with totally 16 GPUs.

# On host1
python -m torch.distributed.launch \
        --nproc_per_node=8 \
        --nnodes=2 --node_rank=0 \
        --master_addr=host1 --master_port=8080 \
        train.py data/binary/wmt14_en_fr \
        --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml \
        --distributed-no-spawn \
        --update-freq 8
# On host2
python -m torch.distributed.launch \
        --nproc_per_node=8 \
        --nnodes=2 --node_rank=1 \
        --master_addr=host1 --master_port=8080 \
        train.py data/binary/wmt14_en_fr \
        --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml \
        --distributed-no-spawn \
        --update-freq 8

Models

We provide the checkpoints for our Lite Transformer reported in the paper:

Dataset#Mult-AddsTest ScoreModel and Test Set
WMT'14 En-Fr90M35.3download
360M39.1download
527M39.6download
WMT'16 En-De90M22.5download
360M25.6download
527M26.5download
CNN / DailyMail800M38.3 (R-L)download
WIKITEXT-1031147M22.2 (PPL)download