Home

Awesome

<img src="./st-moe.png" width="450px"></img>

ST-MoE - Pytorch

Implementation of <a href="https://arxiv.org/abs/2202.08906">ST-MoE</a>, the latest incarnation of mixture of experts after years of research at Brain, in Pytorch. Will be largely a transcription of the <a href="https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py">official Mesh Tensorflow implementation</a>. If you have any papers you think should be added, while I have my attention on mixture of experts, please open an issue.

This should be SOTA for mixture-of-experts for autoregressive transformers. It is rumored that GPT4 is using 16 experts with top2 gating.

For non-autoregressive, would recommend going with the simpler and better <a href="https://github.com/lucidrains/soft-moe-pytorch">Soft MoE</a>.

Install

$ pip install st-moe-pytorch

Appreciation

Usage

import torch
from st_moe_pytorch import MoE

moe = MoE(
    dim = 512,
    num_experts = 16,               # increase the experts (# parameters) of your model without increasing computation
    gating_top_n = 2,               # default to top 2 gating, but can also be more (3 was tested in the paper with a lower threshold)
    threshold_train = 0.2,          # at what threshold to accept a token to be routed to second expert and beyond - 0.2 was optimal for 2 expert routing, and apparently should be lower for 3
    threshold_eval = 0.2,
    capacity_factor_train = 1.25,   # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
    capacity_factor_eval = 2.,      # capacity_factor_* should be set to a value >=1
    balance_loss_coef = 1e-2,       # multiplier on the auxiliary expert balancing auxiliary loss
    router_z_loss_coef = 1e-3,      # loss weight for router z-loss
)

inputs = torch.randn(4, 1024, 512)
out, total_aux_loss, balance_loss, router_z_loss = moe(inputs) # (4, 1024, 512), (1,), (1,), (1,)

# for the entire mixture of experts block, in context of transformer

from st_moe_pytorch import SparseMoEBlock

moe_block = SparseMoEBlock(
    moe,
    add_ff_before = True,
    add_ff_after = True
)

out, total_aux_loss, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,) (1,), (1,)

# the total auxiliary loss will need to be summed and then added to the main loss

# the other two losses are the unweighted breakdown for logging purposes

Todo

Citations

@inproceedings{Zoph2022STMoEDS,
    title   = {ST-MoE: Designing Stable and Transferable Sparse Expert Models},
    author  = {Barret Zoph and Irwan Bello and Sameer Kumar and Nan Du and Yanping Huang and Jeff Dean and Noam M. Shazeer and William Fedus},
    year    = {2022}
}