Awesome
<img src="./gmlp.png" width="400px"></img>
gMLP - Pytorch
Implementation of <a href="https://arxiv.org/abs/2105.08050">gMLP</a>, an all-MLP replacement for Transformers, in Pytorch
Install
$ pip install g-mlp-pytorch
Usage
For masked language modelling
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP(
num_tokens = 20000,
dim = 512,
depth = 6,
seq_len = 256,
circulant_matrix = True, # use circulant weight matrix for linear increase in parameters in respect to sequence length
act = nn.Tanh() # activation for spatial gate (defaults to identity)
)
x = torch.randint(0, 20000, (1, 256))
logits = model(x) # (1, 256, 20000)
For image classification
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 512,
depth = 6
)
img = torch.randn(1, 3, 256, 256)
logits = model(img) # (1, 1000)
You can also add a tiny amount of attention (one-headed) to boost performance, as mentioned in the paper as aMLP
, with the addition of one extra keyword attn_dim
. This applies to both gMLPVision
and gMLP
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 512,
depth = 6,
attn_dim = 64
)
img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)
Non-square images and patch sizes
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision(
image_size = (256, 128),
patch_size = (16, 8),
num_classes = 1000,
dim = 512,
depth = 6,
attn_dim = 64
)
img = torch.randn(1, 3, 256, 128)
pred = model(img) # (1, 1000)
Experimental
A independent researcher proposes using a multi-headed approach for gMLPs in <a href="https://zhuanlan.zhihu.com/p/395005917">a blogpost on Zhihu</a>. To do so, just set heads
to be greater than 1
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP(
num_tokens = 20000,
dim = 512,
depth = 6,
seq_len = 256,
causal = True,
circulant_matrix = True,
heads = 4 # 4 heads
)
x = torch.randint(0, 20000, (1, 256))
logits = model(x) # (1, 256, 20000)
Citations
@misc{liu2021pay,
title = {Pay Attention to MLPs},
author = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
year = {2021},
eprint = {2105.08050},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@software{peng_bo_2021_5196578,
author = {PENG Bo},
title = {BlinkDL/RWKV-LM: 0.01},
month = aug,
year = 2021,
publisher = {Zenodo},
version = {0.01},
doi = {10.5281/zenodo.5196578},
url = {https://doi.org/10.5281/zenodo.5196578%7D
}