Home

Awesome

<p align="center"> <img src="assets/flute-logo.png" alt="" width="40%" align="top" style="border-radius: 10px; padding-left: 120px; padding-right: 120px; background-color: white;"> </p> <p align="center"> <em><strong>FLUTE</strong>: Flexible Lookup Table Engine for LUT-quantized LLMs <br></em> </p> <div align="center">

GitHub License <a href="https://pypi.org/project/flute-kernel/">Version</a> <a href="https://arxiv.org/abs/2407.10960">arXiv</a>

</div> <div align="center">

[Background] [Benchmarks] [Getting Started] [Compatibility] [Model Zoo]

</div>

Update

Installation

Install FLUTE with pip or from source:

# For CUDA 12.1
pip install flute-kernel
# For CUDA 11.8
pip install flute-kernel -i https://flute-ai.github.io/whl/cu118

Head over to Getting Started and try it out!

Background

Uniform quantization converts full precision weights to lower-precision intervals of equal size. Lookup table (LUT) quantization is a flexible variant of non-uniform quantization which can map intervals to arbitrary values via a lookup table.

<table align="center"> <tr> <th>Uniform (Integer) Quantization</th> <th>Lookup Table Quantization</th> </tr> <tr> <td align="center">

$$\widehat{\mathbf{W}} = \mathtt{float}(\mathbf{Q}) \cdot \mathbf{s}$$

</td> <td align="center">

$$\widehat{\mathbf{W}} = \mathtt{tableLookup}(\mathbf{Q}, \mathtt{table}) \cdot \mathbf{s}$$

</td> </tr> </table>

where $\mathbf{Q}$ denote the quantized weight, $\mathbf{s}$ the (group-wise) scales, and $\widehat{\mathbf{W}}$ the de-quantized weight. Here are some examples of the lookup table suppored in FLUTE.

<table align="center"> <tr> <th>Examples</th> <th>Notes</th> </tr> <tr> <td align="left">

int4, int3, int2

</td> <td align="left">

recovers uniform/integer quantization

</td> </tr> <tr> <td align="left">

fp4, fp3, fp2

</td> <td align="left"> </td> </tr> <tr> <td align="left">

nf4, nf3, nf2

</td> <td align="left">

generalizes the nf4 data-format introduced in QLoRA

</td> </tr> </td> </tr> <tr> <td align="left">

any arbitrary table

</td> <td align="left">

you could even learn it!

</td> </tr> </table>

New Models Powered by FLUTE

The flexibility of the kernel could lead to new quantization algorithms. As a proof of concept, we are releasing a few models quantized using Learned Normal Float (NFL) --- a simple extension to the nf4 data format introduced in QLoRA. NFL initialized the lookup table and the scales with those from NF quantization. Then, it uses calibration data to learn the scales via straight through estimation for for the gradient with respect to the scales.

Benchmarks

For additional benchmarks, detailed breakdowns, and corresponding instruction-tuned models, please refer to the paper and the model zoo.

<p align="center"> <img src="assets/intro-figure.jpg" /> </p>

LLaMA-3.1

Wiki PPLC4 PPLLLM Eval Avg.Wiki PPLC4 PPLLLM Eval Avg.
LLaMA-3.1 (8B)6.319.6069.75LLaMA-3.1 (70B)2.827.1875.45
+ NFL W4G646.2410.0669.13+ NFL W4G643.097.5374.84
+ NFL W3G647.2311.8365.66+ NFL W3G644.298.9172.65

Gemma-2

Wiki PPLC4 PPLLLM Eval Avg.Wiki PPLC4 PPLLLM Eval Avg.
Gemma-2 (9B)6.8810.1273.12Gemma-2 (27B)5.708.9875.71
+ NFL W4G646.4910.3572.50+ NFL W4G645.699.3174.11

Getting Started

FLUTE + vLLM

FLUTE-quantized models (Model Zoo) can be directly served using exisiting frameworks such as vLLM.

- python -m vllm.entrypoints.openai.api_server \
+ python -m flute.integrations.vllm vllm.entrypoints.openai.api_server \
    --model [MODEL] \
    --revision [REVISION] \
    --tensor-parallel-size [TP_SIZE] \
+   --quantization flute

For example, the following commmand runs the FLUTE-quantized LLaMA-3.1 (8B) on a single GPU.

python -m flute.integrations.vllm vllm.entrypoints.openai.api_server \
    --model radi-cho/Meta-Llama-3.1-8B-FLUTE \
    --quantization flute

We can then query the vLLM server as usual.

