Home

Awesome

mamba.py 🐍 : a simple and efficient Mamba implementation

A straightfoward implementation of Mamba in PyTorch with a simple parallel scan implementation, offering an major speedup over a sequential implementation, as the parallel scan allows the parallelization over the time dimension. It combines the ease of read with good performances when training. Few other functionalities are implemented, like Jamba, Vision Mamba as well as muP.

Updates


Overview

speed comparison

This graph shows the training time (forward and backward pass) of a single Mamba layer (d_model=16, d_state=16) using 3 different methods : CUDA, which is the official Mamba implementation, mamba.py, which is this repo, and sequential, which is a sequential (RNN-like) implementation of the selective scan.

This repo contains a simple and readable code implementing the Mamba architecture in pure PyTorch as well as MLX. You can also play around with the Jamba model, which combines Mamba and attention layers. The primary goal of this repo is educational.

<p align="center"> <img src="assets/logo.png" alt="a python and a mamba" width="300" height="300" alt="python mamba"/> </p>

<u>The repo is organized as follows : </u>

muP is implemented and compatible with both the Mamba models (see below for more details).

Usage

You can either download this repo or install it with pip install mambapy.

The most basic usage is to use the Mamba object (mamba.py), which implements a simple Mamba model given a configuration. No embedding, no head : input is (B, L, D) and output is (B, L, D) as well.

import torch
from mambapy.mamba import Mamba, MambaConfig

config = MambaConfig(d_model=16, n_layers=2)
model = Mamba(config)

B, L, D = 2, 64, 16
x = torch.randn(B, L, D)
y = model(x)

assert y.shape == x.shape

You can also use Mamba-2 by importing the Mamba2Config and Mamba2 objectfs from mamba2.py.

The class LM (lm.py) builds on the Mamba or Mamba-2 objects and offers a classic API for language models. It can be used as follows :

from mambapy.lm import LM, MambaConfig

config = MambaConfig(d_model=16, n_layers=4) # core model
model = MambaLM(config, vocab_size=32000) # encapsulate it in a LM

x = torch.randint(high=32000, size=(16, 64))
logits = model(x) # (B, L, vocab_size)

It simply encapsulates a Mamba(-2) object with an embedding layer, a final normalization and a language modeling head.

You can use it off the shelf with a pretrained Mamba model :

from mambapy.lm import from_pretrained
from transformers import AutoTokenizer

model = from_pretrained('state-spaces/mamba-130m').to("cuda")
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

output = model.generate(tokenizer, "Mamba is a type of")

This is the structure of the mamba.py modules:

<p align="center"> <img src="assets/mamba_structure.jpg" width="737" height="429" alt="mamba structure"/> </p>

Jamba

You can also train and run inference on Jamba models. Take a look at the jamba.py file, which constructs a Jamba object, which interleaves Mamba layers (from mamba.py) with attention layers.

This is the structure of the modules found in jamba.py :

<p align="center"> <img src="assets/jamba_structure.jpg" width="737" height="429'" alt="mamba structure"/> </p> <p align="center"> <img src="assets/jamba_modules.jpg" width="602" height="343" alt="mamba structure"/> </p>

The API is the same as with the Mamba and MambaLM models. You can load a pretrained Jamba model like so :

from mambapy.jamba_lm import from_pretrained
from transformers import AutoTokenizer

model = from_pretrained('TechxGenus/Mini-Jamba').to("cuda")
tokenizer = AutoTokenizer.from_pretrained('TechxGenus/Mini-Jamba')

output = model.generate(tokenizer, "def min(arr):")

📁 examples

There are two basics examples available (some may be outdated):

If you want a full training example (like in llama2.c), you can check the othello_mamba repo I've done. With this repo, you can train a Mamba or a Jamba from scratch, use bfloat16, easily swipe it with a Transformer, come up with your own data, etc ...

## muP muP is a technique that allows to transfer hyperparameters (like the learning rate) from small to very large models. For example, it is possible to transfer (ie, use the same) the learning rate from a 2M model to a 10B model. This is extremely useful in practice when doing hyperparameters search : you just do sweeps to find the bests HPs on your small model, which is fast and inexpensive, and you automatically have the best performing HPs for your large model.

muP makes it possible by initializing and scaling the learning rates of the weights the model in a specific way. This is the result of these modifications:

<p align="center"> <img src="assets/coord_check_mamba2_mup.png" alt="a python and a mamba" width="1200" height="200" alt="python mamba"/> </p>

Without muP, what we get is :

