Home

Awesome

MambaInLlama

This repository contains the code and released models for our paper.

<img src="assets/mambainllama.png" alt="MambaInLlama" style="width:50%;">

Our goal is to distill a large Transformer into a (Hybrid)-Mamba model while preserving the generational quality with the best effort. Typically, you only need 8x80G A100 (with very limited resources) and run for 3 to 4 days to reproduce our results. Our approach can be used for both base models and chat models.

Changelog

Released Models

Hybrid Mamba distilled from Zephyr

Teacher ModelHybrid Mamba Model - SFTHybrid Mamba Model - DPOHybrid Mamba Model - DPO
ZephyrMamba (1/2 attention)Mamba (1/2 attention)Mamba (1/2 attention)
Mamba (1/4 attention)Mamba (1/4 attention)Mamba (1/4 attention)
Mamba (1/8 attention)Mamba (1/8 attention)Mamba (1/8 attention)
ModelMMLU <br> (5 shots)AlpacaEval <br> (LC win against GPT-4)MT-Bench <br> (scored by GPT-4)
Zephyr61.4413.207.34
Mamba DPO 1 (1/2 attention)55.2320.667.12
Mamba DPO 3 (1/2 attention)55.3817.487.31
Mamba DPO 1 (1/4 attention)50.9417.167.03
Mamba DPO 3 (1/4 attention)51.1913.896.58
Mamba DPO 1 (1/8 attention)48.3515.326.39
Mamba DPO 3 (1/8 attention)48.4412.676.37

For reproduction, please follow the instructions here.

Hybrid Mamba distilled from Llama3

Teacher ModelHybrid Mamba Model - DPOHybrid Mamba2 Model - DPO
Meta-Llama-3-8B-InstructMamba (1/2 attention)Mamba2 (1/2 attention)
Mamba (1/4 attention)Mamba2 (1/4 attention)
Mamba (1/8 attention)Mamba2 (1/8 attention)
Mamba2 (0 attention)
ModelMMLU <br> (5 shots)AlpacaEval <br> (LC win against GPT-4)MT-Bench <br> (scored by GPT-4)
Mamba (1/2 attention)59.2629.617.35
Mamba2 (1/2 attention)56.6725.007.32
Mamba (1/4 attention)52.6825.856.86
Mamba2 (1/4 attention)53.9420.256.74
Mamba (1/8 attention)49.2020.766.46
Mamba2 (1/8 attention)50.8520.256.48
Mamba2 (0 attention)43.1914.495.64

For reproduction, please follow the instructions here.

Usage

Environment

We provide an environment file that lists the specific Python package versions used in our experiments. To ensure the best reproducibility, we suggest using these same package versions. Nonetheless, you may also use alternative versions and still be able to run the program. The alignment-handbook version that we use is here. The following script is to install mamba-ssm==2.2.2 and cuda-11.8.0.

# CUDA>=11.6 needed for `mamba-ssm` and `causal-conv1d`.
conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit
# Install PyTorch (with CUDA 11.8) before everything else. those assume you are using cu118
pip install torch --index-url https://download.pytorch.org/whl/cu118

pip install causal-conv1d==1.4.0
pip install flash-attn==2.6.3

If you install mamba-ssm using pip install mamba-ssm==2.2.2, you will need to manually change CONDA_ENV_PATH/site-packages/mamba_ssm/modules/mha.py to this version to support GQA, since GQA is used in Llama3. The mamba-ssm used in my experiment is from this commit.

Alternatively, you can build mamba-ssm from source, but ensure the commit is after this one, which fixes the GQA bugs in generations.

Generation Example

Mamba:

import torch
from transformers import AutoTokenizer
from mamba_inference.hybrid_wrapper import MambaTransformerHybridModelWrapper

pretrained_model_name = "JunxiongWang/MambaInLlama_0_50" # change the model that you want to test here
model = MambaTransformerHybridModelWrapper.from_pretrained(pretrained_model_name, torch_dtype=torch.bfloat16)
model.eval()