curl http://localhost:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "radi-cho/Meta-Llama-3.1-8B-FLUTE",
        "prompt": "San Francisco is a",
        "max_tokens": 7,
        "temperature": 0
    }'

FLUTE + HuggingFace

FLUTE also runs out of the box with HuggingFace and its accelerate extension. This integration is mostly experimental and not optimized. Users sensitive to performance considerations should use the vLLM integration instead.

  1. Loading a pre-quantized FLUTE model.
import flute.integrations.huggingface

- model = AutoModelForCausalLM.from_pretrained(
+ model = flute.integrations.huggingface.from_pretrained(
    "radi-cho/Meta-Llama-3.1-8B-FLUTE",
    # all of your favoriate HF flags will be forwarded
    device_map="auto")
  1. Loading and quantizing a dense model.
import flute.integrations.base
flute.integrations.base.prepare_model_flute(
    name="model.model.layers",
    module=model.model.layers,  # for LLaMA-3 and Gemma-2
    num_bits=num_bits,
    group_size=group_size,
    fake=False,
    handle_hooks=True)  # for `accelerate` hooks

After this, the model can be used as normal. Please checkout the quantization guide for more information.

Support and Compatibility

Kernel

DescriptionSupported (via pip)Supported (build from source)
Input dtypestorch.float16 torch.bfloat16
Bits4bit 3bit2bit
Group Sizes32 64 128 256
GPUsA100 A6000 RTX 4090H100 (unoptimized)

[!WARNING] In the current release, we noticed torch.bfloat16 is slower than torch.float16. This likely because of lack of tuning, and that Ampere GPUs lack a hardware acceleration for bfloat16 vectorized atomic-add.

[!WARNING] We noticed several numerically unstable situations using bits=4, group-size=256, GPU=A100, though this is relatively rare (8 of 9360 test cases failed). We also noticed correctness issues in some situations with bits=4, group-size=256, dtype=bfloat16, GPU=RTX4090 (1 of 52 test cases failed). We will be looking into this, but we suggest avoiding these particular use cases (W4G256) for now.

Models

[!NOTE] As of the current release, the kernel is shape-specialized due to legacy reasons (i.e., we tune tile sizes etc for each matrix shape). Please see the below chart for the supported use cases, as different platform and tensor parallel size changes the matrix shapes. We plan to add supports for a broad range of shapes in the near future. In the meantime, please let us know if you have any specific models in mind and we are happy to add support for them.

ModelSingle GPU / Pipeline ParallelTensor Parallel
LLaMA-3/3.1 (8B)
LLaMA-3/3.1 (70B)2 or 4 GPUs
LLaMA-3.1 (405B)4 or 8 GPUs
Gemma-2 (9B)
Gemma-2 (27B)2 or 4 GPUs

Model Zoo

[!NOTE] The models we release here are trained on more data and hence different from those in the paper.

[!TIP] The HuggingFace Hub links are for NFL W4G64 quantization by default. To use the NFL W3G64 quantization, add --revision nfl_w3g64.

LLaMA-3.1 (8B)

WikiC4PIQAARC-EARC-CHellaSwagWinoAvg.
Unquantized6.319.6079.1682.2052.6560.7174.0369.75
NFL W4G646.2410.0679.3881.6151.5459.5773.5669.13
NFL W3G647.2311.8377.9176.9846.3356.7470.3265.66

LLaMA-3.1 (70B)

WikiC4PIQAARC-EARC-CHellaSwagWinoAvg.
Unquantized2.827.1882.8185.3159.6467.4982.0075.45
NFL W4G643.097.5383.0385.5258.1967.0480.4374.84
NFL W3G644.298.9182.0483.2954.7864.9978.1472.65

LLaMA-3.1 (405B)

Note that the weights are in the branch nf_w4g64 and thus --revision nf_w4g64 is needed since these are not on the default branch.

LLaMA-3.1 Instruct (8B)

WikiC4
NFL W4G646.7811.11
NFL W3G647.7312.83

LLaMA-3.1 Instruct (70B)

WikiC4
NFL W4G644.159.18
NFL W3G644.749.48

LLaMA-3.1 Instruct (405B)

Note that the weights are in the branch nf_w4g64 and thus --revision nf_w4g64 is needed since these are not on the default branch.

LLaMA-3 (8B)

WikiC4PIQAARC-EARC-CHellaSwagWinoAvg.
Unquantized6.19.279.980.150.460.272.868.6
NFL W4G646.119.3879.3379.7949.7459.2273.9568.41
NFL W3G647.1311.0678.7876.2244.3756.6970.3265.28

LLaMA-3 (70B)

WikiC4PIQAARC-EARC-CHellaSwagWinoAvg.
Unquantized2.96.982.486.960.366.480.675.3
NFL W4G643.037.0382.1585.9857.8566.1779.7974.39
NFL W3G644.158.1080.7483.7155.2964.0578.4572.45

LLaMA-3 Instruct (8B)

WikiC4
NFL W4G646.7810.61
NFL W3G647.7512.28

LLaMA-3 Instruct (70B)

WikiC4
NFL W4G643.677.95
NFL W3G644.9010.86

Gemma-2 (9B)

WikiC4PIQAARC-EARC-CHellaSwagWinoAvg.
Unquantized6.8810.1281.3987.3761.3561.2374.2773.12
NFL W4G646.4910.3581.2886.2459.3060.4075.3072.50
NFL W3G647.0611.1480.5283.1655.4658.2872.6970.02

Gemma-2 (27B)

WikiC4PIQAARC-EARC-CHellaSwagWinoAvg.
Unquantized5.708.9883.2487.8462.8865.3579.2475.71
NFL W4G645.699.3182.5386.4559.2264.1378.2174.11

Gemma-2 Instruct (9B)

WikiC4
NFL W4G646.8811.02
NFL W3G647.3511.72

Gemma-2 Instruct (27B)

WikiC4
NFL W4G645.919.71

Quantizing Your Own Models

We provide two APIs to quantize a custom models. The easist way is to use the command line interface.

Simple Normal Float Quantization

python -m flute.integrations.base \
    --pretrained_model_name_or_path meta-llama/Meta-Llama-3-70B-Instruct \
    --save_directory Meta-Llama-3-70B-Instruct-NF4 \
    --num_bits 4 \
    --group_size 128

The CLI essentially wraps around the following Python API,

from transformers import (
    LlamaForCausalLM,
    Gemma2ForCausalLM,
    AutoModelForCausalLM)
import flute.integrations.base

model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path,
    device_map="cpu",
    torch_dtype="auto")

