Awesome
Gated Delta Networks: Improving Mamba2 with Delta Rule
Official PyTorch implementation of Gated Delta Networks: Improving Mamba2 with Delta Rule.
Songlin Yang, Jan Kautz and Ali Hatamizadeh.
For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 180 32"> <!-- Background rectangle --> <rect width="180" height="32" rx="6" fill="#1a1a1a"/> <!-- NVIDIA logo style --><text x="10" y="21" font-family="Arial, sans-serif" font-weight="bold" font-size="14" fill="#76B900"></text>
<!-- Divider --> <line x1="70" y1="8" x2="70" y2="24" stroke="#333" stroke-width="1"/> </svg>🌟 Why Gated DeltaNet?
Gated DeltaNet introduces a novel approach to linear transformers by combining:
- 🧠 Smart Memory Management: Intelligent memory management that knows what to keep and what to forget
- ⚡ Precise Updates: Targeted memory updates that enhance model efficiency
- 💻 Hardware Efficiency: Optimized implementation for real-world deployment
Efficiency
Gated DeltaNet shows exceptional performance in terms of training throughput compared to models like Mamba2 and Samba:
<p align="center"> <img src="https://github.com/user-attachments/assets/b5c96369-a998-442b-ad7c-2f9fb6979b44" width=62% height=62% class="center"> </p>Language Modeling and Reasoning
Our model outperforms competitors of various types(e.g. Transformer, RNN, hybrid) in terms of perplexity and zero-shot accuracy on reasoning benchmarks:
<p align="center"> <img src="https://github.com/user-attachments/assets/afaa4527-e974-4367-a784-6e19c21c8bc0" width=82% height=82% class="center"> </p>Long-context
Gated DeltaNet also achieves favorable perplexity scores on long-context benchmarks:
<p align="center"> <img src="https://github.com/user-attachments/assets/64c307f4-3b30-4899-ab17-6507e6506c51" width=72% height=72% class="center"> </p>📢 Latest Updates
12/09/2024
: 🔥 Code Release: Train your own Gated DeltaNet on Slimpajama dataset- Watch this space for more exciting updates!
🚀 Getting Started
Training Your Model
Launch your training with our streamlined command:
python ../pretrain.py \
--train_data_dir ${TRAIN_DATA} \
--val_data_dir ${VALIDATION_DATA} \
--output_root ${SAVE_DIR} \
--exp_name ${NAME} \
--model_name ${MODEL} \
--train_config ${CONFIG} \
--eval_iters ${EVAL_ITERS} \
--learning_rate ${LR} \
--micro_batch_size ${MICRO_BATCH_SIZE}
💡 Pro Tip: Add --interactive_job --debug
for interactive debugging sessions!
Please see this slurm script for training the GatedDeltaNet_H1 model with 0.4B parameters on 15B tokens. The training requires 4 nodes and can be finished in approximately 4 hours. For this run, the validation loss and perplexitty curves (1x & 2x for lengh extrapolation) are expected as follows:
📜 License
Copyright © 2024, NVIDIA Corporation. All rights reserved.
Licensed under the NVIDIA Source Code License-NC. See LICENSE for details.
🙏 Acknowledgements
Built on the shoulders of giants:
⭐ Support Us
If you find this work useful, please consider:
- Starring the repository
- Citing our paper
- Contributing to the codebase
Join us in pushing the boundaries of linear transformers! 🚀
Citation
If you find Gated DeltaNet to be useful for your work, please consider citing our paper:
@article{yang2024gated,
title={Gated Delta Networks: Improving Mamba2 with Delta Rule},
author={Yang, Songlin and Kautz, Jan and Hatamizadeh, Ali},
journal={arXiv preprint arXiv:2412.06464},
year={2024}
}