<p align="center"> <img src="assets/coord_check_mamba2_no_mup.png" alt="a python and a mamba" width="1200" height="200" alt="python mamba"/> </p>

What we see here are the scale of the activations for various widths (d_model) starting at t=1 (initialization) to t=5 (5 steps of training). With SP (standard parametrization), the activations of the network are greatly vary with width, whereas they stay constant with width under muP. And intuitively, if the activations (the "signals") of the network behave the same no matter the width, one can easily imagine that the optimal HP is thus independent of the width.

And this is what we observe in practice when we sweep for the optimal LR :

<p align="center"> <img src="assets/sweep_mamba2.png" alt="a python and a mamba" width="900" height="340" alt="python mamba"/> </p>

The optimal LR shifts with bigger models under SP, whereas, with muP, it stays roughly constant. The smaller model has only 172k params, while the bigger has over 100M!

For more information about muP in general, you can take a look at the paper, and to see my derivation of the muP implementation for Mamba, and what it changes concretly in code, please see the associated PR.


Performances

This section provides a more comprehensive performance comparison between mamba.py and the official Mamba implementation. Overall, as the first graph of this file shows, both have approximately the same asymptotic performance with respect to the sequence length. You can think as mamba.py as a regular Transformer implementation, while the official Mamba implementation is more like FlashAttention v1. Both have their owns advantages.

That being said, does the two implementations have the same asymptotic performances with respect to the other parameters ?

d_model asymptotic performances
<p align="center"> <img src="assets/training_vs_d_model.png" alt="a python and a mamba" width="800" height="413" alt="python mamba"/> </p>

We can see that both implementations behave the same as we increase d_model. The gap between the two stays roughly the same. (mamba.py is overall ~2x slower)

d_state asymptotic performances
<p align="center"> <img src="assets/training_vs_d_state.png" alt="a python and a mamba" width="800" height="413" alt="python mamba"/> </p>

This graph is important. We see that here, the asymptotic performance is not the same as we increase d_state. For a reminder, d_state, or $N$ in the paper, is the state expansion factor : each channel of the input is expanded into $N$ channels of the hidden state.

<i>Note : the CUDA version doesn't seem to be impacted by the increase of d_state. This is because the benchmark was done with a batch size of 1 : the GPU was not at its full capacity and thus the impact of an increased d_state isn't visible. The same happens if you have a small model, or a small input length. See this issue.</i>

Does it matter in practice ? As of now, all the pretrained Mamba models (up to 2.8B parameters) used d_state=16, so this change of performance over d_state isn't important in this case. As d_state is not something that is supposed to grow (contrary to the seq length or d_model), this isn't a catastrophic result, but something to consider.

However, it is interesting to relate this observation with the claim made by Albert Gu and Tri Dao Mamba paper : <i>The main idea is to leverage properties of modern accelerators (GPUs) to <b>materialize the state ℎ only in more efficient levels of the memory hierarchy.</b></i> They also describe (Annex D) the main data movements of their selective scan : working mainly in SRAM, they can reduce the memory reads/writes by a factor of $O(N)$. This explains the different asymptotic behaviors that we see here.

With d_state=16 (as in state-spaces/mamba-2.8b-slimpj), the gap between the two is relatively small, but with d_state=64 (currently not used in any models), the gap widens. (note the OOM on the second graph)

<p align="center"> <img src="assets/training_vs_seqlen_d_state_var.png" alt="a python and a mamba" width="1152" height="240" alt="python mamba"/> </p>

All the previous graph were computed with a batch size of 1, on a A100 80GB. It is a measure of both the forward and backward pass of a single Mamba block.

The previous analysis showed the importance of kernel fusion, which reduces the memory accesses by $O(N)$, which makes the whole process faster.

But memory requierement should also be considered : the official Mamba implementation uses <b>recomputation</b> in the backward pass : rather than keeping in memory the activations computed during the forward pass, it simply recomputes them in the backward pass, when needed. This greatly reduces the memory requierement of the Mamba model when doing training. This is not implemented in this repo.

Hence, this repo implements one of the three techniques mentionned in the Mamba paper that form the so called "hardware-aware selective scan" : the parallel scan. We say how kernel fusion impacts the speed while recomputation the memory requierements.


Sources and where to learn more

TODOs

perfs related:

Citation

If you find this project useful in your research and wish to cite it, please use the following BibTex entry:

@software{mambapy,
  author = {Alexandre Torres--Leguet},
  title = {mamba.py: A simple, hackable and efficient Mamba implementation in pure PyTorch and MLX.},
  url = {https://github.com/alxndrTL/mamba.py},
  version = {1.0},
  year = {2024},
}