if isinstance(model, (LlamaForCausalLM, Gemma2ForCausalLM)):
    flute.integrations.base.prepare_model_flute(
        name="model.model.layers",
        module=model.model.layers,
        num_bits=num_bits,
        group_size=group_size,
        fake=False)
else:
    # more models to come
    raise NotImplementedError

Converting bitsandbytes Model into FLUTE Model

While FLUTE has its own Normal Float (NF) implementation, we could convert an existing HuggingFace model quantized via bitsandbytes into FLUTE format. To do so, just add two lines to the Python API,

flute.integrations.base.prepare_model_flute(
    name="model.model.layers",
    module=model.model.layers,
    num_bits=num_bits,
    group_size=group_size,
    fake=False,
+   prepare_bnb_layers=True,
+   default_bnb_dtype=torch.float16,
)

It's worth noting that we do not support double quantization, and the conversion will materialize the first-level scales.

Learned Normal Float Quantization (NFL)

NFL initialized the lookup table and the scales with those from NF quantization. Then, it uses calibration data to learn the scales via straight through estimation for for the gradient with respect to the scales.

To use NFL quantization, call the following function before prepare_model_flute. We also provide an example jupyter notebook to illustrate the entire process.

import flute.integrations.learnable

flute.integrations.learnable.learn_scales(
    model=model,
    tokenizer=tokenizer,
    num_bits=num_bits,
    group_size=group_size,
    custom_corpora=list_of_corpora,
    samples=num_samples,
)

Build From Source

  1. Clone the CUTLASS library.
# Unfortunately, the path is hard-coded as of now. If you install CUTLASS
# in a different directory, please make sure the corresponding path in
# `setup.py` is updated.
cd /workspace

git clone https://github.com/NVIDIA/cutlass.git
cd cutlass
git checkout v3.4.1
  1. Build.
git clone https://github.com/HanGuo97/flute
cd flute
pip install -e .

Note: the build process requires having the local CUDA version (nvcc --version) match PyTorch's CUDA. In situations in which the build process throws an error related to CUDA version mismatch, try adding --no-build-isolation.

Acknowledgement and Citation

Special thanks to Dmytro Ivchenko, Yijie Bei, and the Fireworks AI team for helpful discussion. If you find any of the models or code in this repo useful, please feel free to cite:

@article{flute2024,
  title={Fast Matrix Multiplications for Lookup Table-Quantized LLMs},
  author={Guo, Han and Brandon, William and Cholakov, Radostin and Ragan-Kelley, Jonathan and Xing, Eric P and Kim, Yoon},
  journal={arXiv preprint arXiv:2407.10960},
  year={2024}
}