Awesome
Linformer for Pytorch
An implementation of Linformer in Pytorch. Linformer comes with two deficiencies. (1) It does not work for the auto-regressive case. (2) Assumes a fixed sequence length. However, if benchmarks show it to perform well enough, it will be added to <a href="https://github.com/lucidrains/linear-attention-transformer">this repository</a> as a self-attention layer to be used in the encoder.
Linformer has been <a href="https://ai.facebook.com/blog/how-facebook-uses-super-efficient-ai-models-to-detect-hate-speech/">put into production</a> by Facebook!
Install
$ pip install linformer
Usage
Linformer language model
import torch
from linformer import LinformerLM
model = LinformerLM(
num_tokens = 20000,
dim = 512,
seq_len = 4096,
depth = 12,
heads = 8,
dim_head = 128, # be able to set the dimension of each head in multi-head attention
k = 256, # this is the k that the key/values are projected to along the sequence dimension
one_kv_head = True, # share one key/value head across all heads
share_kv = False, # share the same projection for keys and values
reversible = True # make network reversible, like Reformer
)
x = torch.randint(0, 20000, (1, 4096))
model(x) # (1, 4096, 20000)
Linformer
import torch
from linformer import Linformer
model = Linformer(
dim = 512,
seq_len = 4096,
depth = 12,
heads = 8,
k = 256,
one_kv_head = True,
share_kv = True
)
x = torch.randn(1, 4096, 512)
model(x) # (1, 4096, 512)
Single Self-Attention layer
import torch
from linformer import LinformerSelfAttention
attn = LinformerSelfAttention(
dim = 512,
seq_len = 4096,
heads = 8,
k = 256,
one_kv_head = True,
share_kv = True
)
x = torch.randn(1, 4096, 512)
attn(x) # (1, 4096, 512)
Self-Attention layer above receiving contextual keys. The sequence length is validated on the length of the contextual keys instead of the source sequence.
import torch
from linformer import LinformerSelfAttention
attn = LinformerSelfAttention(
dim = 512,
seq_len = 8192,
heads = 8,
k = 256,
one_kv_head = True,
share_kv = True
)
x = torch.randn(1, 2048, 512)
context = torch.randn(1, 8192, 512)
attn(x, context) # (1, 2048, 512)
Citations
@misc{wang2020linformer,
title={Linformer: Self-Attention with Linear Complexity},
author={Sinong Wang and Belinda Z. Li and Madian Khabsa and Han Fang and Hao Ma},
year={2020},
eprint={2006.04768},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@inproceedings{kitaev2020reformer,
title = {Reformer: The Efficient Transformer},
author = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
booktitle = {International Conference on Learning Representations},
year = {2020},
url = {https://openreview.net/forum?id=rkgNKkHtvB}
}