Home

Awesome

Gated Delta Networks: Improving Mamba2 with Delta Rule

nvidia-deltanet-badge

Official PyTorch implementation of Gated Delta Networks: Improving Mamba2 with Delta Rule.

Star on GitHub

Songlin Yang, Jan Kautz and Ali Hatamizadeh.

For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing

<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 180 32"> <!-- Background rectangle --> <rect width="180" height="32" rx="6" fill="#1a1a1a"/> <!-- NVIDIA logo style -->

<text x="10" y="21" font-family="Arial, sans-serif" font-weight="bold" font-size="14" fill="#76B900"></text>

<!-- Divider --> <line x1="70" y1="8" x2="70" y2="24" stroke="#333" stroke-width="1"/> </svg>

🌟 Why Gated DeltaNet?

Gated DeltaNet introduces a novel approach to linear transformers by combining:

Architecture Overview

Efficiency

Gated DeltaNet shows exceptional performance in terms of training throughput compared to models like Mamba2 and Samba:

<p align="center"> <img src="https://github.com/user-attachments/assets/b5c96369-a998-442b-ad7c-2f9fb6979b44" width=62% height=62% class="center"> </p>

Language Modeling and Reasoning

Our model outperforms competitors of various types(e.g. Transformer, RNN, hybrid) in terms of perplexity and zero-shot accuracy on reasoning benchmarks:

<p align="center"> <img src="https://github.com/user-attachments/assets/afaa4527-e974-4367-a784-6e19c21c8bc0" width=82% height=82% class="center"> </p>

Long-context

Gated DeltaNet also achieves favorable perplexity scores on long-context benchmarks:

<p align="center"> <img src="https://github.com/user-attachments/assets/64c307f4-3b30-4899-ab17-6507e6506c51" width=72% height=72% class="center"> </p>

📢 Latest Updates

🚀 Getting Started

Training Your Model

Launch your training with our streamlined command:

python ../pretrain.py \
--train_data_dir ${TRAIN_DATA} \
--val_data_dir ${VALIDATION_DATA} \
--output_root ${SAVE_DIR} \
--exp_name ${NAME} \
--model_name ${MODEL} \
--train_config ${CONFIG} \
--eval_iters ${EVAL_ITERS} \
--learning_rate ${LR} \
--micro_batch_size ${MICRO_BATCH_SIZE}

💡 Pro Tip: Add --interactive_job --debug for interactive debugging sessions!

Please see this slurm script for training the GatedDeltaNet_H1 model with 0.4B parameters on 15B tokens. The training requires 4 nodes and can be finished in approximately 4 hours. For this run, the validation loss and perplexitty curves (1x & 2x for lengh extrapolation) are expected as follows:

curves

📜 License

Copyright © 2024, NVIDIA Corporation. All rights reserved.

Licensed under the NVIDIA Source Code License-NC. See LICENSE for details.

🙏 Acknowledgements

Built on the shoulders of giants:

⭐ Support Us

If you find this work useful, please consider:

Join us in pushing the boundaries of linear transformers! 🚀

Citation

If you find Gated DeltaNet to be useful for your work, please consider citing our paper:

@article{yang2024gated,
  title={Gated Delta Networks: Improving Mamba2 with Delta Rule},
  author={Yang, Songlin and Kautz, Jan and Hatamizadeh, Ali},
  journal={arXiv preprint arXiv:2412.06464},
  year={2024}
}

Star History

Stargazers repo roster for @NVlabs/GatedDeltaNet

Star History Chart