Home

Awesome

Speculative Decoding with Mamba

This repository provides a Python implementation for speculative decoding using Mamba models. Speculative decoding accelerates autoregressive generation by leveraging a small "draft" model to propose tokens and a larger "target" model to validate them, significantly reducing computational overhead.

This code, together with this, accompanies the Neurips 2024 paper The Mamba in the Llama: Distilling and Accelerating Hybrid Models.


Features


Getting Started

Installation

Follow the steps below to set up the environment and install dependencies:

# Create a conda environment
conda create --name specmamba python=3.11
conda activate specmamba

# Install PyTorch with CUDA support
conda install pytorch==2.2.1 pytorch-cuda=12.1 -c pytorch -c nvidia

# Install required Python packages
pip install causal_conv1d==1.4.0
pip install transformers
pip install flash_attn

# Install the repository
pip install -e .

Usage

You can run the decoding script by specifying the prompt and other generation parameters:

python speculative_mamba/run.py \
    --prompt "Italy is a country" \
    --n_tokens_to_generate 64 \
    --K 3 \
    --model_target state-spaces/mamba-2.8b \
    --model_draft state-spaces/mamba-130m \
    --dtype float16 \
    --top_k 50 \
    --top_p 0.8 \
    --temperature 0.8 \
    --cg

Parameters


Example

Run the script with the default settings:

python speculative_mamba/run.py

Output:

Decoding...
Prompt processing + decoding time: 4364ms
Acceptance rate: 68.25%
Italy is a country that has always had an important role in international affairs, both in the economic and in the political sphere.

But in the last years, the country has been going through a period of great political instability.

In the last year, the country has had three different Prime Ministers: Mario Monti, En

Note: your output and acceptance rate will vary.


TODO


Citation

If you use this repository, please cite:

@article{junxiongdaniele2024mambainllama,
  title   = {The Mamba in the Llama: Distilling and Accelerating Hybrid Models},
  author  = {Junxiong Wang and Daniele Paliotta and Avner May and Alexander M. Rush and Tri Dao},
  journal = {arXiv preprint arXiv:2408.15237},
  year    = {2024}
}