Home

Awesome

mamba-minimal-jax

Simple, minimal implementation of the Mamba SSM in one file of JAX.

Plan:

  1. First finish the model.py, done.
  2. Convert the pytorch weights into the JAX weights, done.
  3. Check the results of greedy generation is the same as pytorch, done.
  4. Implement the associative scan so that the state update is faster, done in the speedup branch. See discussion in https://github.com/srush/annotated-mamba/issues/1.
  5. Pay attention to the weights initialization so that we can train the model from scratch.
  6. Implement the step function for mamba inference.

From mamba-minimal

Featuring:

Does NOT include:

Demo

See demo.ipynb for examples of prompt completions.

from model import Mamba
from transformers import AutoTokenizer

model = Mamba.from_pretrained('state-spaces/mamba-370m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

generate(model, tokenizer, 'Mamba is the')

Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)

150 meters... 🫢 scary!

References

The Mamba architecture was introduced in Mamba: Linear-Time Sequence Modeling with Selective State Spaces by Albert Gu and Tri Dao.

The official implementation is here: https://github.com/state-spaces/mamba

The minimal implementation in torch is here: https://github.com/johnma2006/mamba-minimal