Home

Awesome

Speculative Decoding

Introduction

This repository contains the code used to reimplement our paper: Speculative Decoding: Exploiting Speculative Execution for Accelerating Seq2seq Generation.

SpecDec

Download model

DescriptionModel
wmt14.en-dear-verifier-basenar-drafter-base (k=25)
wmt14.de-enar-verifier-basenar-drafter-base (k=25)
wmt16.en-roar-verifier-basenar-drafter-base (k=25)
wmt16.ro-enar-verifier-basenar-drafter-base (k=25)

Requirements

Installation

conda create -n specdec python=3.7
cd SpecDec
pip install --editable .

Preprocess

The datasets we used can be obtained following the script released by Mask-Predict. We release the bpe codes and our dicts in ./data. The tokenized datasets are preprocessed as follows:

text=PATH_TO_YOUR_DATA
src=source_language
tgt=target_language
bin_path=PATH_TO_BIN_DIR

model_path=PATH_TO_MODEL_DICT_DIR

fairseq-preprocess --source-lang ${src} --target-lang ${tgt} \
    --trainpref $text/train --validpref $text/valid --testpref $text/test \
    --destdir ${bin_path} --workers 60 \
    --srcdict ${model_path}/dict.${src}.txt \
    --tgtdict ${model_path}/dict.${tgt}.txt

Encoder Initialization

We recommend using the AR verifier's encoder to initialize the weights of the NAR drafter. For preparing the initialization checkpoints, check encoder_initial.py.

Train

The AR verifier of SpecDec is a standard Transformer that can be trained with fairseq:

fairseq-train ${bin_path} --arch transformer --share-all-embeddings \
      --task translation --source-lang ${src} --target-lang ${tgt} \
      --criterion label_smoothed_cross_entropy --dropout ${dropout} \
      --label-smoothing 0.1 --lr ${lr} --clip-norm 3.0 \
      --warmup-init-lr 1e-7 --min-lr 1e-9 --lr-scheduler inverse_sqrt \
      --weight-decay 0.00001 --update-freq ${update_freq} --fp16 --seed ${seed} \
      --warmup-updates ${warmup} --optimizer adam \
      --adam-betas '(0.9, 0.98)' --max-tokens ${max_tokens} --max-epoch ${max_epoch} \
      --save-dir ./checkpoints \
      --eval-bleu \
      --eval-bleu-args '{"beam":5}' \
      --eval-bleu-detok moses \
      --eval-bleu-remove-bpe \
      --eval-bleu-print-samples \
      --best-checkpoint-metric bleu --maximize-best-checkpoint-metric

For training the NAR drafter of SpecDec (check train.sh):

python train.py ${bin_path} --arch block --noise block_mask --share-all-embeddings \
    --criterion glat_loss --label-smoothing 0.1 --lr ${lr} --warmup-init-lr 1e-7 \
    --stop-min-lr 1e-9 --lr-scheduler inverse_sqrt --warmup-updates ${warmup} \
    --optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-6 \
    --task translation_lev_modified --max-tokens ${max_tokens} --weight-decay 0.01 \
    --dropout ${dropout} --encoder-layers 6 --encoder-embed-dim 512 --decoder-layers 6 \
    --decoder-embed-dim 512 --fp16 --max-source-positions 1000 \
    --max-target-positions 1000 --max-update ${update} --seed ${seed} --clip-norm 5 \
    --save-dir ./checkpoints --src-embedding-copy --log-interval 1000 \
    --user-dir specdec_plugins --block-size ${size} --total-up ${update} \
    --update-freq ${update_freq} --decoder-learned-pos --encoder-learned-pos \
    --apply-bert-init --activation-fn gelu \
    --restore-file ./checkpoints/initial_checkpoint.pt \
    --reset-optimizer --reset-meters --reset-lr-scheduler --reset-dataloader

Hyperparameters

The hyperparameters of the NAR drafter are shown as follows:

Hyperparameters \ DatasetsWMT14 EN-DEWMT16 EN-RO
learning rate0.00050.001
dropout0.10.2
warm up100004000
max update300K50K
batch size (tokens)128K64K

the effective batch size of tokens is calculated by GPU_NUM * MAX_TOKENS * UPDATE_FREQ.

Inference

For SpecDec (check inference.sh, set beta=1 for identical results to AR greedy decoding):

python inference.py ${data_dir} --path ${checkpoint_path} --user-dir specdec_plugins \
    --task translation_lev_modified --remove-bpe --max-sentences 20 \
    --source-lang ${src} --target-lang ${tgt} --iter-decode-max-iter 0 \
    --iter-decode-eos-penalty 0 --iter-decode-with-beam 1 --gen-subset test \
    --AR-path ${AR_checkpoint_path} --input-path ${input_path} \
    --output-path ${output_path} --block-size ${block_size} --beta ${beta} --tau ${tau} \
    --batch ${batch} --beam ${beam} --strategy ${strategy}

We test the inference latency of SpecDec with batch 1 implementation, check inference_paper.py for details.

check inference_drafter.py for inference with our NAR drafter only.

Calculating compound split bleu:

./ref.sh

Example

We put the first three tokenized sentences of WMT14 EN-DE in data/wmt14.en-de/example.en. Put this file in the input_path of the inference script. The results below were obtained by running inference.sh with inference_paper.py (on 1 Nvidia P100 GPU, Pytorch 1.10, CUDA 11).

ModelAccepted Tokens (average)Latency (s)
Fairseq (beam5)1.000.83
Fairseq (beam1)1.000.81
SpecDec6.180.27

You can find the translation results in ./output.

Extra Memory Cost

Since there is no need to save intermediate variables during inference, SpecDec can achieve 3x~5x decoding speedup (by alternating NAR and AR decoding) with only ~300MiB of extra memory cost. Below is the nvidia-smi memory cost comparison of AR and SpecDec, tested on WMT14 EN-DE:

Model \ Batch SizeModel States (Params)1481632
Fairseq (beam1)232.3816701712175818442028
SpecDec469.75 (AR + NAR)19021938201221082298
Extra Memory237.38 (NAR)232226254264270

Note

This code is based on GLAT (https://github.com/FLC777/GLAT).

Citation

If you find the resources in this repository useful, please cite our paper:

@inproceedings{xia-etal-2023-speculative,
    title = "Speculative Decoding: Exploiting Speculative Execution for Accelerating Seq2seq Generation",
    author = "Xia, Heming  and
      Ge, Tao  and
      Wang, Peiyi  and
      Chen, Si-Qing  and
      Wei, Furu  and
      Sui, Zhifang",
    editor = "Bouamor, Houda  and
      Pino, Juan  and
      Bali, Kalika",
    booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2023",
    month = dec,
    year = "2023",
    address = "Singapore",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2023.findings-emnlp.257",
    pages = "3909--3925",
}