Home

Awesome

ProGen - (wip)

Implementation and replication of <a href="https://arxiv.org/abs/2004.03497">ProGen</a>, Language Modeling for Protein Generation, in Pytorch and Jax (the weights will be made easily transferrable between the two). You can think of this as GPT for proteins sequences.

Requirements

We are going to use <a href="https://github.com/python-poetry/poetry">Poetry</a> for managing the dependencies for this project. So first install it using the <a href="https://github.com/python-poetry/poetry#osx--linux--bashonwindows-install-instructions">one-liner bash command</a>.

Next, git clone the project and install the dependencies

$ git clone git@github.com:lucidrains/progen
$ cd progen
$ poetry install

For training on GPUs, you may need to rerun pip install with the correct CUDA version. You can follow the instructions <a href="https://github.com/google/jax#pip-installation-gpu-cuda">here</a>

# ex. CUDA 11.1
$ pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

For running any scripts, you'll notice that it will always be prepended with poetry run

Usage

from jax import random
from haiku import PRNGSequence
from progen_transformer import ProGen

model = ProGen(
    num_tokens = 256,
    dim = 512,
    seq_len = 1024,
    window_size = 256,       # local attention window size
    depth = 12,              # depth
    heads = 8,               # attention heads
    dim_head = 64,           # dimension per head
    ff_glu = True,           # use GLU in feedforward, from Noam's paper
    global_mlp_depth = 2     # last N global gmlp layers
)

rng = PRNGSequence(42)
seq = random.randint(next(rng), (1024,), 0, 256)

params = model.init(next(rng), seq)
logits = model.apply(params, next(rng), seq) # (1024, 256)

Training

Download Uniref50 from <a href="https://www.uniprot.org/downloads">UniProt</a> and place uniref50.fasta in the root directory

$ poetry run python generate_data.py

You should see a lot of green if everything succeeds. Then

$ poetry run python train.py

By default, the script will checkpoint and resume automatically, but if you wish to clear your progress and restart, just add a --new flag

$ poetry run python train.py --new

Model checkpoints will be saved periodically to ./ckpts

Finally, to sample from your checkpoint, just do

$ poetry run python sample.py

You can pass a prime with --prime. You can either pass the annotations, followed by #, to get the generated sequence, or pass the sequence (also followed by #) and get the generated annotations

$ poetry run python sample.py --prime "[Tax=Mammalia] #"

Mixed Precision

To use mixed precision training, you'll need to install the latest Haiku with the following command

$ pip install git+https://github.com/deepmind/dm-haiku

Then make sure to set the --mixed_precision flag when invoking the training script

$ poetry run python train.py --mixed_precision

Todo

Acknowledgements

Many thanks goes out to <a href="https://github.com/kingoflolz">Ben Wang</a>, who showed this type of large-scale training can be achieved with <a href="https://github.com/kingoflolz/mesh-transformer-jax">GPT-J</a>

Citations

@misc{madani2020progen,
    title   = {ProGen: Language Modeling for Protein Generation}, 
    author  = {Ali Madani and Bryan McCann and Nikhil Naik and Nitish Shirish Keskar and Namrata Anand and Raphael R. Eguchi and Po-Ssu Huang and Richard Socher},
    year    = {2020},
    eprint  = {2004.03497},
    archivePrefix = {arXiv},
    primaryClass = {q-bio.BM}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}