Home

Awesome

ByteTransformer: Optimized BERT Transformer Inference on NVIDIA GPUs

Introduction

ByteTransformer is a high-performance inference library for BERT-like transformers that offers the following features:

ByteTransformer has been widely deployed to improve in-house transformer inference serving systems at ByteDance, delivering superior performance over other transformer implementations for both fixed-length and variable-length inputs. The technical details have been published at IEEE IPDPS 2023.

Cite Us

If you use our library, please cite our research paper.

@article{zhai2022bytetransformer,
  title={ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs},
  author={Zhai, Yujia and Jiang, Chengquan and Wang, Leyuan and Jia, Xiaoying and Zhang, Shang and Chen, Zizhong and Liu, Xin and Zhu, Yibo},
  journal={arXiv preprint arXiv:2210.03052},
  year={2022}
}

Performance and Speedup

We compared ByteTransformer with PyTorch, TensorFlow, FasterTransformer, and DeepSpeed on an A100 GPU. The benchmark script is available in benchmark/bert_bench.sh.

1. Standard BERT batch size = 1, average sequence length = 0.6 * maximal, execution time in millisecond:

PyTorchTensorflowFasterTransformerFasterTransformer with remove paddingDeepSpeedByteTransformer
642.932.461.051.231.170.90
1283.182.61.101.431.280.97
1923.182.811.261.431.401.36
2562.812.91.351.551.511.43
3203.113.241.631.661.841.69
3842.873.431.641.641.951.72
4482.993.612.262.352.231.86
5122.893.742.282.432.372.00
5762.994.032.512.592.702.19
6402.994.542.852.833.172.23
7043.214.673.163.443.322.47
7683.334.883.263.633.462.51
8323.785.393.753.873.972.80
8963.865.814.084.954.372.86
9604.026.274.305.234.663.12
10244.26.374.514.964.863.16

2. Standard BERT batch size = 16, average sequence length = 0.6 * maximal, execution time in millisecond:

PyTorchTensorflowFasterTransformerFasterTransformer with remove paddingDeepSpeedByteTransformer
643.24.572.241.932.812.09
1284.976.973.623.334.543.18
1927.659.375.265.296.685.08
2569.5612.176.775.499.036.85
32013.2115.878.856.4712.817.49
38415.0118.5610.377.0515.198.44
44819.0623.0115.9712.5418.838.89
5122126.0318.0313.7921.559.22
57624.3331.2421.1117.6526.210.15
64028.0335.0724.5220.3430.2412.04
70432.3341.4328.9424.5234.6513.55
76835.3144.6232.0928.2137.9516.3
83240.7551.8736.3331.6945.3216.92
89644.4755.6542.1738.0549.4820.67
96049.7263.5947.0142.9855.7223.27
102453.2165.9450.2845.2259.9624.70

Supported Models

Currently, only the standard BERT transformer encoder is available under this repository.

Environment requirements

Tested on: A100 + CUDA 11.6 + PyTorch 1.13.0+cu116 + Python 3.9.16

Building from Source

To build from source, run the following commands:

git submodule update --init
mkdir build && cd build
cmake -DTORCH_CUDA_ARCH_LIST="8.0" -DDataType=FP16 -DBUILD_THS=ON -DCUDAARCHS="80" ..
make

Getting Started with Unit Tests

Unit Tests in C++

To generate test data, run the following code:

cd build
# batch sz = 16, seqlen = 64, head num = 12, head sz = 64, avg seqlen = 32
python3 bert_transformer_test.py 16 64 12 64 --avg_seqlen 32 --dtype fp16 --export_data

Here, 16, 64, 12, and 64 represent batch size, sequence length, number of heads, and head size, respectively. The --avg_seqlen 32 flag is used to set the average sequence length, --dtype fp16 sets the data type, and --export_data exports the test data.

After test data is generated (*.in and *.out files are saved under the current directory), run the following command:

./bin/bert_transformer_test 16 64 12 64

Here, the arguments represent the same parameters as used in generating the test data.

Unit Tests in a PyTorch Plugin in Python

To perform the unit tests in a PyTorch plugin in Python, use the same script as for C++, but without the --export_data flag. Run the following command in the terminal:

# batch sz = 16, seqlen = 64, head num = 12, head sz = 64, avg seqlen = 32
python3 bert_transformer_test.py 16 64 12 64 --avg_seqlen 32 --dtype fp16

Again, the arguments represent the same parameters as used in generating the test data.

Benchmark

cd build
../benchmark/bert_bench.sh