Home

Awesome

esmdiff

Installation

conda create -n py310 python=3.10 -y
conda activate py310
pip install -r requirements.txt
pip install -e .

The first time running ESM3:

In order to download the weights, we require users to accept our non-commercial license. The weights are stored on HuggingFace Hub under HuggingFace/EvolutionaryScale/esm3. Please create an account and accept the license.

from huggingface_hub import login
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig

# Will instruct you how to get an API key from huggingface hub, make one with "Read" permission.
login()

Process data

(Optionally) Download the whole pdb database and extract/process structures to pickle (.pkl) files:

# to download the whole pdb database
python scripts/download_pdb_mmcif.sh ./pdb_mmcif 

# ad hoc: biopython=1.79 is required for mmcif_parsing
pip install biopython==1.79
python scripts/pdb/preprocess.py --mmcif_dir pdb_data/pdb_mmcif --output_dir pdb_data/processed_chains --per_chain --strip_array

# turn back to this version for main usage
pip install biopython==1.84 

For training purpose, the VQ-VAE encoding should be pre-computed:

# turn processed pickle (above) or pdb files into 
python scripts/dump.py pdb_data/processed_chains pdb_data/processed_chains_encoding pkl
# if you have some dataset of pdb files at hand
python scripts/dump.py pdb_data/raw_pdb pdb_data/raw_pdb_encoding pdb

Training

sbatch train.sh experiment=jlm paths.data_dir=pdb_data/raw_pdb_encoding data.batch_size=16 logger=csv 
sbatch train.sh experiment=clm paths.data_dir=pdb_data/raw_pdb_encoding data.batch_size=16 logger=csv 
sbatch train.sh experiment=mdlm paths.data_dir=pdb_data/raw_pdb_encoding data.batch_size=16 logger=csv 

Inference

Sample from HuggingFace-based models (T5, GPT2), for example:

python slm/sample_hf.py ckpt_path=logs/ConditionalLanguageModeling/runs/dev_exp_name/checkpoints/epoch_999.ckpt inference.input=data/targets/bpti inference.output=outputs/inference inference.batch_size=32 inference.n_samples=100
# or 
python slm/sample_hf.py ckpt_path=logs/ConditionalLanguageModeling/runs/dev_exp_name/checkpoints/epoch_999.ckpt inference.target=bpti inference.output=outputs/inference inference.batch_size=32 inference.n_samples=100

Sample from ESMDiff (masked diffusion fine-tuned ESM3):

python slm/sample_esmdiff.py --input data/targets/bpti --output outputs/inference_esmdiff --num_steps 25 --num_samples 100 --ckpt logs/MaskedDiffusionLanguageModeling/runs/dev_exp_name/checkpoints/epoch_999.ckpt --mode ddpm
# inpainting 
python slm/sample_esmdiff.py --input data/targets/bpti --output outputs/inference_esmdiff --num_steps 25 --num_samples 100 --mask_ids 1,2,3,4,5

Evaluation of samples

See ./analysis.

LICENSE

The source code and model can be used for non-commerical purpose. For any parts related to ESM3, please strictly follow the EvolutionaryScale Community License Agreement https://www.evolutionaryscale.ai/policies/community-license-agreement.

Remarks

This repo is still work-in-progress and more features will be added in the short future.