Awesome
FlatQuant: Flatness Matters for LLM Quantization
This repository contains the PyTorch implementation of FlatQuant: Flatness Matters for LLM Quantization.
FlatQuant leverages Fast and Learnable Affine Transformations tailored for each linear layer to alleviate outliers in LLMs. Additionally, as indicated by the name, it also achieves pretty flat weights and activations that are friendly to quantization. FlatQuant significantly enhances the quantization accuracy under a low-bit quantization setting (i.e., W4A4) while introducing little inference overhead, which may help promote the deployment of W4A4-quantized LLMs.
News
- [2024/10] 🔥 FlatQuant is publicly released! Check our paper here.
Contents
Preparations
Installation
conda create -n flatquant python=3.10 -y
conda activate flatquant
pip install -r requirements.txt && pip install -e . && pip install triton==3.0.0
Note: To run models like LLaMA-3.1 or Qwen-2.5, we use transformers==4.45.0
instead.
Data Preparation
Download datasets in ./datasets
.
Calibration set or PPL evaluation
Dataset | Local Dir | URL |
---|---|---|
WikiText2 | ./datasets/wikitext | https://huggingface.co/datasets/wikitext |
C4 | ./datasets/allenai/c4 | https://huggingface.co/datasets/allenai/c4 |
Pile | ./datasets/pile-val-backup | https://huggingface.co/datasets/mit-han-lab/pile-val-backup |
Commonsense QA evaluation
Dataset | Local Dir | URL |
---|---|---|
ARC-E and ARC-C | ./datasets/ai2_arc | https://huggingface.co/datasets/ai2_arc |
HellaSwag | ./datasets/hellaswag | https://huggingface.co/datasets/hellaswag |
LAMBADA | ./datasets/lambada_openai | https://huggingface.co/datasets/EleutherAI/lambada_openai |
PIQA | ./datasets/piqa | https://huggingface.co/datasets/ybisk/piqa |
WinoGrande | ./datasets/winogrande | https://huggingface.co/datasets/winogrande |
Model Preparation
Download models in ./modelzoo
.
Model | Local Dir | URL |
---|---|---|
LLaMA-2-7B | ./modelzoo/llama-2/llama-2-7b | https://huggingface.co/meta-llama/Llama-2-7b |
LLaMA-2-13B | ./modelzoo/llama-2/llama-2-13b | https://huggingface.co/meta-llama/Llama-2-13b |
LLaMA-2-70B | ./modelzoo/llama-2/llama-2-70b | https://huggingface.co/meta-llama/Llama-2-70b |
LLaMA-3-8B | ./modelzoo/llama-3/llama-3-8b | https://huggingface.co/meta-llama/Meta-Llama-3-8B |
LLaMA-3-70B | ./modelzoo/llama-3/llama-3-70b | https://huggingface.co/meta-llama/Meta-Llama-3-70B |
Usage
Calibration
We provide full script to run FlatQuant in ./scripts/. We use LLaMa-3-8B as an example here:
- Weight-Activation-KV Cache Quantization
# W4A4KV4
python ./main.py \
--model ./modelzoo/llama-3/llama-3-8b \
--w_bits 4 --a_bits 4 \
--k_bits 4 --k_asym --k_groupsize 128 \
--v_bits 4 --v_asym --v_groupsize 128 \
--cali_bsz 4 --epoch 15 --flat_lr 5e-3 \
--lwc --lac --cali_trans --add_diag \
--output_dir ./outputs --save_matrix \
--lm_eval --lm_eval_batch_size 16
- Weight-Only Quantization
# W4A16
python ./main.py \
--model ./modelzoo/llama-3/llama-3-8b \
--w_bits 4 \
--cali_bsz 4 --epoch 15 --flat_lr 5e-3 \
--lwc --lac --cali_trans --add_diag \
--output_dir ./outputs --exp_name wonly --save_matrix \
--lm_eval --lm_eval_batch_size 16
-
Reproduce Evaluation Results of Our Paper
1) Download the pretrained FlatQuant parameters you want through modelzoo.
2) Inference with
--reload_matrix
and--matrix_path PATH_TO_XXX
, take LLaMa-3-8B with W4A4KV4 quantization as an example:
python ./main.py \
--model ./modelzoo/llama-3/llama-3-8b \
--w_bits 4 --a_bits 4 \
--k_bits 4 --k_asym --k_groupsize 128 \
--v_bits 4 --v_asym --v_groupsize 128 \
--cali_bsz 4 --epoch 15 --flat_lr 5e-3 \
--lwc --lac --cali_trans --add_diag \
--output_dir ./outputs --save_matrix \
--lm_eval --lm_eval_batch_size 16 \
--reload_matrix --matrix_path PATH_TO_XXX
More detailed and optional arguments:
--model
: The model name or the path to the model weights.--w_bits
: Number of bits for weight quantization.--a_bits
: Number of bits for activation quantization.--k_bits
: Number of bits for key cache quantization.--k_asym
: Enable asymmetric quantization for the key cache.--k_groupsize
: Group size for key cache quantization.--v_bits
: Number of bits for value cache quantization.--v_asym
: Enable asymmetric quantization for the value cache.--v_groupsize
: Group size for value cache quantization.--cali_bsz
: The batch size used in the calibration of FlatQuant.--epochs
: Number of training epochs for the calibration of FlatQuant.--flat_lr
: Learning rate for FlatQuant calibration.--cali_trans
: Enable the calibration of transformations.--lwc
: Use learnable weight clipping during quantization.--lac
: Use learnable activation clipping during quantization.--add_diag
: Add per-channel scaling to transformations.--save_matrix
: Save the matrix-style parameters of FlatQuant.--reload_matrix
: Load the pre-trained matrix-style parameters of FlatQuant.--matrix_path
: Path to the pre-trained matrix-style parameters of FlatQuant.--deactive_amp
: Disable AMP (automatic mixed precision) during training.--direct_inv
: Use PyTorch's inverse method to compute the inverse matrix instead of the SVD method.--lm_eval
: Evaluate the model on language model (LM) evaluation tasks.--lm_eval_batch_size
: Batch size for evaluation using the LM eval harness.
Inference Latency
To measure the speedup of FlatQuant and our efficient kernel, run the corresponding benchmark commands provided below:
# Run end-to-end latency benchmark
python ./benchmarks/layer_benchmark.py
# Run kernel latency benchmark
python ./benchmarks/kernel_benchmark.py
# Run linear layer latency benchmark
python ./benchmarks/qlinear_benchmark.py
# Run attention latency benchmark
python ./benchmarks/qattention_benchmark.py
Apply to other models
To apply FlatQuant in your own models, some modifications are required in the forward pass of the model, particularly within the Attention and MLP modules. You can refer to flatquant/model_tools for our implementations of LLaMA2, LLaMA3, LLaMA3.1, and Qwen2.5.
Efficient Kernel
The detailed implementation of our efficient kernel can be found in deploy/kernels/kron_matmul.py and deploy/kernels/block_matmul.py.
Model Zoo
The model zoo is coming soon.
Results
Accuracy Results
Table 1: WikiText-2 perplexity of 4-bit weight & acitvation quantized LLaMA models.
Method | W Quantizer | 2-7B | 2-13B | 2-70B | 3-8B | 3-70B |
---|---|---|---|---|---|---|
FP16 | - | 5.47 | 4.88 | 3.32 | 6.14 | 2.86 |
SmoothQuant | RTN | 83.12 | 35.88 | 26.01 | 210.19 | 9.60 |
OmniQuant | RTN | 14.74 | 12.28 | - | - | - |
AffineQuant | RTN | 12.69 | 11.45 | - | - | - |
QuaRot | RTN | 8.56 | 6.10 | 4.14 | 10.60 | 55.44 |
SpinQuant | RTN | 6.14 | 5.44 | 3.82 | 7.96 | 7.58 |
FlatQuant | RTN | 5.79 | 5.12 | 3.55 | 6.98 | 3.78 |
QUIK-4B | GPTQ | 8.87 | 7.78 | 6.91 | - | - |
QuaRot | GPTQ | 6.10 | 5.40 | 3.79 | 8.16 | 6.60 |
SpinQuant | GPTQ | 5.96 | 5.24 | 3.70 | 7.39 | 6.21 |
FlatQuant | GPTQ | 5.78 | 5.11 | 3.54 | 6.90 | 3.77 |
Table 2: C4 perplexity of 4-bit weight & acitvation quantized LLaMA models.
Method | W Quantizer | 2-7B | 2-13B | 2-70B | 3-8B | 3-70B |
---|---|---|---|---|---|---|
FP16 | - | 7.26 | 6.73 | 5.71 | 9.45 | 7.17 |
SmoothQuant | RTN | 77.27 | 43.19 | 34.61 | 187.93 | 16.90 |
OmniQuant | RTN | 21.40 | 16.24 | - | - | - |
AffineQuant | RTN | 15.76 | 13.97 | - | - | - |
QuaRot | RTN | 11.86 | 8.67 | 6.42 | 17.19 | 79.48 |
SpinQuant | RTN | 9.19 | 8.11 | 6.26 | 13.45 | 15.39 |
FlatQuant | RTN | 7.79 | 7.09 | 5.91 | 11.13 | 7.86 |
QUIK-4B | GPTQ | - | - | - | - | - |
QuaRot | GPTQ | 8.32 | 7.54 | 6.12 | 13.38 | 12.87 |
SpinQuant | GPTQ | 8.28 | 7.48 | 6.07 | 12.19 | 12.82 |
FlatQuant | GPTQ | 7.86 | 7.11 | 5.92 | 11.21 | 7.93 |
Table 3: Zero-shot QA task results of 4-bit weight & activation quantized LLaMA models.
Method | W Quantizer | 2-7B | 2-13B | 2-70B | 3-8B | 3-70B |
---|---|---|---|---|---|---|
FP16 | - | 69.79 | 72.55 | 77.05 | 73.23 | 79.95 |
QuaRot | RTN | 57.73 | 66.25 | 73.47 | 61.34 | 35.36 |
SpinQuant | RTN | 63.52 | 68.56 | 75.09 | 66.98 | 65.66 |
FlatQuant | RTN | 67.96 | 71.42 | 76.62 | 71.23 | 79.01 |
QuaRot | GPTQ | 65.01 | 68.91 | 75.68 | 65.79 | 70.45 |
SpinQuant | GPTQ | 66.23 | 70.93 | 76.06 | 68.70 | 71.66 |
FlatQuant | GPTQ | 67.47 | 71.64 | 76.53 | 71.33 | 78.58 |
Latency Results
Table 4: Prefill speedup of LLaMA-2-7B model across different batch sizes on one RTX3090 GPU. We decode 256 tokens after the prefill on a sequence length of 2048.
Batch Size | Int4 | QuaRot | FlatQuant |
---|---|---|---|
1 | 2.17 | 1.97 | 2.12 |
2 | 2.21 | 1.99 | 2.16 |
4 | 2.25 | 2.04 | 2.21 |
8 | 2.28 | 2.05 | 2.23 |
16 | 2.32 | 2.08 | 2.27 |
32 | 2.35 | 2.09 | 2.28 |
64 | 2.37 | 2.11 | 2.30 |
Table 5: Decoding speedup of LLaMA-2-7B model across different batch sizes on one RTX3090 GPU. We decode 256 tokens after the prefill on a sequence length of 2048.
Batch Size | Int4 | QuaRot | FlatQuant |
---|---|---|---|
1 | 0.81 | 0.70 | 0.71 |
2 | 0.78 | 0.66 | 0.69 |
4 | 0.82 | 0.74 | 0.73 |
8 | 0.97 | 0.83 | 0.83 |
16 | 1.18 | 1.01 | 1.05 |
32 | 1.50 | 1.38 | 1.43 |
64 | 1.83 | 1.75 | 1.76 |
Acknowledgements
This project is based on the work of the following projects:
We are grateful for the contributions provided by these projects.
References
If you find FlatQuant helpful, please cite our paper:
@article{sun2024flatquant,
title={FlatQuant: Flatness Matters for LLM Quantization},
author={Sun, Yuxuan and Liu, Ruikang and Bai, Haoli and Bao, Han and Zhao, Kang and Li, Yuening and Hu, Jiaxin and Yu, Xianzhi and Hou, Lu and Yuan, Chun and others},
journal={arXiv preprint arXiv:2410.09426},
year={2024}
}