Awesome
Winner-Take-All Column Row Sampling for Memory Efficient Adaptation of Language Model
This is the official codes for Winner-Take-All Column Row Sampling for Memory Efficient Adaptation of Language Model.
Introduction
As the model size grows rapidly, fine-tuning the large pre-trained language model has become increasingly difficult due to its extensive memory usage. While previous approaches focused on reducing trainable parameters, the primary memory bottleneck is storing feature maps (activations) crucial for gradient calculation. The proposed solution introduces a family of unbiased estimators called WTA-CRS for matrix production, reducing variance and requiring only sub-sampled activations for gradient calculation. Theoretical and experimental evidence demonstrates lower variance compared to existing estimators, enabling up to 2.7× peak memory reduction with minimal accuracy loss and up to 6.4× larger batch sizes in transformers. WTA-CRS facilitates better downstream task performance through larger models and faster training speeds under the same hardware.
Setup
conda create -n approx python=3.9
conda activate approx
pip install torch==2.0.0
pip install -e .
pip install protobuf==3.20.3
Run main Experiments
Run WTA-CRS on T5 and BERT language models:
bash scripts/main_exp.sh
Run LoRA+WTA-CRS on T5 and BERT language models:
bash scripts/lora_exp.sh
Experiment Results
Accuracy of WTA-CRS on the GLUE datasets.
<div align=center> <img width="700" height="300" src="https://github.com/zirui-ray-liu/WTACRS/blob/main/figure/wta_accuracy.png"> </div>The memory footprint of WTA-CRS.
<div align=center> <img width="800" height="150" src="https://github.com/zirui-ray-liu/WTACRS/blob/main/figure/wta_mem.png"> </div>Throughput of finetuning using WTA-CRS.
<div align=center> <img width="500" height="300" src="https://github.com/zirui-ray-liu/WTACRS/blob/main/figure/wta_throughput.png"> </div>Acknowledgment
Our code is based on the official code of Ladder Site Tuning
Cite this work
If you find this project useful, you can cite this work by:
@article{liu2023winner,
title={Winner-Take-All Column Row Sampling for Memory Efficient Adaptation of Language Model},
author={Liu, Zirui and Wang, Guanchu and Zhong, Shaochen and Xu, Zhaozhuo and Zha, Daochen and Tang, Ruixiang and Jiang, Zhimeng and Zhou, Kaixiong and Chaudhary, Vipin and Xu, Shuai and others},
journal={arXiv preprint arXiv:2305.15265},
year={2023}
}