Home

Awesome

<div align="center"> <h1><img src="figures/logo.png" height="34px" align="top"/>Locret: Accelerating Long-Context LLM Inference with Retaining Heads</h1> <p align="center"> <a href="https://huangyuxiang03.github.io/blogs_locret" target="_blank">Blog</a> | <a href="https://arxiv.org/abs/2410.01805" target="_blank">Paper (ArXiV)</a> </a> </p>

A Light-weight Training-based KV Cache Compression Algorithm

</div>

Locret is a light-weight training-based KV cache eviction method, which achieves 20x and 8x KV cache compression ratio for Phi-3-mini-128K and Llama-3.1-8B-instruct, enabling 128K+ long-context inference on a single Nvidia 4090 GPU.

<div align="center"> <h4>This project was made possible thanks to a collaboration with <img src="figures/univ.png" height="60px" align="center"/></h4> </div>

Design

Overall Framework

Locret Inference

Usage

Note that we only support Phi-3-mini-128K and Llama-3.1-8B-instruct currently. More models will be supported in the future.

Environment Setup

Run the following command, and the environment will be set up automatically.

cd locret
conda env create -f locret_env.yaml
conda activate locret_env
pip install -e .

Official checkpoints of Locret can be downloaded from hyx21/Locret-phi-3-mini-128K and hyx21/Locret-llama-3.1-8B-instruct.

Now you are all set!

Training

First, enter the working directory by

cd locret/train

Then, generate the training dataset. Run the following command.

python data_gen.py --model_dir <model_dir>

Finally, start training the model.

python train.py --model_dir <model_dir>

All the hyperparameters are set automatically according to our paper. Please indicate the model you use in model_dir, i.e. if there is a phi in the path, the training script can detect that.

After the training process, you will be getting the trained checkpoint at locret/train/checkpoints/<model_name>. You can directly load the complete model after save_pretrained or only load the retaining heads. We provide a script, locret/train/convert.py, to convert checkpoints saved in safetensors format to pytorch format.

Inference

Here, we provide an example of one entry in R.PassKey of InfiniteBench in example.py. To run the example, you can execute

python example.py --model_dir <model_dir> # for saved full checkpoint, or
python example.py --model_dir <model_dir> --retaining_head_path <*.bin> # original model + saved retaining heads

For other experiments in our paper, please run the codes in benchmark/infinite_bench and benchmark/LEval-main. Each script correspond to an experimental setting, which can be recognized through the script name.

Develop Roadmap

Citation

Please cite our paper if you find our work valuable.

@article{huang2024locret,
  title={Locret: Accelerating Long-Context LLM Inference with Retaining Heads},
  author={Yuxiang Huang, Binhang Yuan, Xu Han, Chaojun Xiao, Zhiyuan Liu},
  journal={arXiv preprint arXiv:2410.01805},
  year={2024}
}

Acknowledgement

We acknowledge the discussion with Ruisi Cai@UT Austin, the author of LoCoCo, for the implementation of LoCoCo. We deeply thank Xinrong Zhang@Tsinghua for details in InfiniteBench, Weilin Zhao@Tsinghua for system implementation advices, Chenyang Song@Tsinghua for model sparsity, Shuo Wang@Tsinghua for long-context training recipe and Yuan Yao@NUS for future works on multimodal LLMs.