messages = [[
    {
        "role": "user",
        "content": "## Question: \n\nFarmer Brown has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens? ## Instruction \n\nPlease answer this question by first reasoning and then providing your answer.",
    },
]]

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
formatted_prompts = [
    tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages
]

prompts = [
    tokenizer.encode(formatted_prompt, return_tensors="pt", truncation=True, max_length=200)
    for formatted_prompt in formatted_prompts
]
batch_prompts = torch.cat(prompts, dim=0).cuda()

outputs = model.generate(
    input_ids=batch_prompts,
    max_length=1000,
    cg=True,
    return_dict_in_generate=True,
    output_scores=True,
    enable_timing=True,
    top_k=1,
    eos_token_id=tokenizer.eos_token_id
)

generated_text = tokenizer.batch_decode(outputs.sequences.tolist())
print(generated_text[0])

#output:
#Let's denote the number of chickens as C and the number of cows as K. We know that:
#1. There are 20 animals in total: C + K = 20.
#2. Chickens have 2 legs and cows have 4 legs. Together, they have 70 legs: 2C + 4K = 70.
#Now, we can solve these equations step by step:
#From the first equation, we can express K in terms of C:
#K = 20 - C
#Substitute this expression into the second equation:
#2C + 4(20 - C) = 70
#Simplify and solve for C:
#2C + 80 - 4C = 70
#-2C = -10
#C = 5
#So, there are 5 chickens on Farmer Brown's farm.

Mamba 2:

import torch
from transformers import AutoTokenizer
from mamba2_inference.hybrid_wrapper import MambaTransformerHybridModelWrapper

pretrained_model_name = "JunxiongWang/Mamba2InLlama_0_50" # change the model that you want to test here
model = MambaTransformerHybridModelWrapper.from_pretrained(pretrained_model_name, torch_dtype=torch.bfloat16)
model.eval()

messages = [[
    {
        "role": "user",
        "content": "## Question: \n\nFarmer Brown has 20 animals on his farm, all either chickens or cows. They have a total of 70 legs, all together. How many of the animals are chickens? ## Instruction \n\nPlease answer this question by first reasoning and then providing your answer.",
    },
]]

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
formatted_prompts = [
    tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages
]

prompts = [
    tokenizer.encode(formatted_prompt, return_tensors="pt", truncation=True, max_length=200)
    for formatted_prompt in formatted_prompts
]
batch_prompts = torch.cat(prompts, dim=0).cuda()

outputs = model.generate(
    input_ids=batch_prompts,
    max_length=1000,
    cg=True,
    return_dict_in_generate=True,
    output_scores=True,
    enable_timing=True,
    top_k=1,
    eos_token_id=tokenizer.eos_token_id
)

generated_text = tokenizer.batch_decode(outputs.sequences.tolist())
print(generated_text[0])

#output:
#**Step 1:** Let's denote the number of chickens as C and the number of cows as K.
#**Step 2:** Chickens have 2 legs each, and cows have 4 legs each. We can create an equation based on the total number of legs:
#2C (chicken legs) + 4K (cow legs) = 70 legs
#**Step 3:** We also know that there are 20 animals in total:
#C + K = 20
#**Step 4:** Solve the second equation for one variable, let's say K:
#K = 20 - C
#**Step 5:** Substitute the expression for K in the first equation:
#2C + 4(20 - C) = 70
#**Step 6:** Simplify and solve for C:
#2C + 80 - 4C = 70
#-2C = -10
#C = 5
#**Conclusion:** There are 5 chickens on Farmer Brown's farm.

Evaluation

Please follow the instructions here

Citation

If you use this codebase, or otherwise found our work valuable, please cite:

@article{junxiongdaniele2024mambainllama,
  title   = {The Mamba in the Llama: Distilling and Accelerating Hybrid Models},
  author  = {Junxiong Wang and Daniele Paliotta and Avner May and Alexander M. Rush and Tri Dao},
  journal = {arXiv preprint arXiv:2408.15237},
  year    = {2024}
}