Home

Awesome

Winner-Take-All Column Row Sampling for Memory Efficient Adaptation of Language Model

This is the official codes for Winner-Take-All Column Row Sampling for Memory Efficient Adaptation of Language Model.

Introduction

As the model size grows rapidly, fine-tuning the large pre-trained language model has become increasingly difficult due to its extensive memory usage. While previous approaches focused on reducing trainable parameters, the primary memory bottleneck is storing feature maps (activations) crucial for gradient calculation. The proposed solution introduces a family of unbiased estimators called WTA-CRS for matrix production, reducing variance and requiring only sub-sampled activations for gradient calculation. Theoretical and experimental evidence demonstrates lower variance compared to existing estimators, enabling up to 2.7× peak memory reduction with minimal accuracy loss and up to 6.4× larger batch sizes in transformers. WTA-CRS facilitates better downstream task performance through larger models and faster training speeds under the same hardware.

Setup

conda create -n approx python=3.9
conda activate approx
pip install torch==2.0.0
pip install -e .
pip install protobuf==3.20.3

Run main Experiments

Run WTA-CRS on T5 and BERT language models:

bash scripts/main_exp.sh

Run LoRA+WTA-CRS on T5 and BERT language models:

bash scripts/lora_exp.sh

Experiment Results

Accuracy of WTA-CRS on the GLUE datasets.

<div align=center> <img width="700" height="300" src="https://github.com/zirui-ray-liu/WTACRS/blob/main/figure/wta_accuracy.png"> </div>

The memory footprint of WTA-CRS.

<div align=center> <img width="800" height="150" src="https://github.com/zirui-ray-liu/WTACRS/blob/main/figure/wta_mem.png"> </div>

Throughput of finetuning using WTA-CRS.

<div align=center> <img width="500" height="300" src="https://github.com/zirui-ray-liu/WTACRS/blob/main/figure/wta_throughput.png"> </div>

Acknowledgment

Our code is based on the official code of Ladder Site Tuning

Cite this work

If you find this project useful, you can cite this work by:

@article{liu2023winner,
  title={Winner-Take-All Column Row Sampling for Memory Efficient Adaptation of Language Model},
  author={Liu, Zirui and Wang, Guanchu and Zhong, Shaochen and Xu, Zhaozhuo and Zha, Daochen and Tang, Ruixiang and Jiang, Zhimeng and Zhou, Kaixiong and Chaudhary, Vipin and Xu, Shuai and others},
  journal={arXiv preprint arXiv:2305.15265},
  year={2023}
}