Home

Awesome

MoEUT: Mixture-of-Experts Universal Transformers

Official implementation of our MoEUT model.

The implementation uses the CVMM Triton kernel from $\sigma$-MoE.

For the training code, please refer to https://github.com/RobertCsordas/moeut_training_code.

Example

import moeut
import profiles

model = moeut.MoEUTLM(vocab_size, **profiles.MoEUT_244M).cuda()
out = model(tokens[:, :-1])

loss = F.cross_entropy(out.outputs.view(-1, vocab_size), tokens[:, 1:].flatten())
(loss + out.reg_loss).backward()

A simple example can be found in example.py.

Usage

from moeut import MoEUTLM

The signature of the init function is as follows:

 def __init__(self, n_tokens: int, d_model: int, n_layers: int, n_heads: int,
                 ff_n_experts: int, att_n_experts: int, d_head: Optional[int] = None,
                 group_size: int = 2, ff_k: int = 8,  att_k: int = 2, ff_expert_dropout: float = 0.0,
                 att_expert_dropout: float = 0.0, ff_expert_size: int = 128, dropout: float = 0.0, 
                 entropy_reg: float = 0.01, att_entropy_reg: float = 0.001, attention = SwitchHeadRope):

The meaning of the arguments:

The signature of the forward function:

def forward(self, x: torch.Tensor, mask: Optional[AttentionMask] = None,
            kv_cache: MultilayerKVCache = None) -> MoEUTOutput:

The meaning of the arguments:

The forward pass returns a MoEUTOutput object, which has 3 fields:

The AttentionMask has two optional boolean fields. True if to be removed. If None, they are ignored.

If you wish to use MoEUT for something else than language modeling, use MoEUT instead of MoEUTLM. The constructor is identical except for no n_tokens argument. The forward pass format is also identical, except the shape of inputs and outputs is [batch size, context length, d_model].

Configurations used in the paper

We provide the configurations used in the paper in profiles.py. We have the following options: MoEUT_44M, MoEUT_126M, MoEUT_244M, MoEUT_318M, MoEUT_727M, MoEUT_1B. They are dicts of parameters. Pass them to the constructor as e.g. **MoEUT_1B.

Useful tips

Try disabling SwitchHead attention for faster speed (att_n_experts=1). The degradation in predictive performance (perplexity) is minimal, still outperforming the dense baseline. Tested on 244M and 768M scales.

torch.compile() support

torch.compile() is supported with PyTorch >= 2.3.

Project structure

├───moeut - the MoEUT model. Copy this to your project.
│    ├─  cvmm.py - the CVMM Triton kernel.
│    └─  moeut.py - the implementation of MoEUT
│
├───example.py - an example forward-backward pass.
├───profiles.py - default configurations used in the paper.
├───LICENSE - MIT License.
└───README.md - this documentation.

Known issues

Triton seems to be broken on Volta GPUs when using float16 starting from PyTorch 2.2 onwards (see github issue). Until the PyTorch team does not fix the issue, please downgrade to PyTorch 2.1 or disable AMP if you have Volta GPUs.