Home

Awesome

Codebase for MoEUT

The official training repository for our paper "MoEUT: Mixture-of-Experts Universal Transformers". This codebase is the one we used to develop the model and it's quite messy.

If you are looking for an easy to use, short, cleaned up version, please take a look at https://github.com/robertcsordas/moeut.

Installation

This project requires Python >= 3.10 and PyTorch >= 2.2.

pip3 install -r requirements.txt

Create a Weights and Biases account and run

wandb login

More information on setting up Weights and Biases can be found on https://docs.wandb.com/quickstart.

For plotting, LaTeX is required (to avoid Type 3 fonts and to render symbols). Installation is OS specific.

Pretrained Model Checkpoints

We released the model checkpoints from our paper for all of our MoEUT models. They can be found at https://huggingface.co/robertcsordas/moeut_training_checkpoints.

NOTE: These are not production quality pretrained models, but only a proof of concept. Because of our limited resources, they are only trained on 6.5B tokens which is very little with modern standards.

The structure of the checkpoint repository:

├───cache - Tokenizers for the different datasets
└───checkpoints - Model checkpoints

The cache folder contains our tokenizers and must be copied to this folder in order to avoid minor differences that can happen if different version of SentencePiece is used than ours. When a specific task is run, it will automatically download the necessary data and tokenize it. It only tokenizes the amount of data actually needed for training/evaluation to avoid wasting space and time.

The checkpoints folder contains all the model checkpoints (without the optimizer state which we removed to save space). It can be loaded with --restore checkpoints/C4_44M.ckpt. It automatically resotres all configurations used for training.

In order to run a simple validation pass, you can run:

python3 main.py -restore checkpoints/C4_44M.ckpt -test_only 1 -log tb -name test -reset 1 -lm.eval.enabled 0 -stop_after 0

The flag -log tb is used to switch to tensorboard logging instead of W&B which was used for the training run, -lm.eval.enabled 0 disables the costly downstream evals. -stop_after 0 is a hack to avoid wasting exessive amount of time on tokenizing training data which will not be used for evaluation anywas (sorry, this could be handled better). For the other flags, see the details at the end of this doc.

Usage

The code makes use of Weights and Biases for experiment tracking. In the "sweeps" directory, we provide sweep configurations for all experiments we have performed.

To reproduce our results, start a sweep for each of the YAML files in the "sweeps" directory. Run wandb agent for each of them in the main directory. This will run all the experiments, and they will be displayed on the W&B dashboard.

ClusterTool

The code is designed to work with ClusterTool.

If used wih ClusterTool, W&B sweeps, run preemption, file synchronization, etc will be handled automatically.

Re-creating plots from the paper

Edit config file "paper/moe_universal/config.json". Enter your project name in the field "wandb_project" (e.g. "username/moeut"). Copy the checkpoint of your runs to paper/moe_universal/checkpoints/<run_id>/model.ckpt. Then run "paper/moe_universal/run_tests.py" to run additional validations on zero-shot downstream tasks. This will take long time.

To reprodce a speficif plot or table, run the script of interest within the "paper" directory. For example:

cd paper/moe_universal
python3 main_result_table.py

Structure

├───cache - temporary files automatically generated by this code
├───framework - reusable library for running experiments
│    ├─  datasets - a large collection of diverse datasets
│    ├─  visualize - universal plotting functions working for TF and W&B
│    ├─  helpers - helper routines for cluster, wandb and training setup
│    ├─  utils - useful utils (downloaders, samplers, multiprocessing)
│    ├─  layers - useful layers
│    ├─  tasks - main training loop, etc.
│    └─  * - all kinds of reusable components
│
├───save - saved checkpoints and trainig state
├───sweeps - Weights and Biases experiment configs
├───tasks - experiments. Add new experiments as new files here, it will be automatically picked up.
├───main.py - initialization code
└───cluster.json - configuration for the ClusterTool

Useful built-in arguments

There are many other useful default arguments, defined in framework/task/task.py, framework/task/simple_task.py and framework/helpers/training_helper.py.

Known issues

Triton seems to be broken on Volta GPUs when using float16 starting from PyTorch 2.2 onwards (see github issue). Until the PyTorch team does not fix the issue, please downgrade to PyTorch 2.1 or disable AMP if you have Volta GPUs.