Home

Awesome

DeciMamba: Exploring the Length Extrapolation Potential of Mamba

<p align="center">

<a href="https://assafbk.github.io/website/">Assaf Ben-Kish</a>, <a href="https://itamarzimm.github.io/">Itamar Zimerman</a>, <a href="https://scholar.google.com/citations?user=FZYAWe4AAAAJ&hl=en">Shady Abu-Hussein</a>, <a href="https://scholar.google.co.il/citations?user=DmzoCRMAAAAJ&hl=en">Nadav Cohen</a>, <a href="https://scholar.google.com/citations?user=5JserkUAAAAJ&hl=en">Amir Globerson</a>, <a href="https://scholar.google.co.il/citations?user=UbFrXTsAAAAJ&hl=en">Lior Wolf</a>, <a href="https://www.giryes.sites.tau.ac.il/">Raja Giryes</a>

<a href="https://arxiv.org/abs/2406.14528"><img src="https://img.shields.io/badge/arXiv-2311.13608-b31b1b.svg"></a>

We present DeciMamba (Decimating-Mamba), the first context extension method for Mamba. In synthetic tasks, as well as in real-world long-range NLP tasks, DeciMamba is able to extrapolate to sequences that are magnitudes longer than the ones seen during training. It does so while requiring less computational resources and doesn't require retraining:

<img src="etc/doc_ret.jpeg" width="90%"/> <br> <img src="etc/niah.png" width="90%"/>

</p> <br>

Release Updates

<br>

Setup

Clone Project

git clone https://github.com/assafbk/DeciMamba.git
cd DeciMamba

Create Environment

To set up our environment, please run:

conda env create -f environment.yml
conda activate decimamba

Install Mamba:

pip install causal-conv1d==1.1.1
pip install mamba-ssm==1.1.1

Additional Requirements - Passkey Retrieval

Install the required submodule via:

git submodule init
git submodule update

Additional Requirements - Language Modeling

In order to train/evaluate the Language Modeling task, the PG-19 dataset must be tokenized. This can be done using the following script:

python ./custom_datasets/tokenize_pg19.py
<br>

Evaluate DeciMamba

We uploaded the weights of the best DeciMamba and Mamba models for each task:

Eval IDTaskModel typeCheckpoint
0Document RetrievalDeciMamba-130m🤗 assafbk/decimamba-130m-squad-doc-ret
1Document RetrievalMamba-130m🤗 assafbk/mamba-130m-squad-doc-ret
2Passkey RetrievalDeciMamba-130m🤗 assafbk/decimamba-130m-niah
3Passkey RetrievalMamba-130m🤗 assafbk/mamba-130m-niah
4Language ModelingDeciMamba-130m🤗 assafbk/decimamba-130m-pg19
5Language ModelingMamba-130m🤗 assafbk/mamba-130m-pg19
6Passkey Retrieval - Save Data For Mamba Attn MapsMamba-130m🤗 assafbk/mamba-130m-niah
<br>

To run the evaluation script:

python finetune_ssm.py --eval <eval_id> --device <device_id>

Arguments:

<br>

Train DeciMamba

To run the training script:

python finetune_ssm.py

All training metrics are displayed in the wandb webpage.

The configuration file is ./configs/finetune_ssm_config.json. <br> General configurations:

Decimation configurations:

Note that decimation is automatically disabled when mamba_arch != "deci".

Additional configurations:

Check out ./configs/finetune_ssm_config.json for more configurations.

Train for Document Retrieval

In ./configs/finetune_ssm_config.json set:

Special configurations:

Train for Passkey Retrieval

First, make sure that the additional submodule was cloned (see 'Additional Requirements - Passkey Retrieval' above).

Then, in ./configs/finetune_ssm_config.json set:

Special configurations:

Train for Language Modeling

First, make sure that the PG-19 dataset was tokenized (see 'Additional Requirements - Language Modeling' above).

Next, in ./configs/finetune_ssm_config.json set:

Special configurations:

Notes and Tips:

<br>

Calculate Mamba Hidden Attention Maps

To calculate Mamba's hidden attention maps:

python finetune_ssm.py --eval 6 --device <device_id>
<br>

Acknowledgments

We thank the authors of Mamba for their amazing work and contribution: https://github.com/state-spaces/mamba

Additionally, we thank the authors of BABILong, as we use a version of their code for the passkey retrieval task: https://github.com/booydar/babilong

Citation

If you find this work useful, please cite the following:

@misc{benkish2024decimambaexploringlengthextrapolation,
      title={DeciMamba: Exploring the Length Extrapolation Potential of Mamba}, 
      author={Assaf Ben-Kish and Itamar Zimerman and Shady Abu-Hussein and Nadav Cohen and Amir Globerson and Lior Wolf and Raja Giryes},
      year={2024},
      eprint={2406.14528},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2406.14528}, 
}