Home

Awesome

JODO


The implementation of Learning Joint 2D & 3D Diffusion Models for Complete Molecule Generation.

Represent molecules as 3D point cloud and 2D bonding graph:

<p align="left"> <img src="assets/exp_geom_ver.png" width="800"/> </p>

The generative diffusion process:

<p align="left"> <img src="assets/sampling_exp.png" width="1200"/> </p>

Visualization of molecules generated by JODO trained on the GEOM-Drugs dataset:

<p align="left"> <img src="assets/geom_vis_3_3d.png" width="1500"/> </p> <p align="left"> <img src="assets/geom_vis_3_2d.png" width="1500"/> </p> <p align="left"> <img src="assets/geom_vis_2_3d.png" width="1500"/> </p> <p align="left"> <img src="assets/geom_vis_2_2d.png" width="1500"/> </p>

Visualization of molecules generated by JODO trained on the QM9 dataset with explict hydrogen atoms:

<p align="left"> <img src="assets/qm9_vis_5_3d.png" width="1500"/> </p> <p align="left"> <img src="assets/qm9_vis_5_2d.png" width="1500"/> </p> <p align="left"> <img src="assets/qm9_vis_4_3d.png" width="1500"/> </p> <p align="left"> <img src="assets/qm9_vis_4_2d.png" width="1500"/> </p>

Dependencies

Dataset

We recommend using our processed dataset files provided here.

Download datasets:

# 718MB
wget https://zenodo.org/record/7966493/files/data.zip
unzip data.zip

If you want to construct the GEOM-Drugs dataset from scratch:

Generated Molecules

We provide pickles of 10000 molecules generated by JODO on different datasets in ./rdkit_mols. Molecules are saved as RDKit Mol objects. Just load the list of molecules and make further analysis.

# Example for loading molecules generated from JODO trained on GEOM-Drugs dataset. 
import pickle
mol_list = pickle.load(open('rdkit_mols/geom_jodo_ancestral_ckpt_35.pkl', 'rb'))

Evaluation

We construct a comprehensive evaluation pipeline for molecule generation, including 2D molecular graph metrics, 3D geometry metrics, and substructure geometry alignment metrics.

To evaluate your models with our pipeline conveniently, you can save your generated molecules as a list of RDKit Mol objects and run eval_rdkit_pkl.py.

Take QM9 as an example:

# Molecules with 3D positions and atom types, without bonds
python eval_rdkit_pkl.py --dataset_name qm9 --type 3D --root_path YOUR_DATASET_PATH --pkl_path YOUR_MOL_PATH

# Molecules with atom and bond types, without 3D positions
python eval_rdkit_pkl.py --dataset_name qm9 --type 2D --root_path YOUR_DATASET_PATH --pkl_path YOUR_MOL_PATH

# Molecules with atom types, bond types and 3D positions
python eval_rdkit_pkl.py --dataset_name qm9 --type both --sub_geometry=True --root_path YOUR_DATASET_PATH --pkl_path YOUR_MOL_PATH

Checkpoint

Our checkpoints are provided here.

Download checkpoints:

# Unconditional Generation: QM9, GEOM-Drugs (2.8GB)
wget https://zenodo.org/record/8002902/files/exp_uncond.zip
unzip exp_uncond.zip

# Conditional Generation: single quantum property on QM9 (3.1GB)
wget https://zenodo.org/record/8002902/files/exp_cond.zip 
unzip exp_cond.zip

# Conditional Generation: multi properties (1.6GB)
wget https://zenodo.org/record/8002902/files/exp_cond_multi.zip 
unzip exp_cond_multi.zip

# Molecular Graph Generation: ZINC250k, MOSES (3.9GB)
wget https://zenodo.org/record/8002902/files/exp_2d.zip 
unzip exp_2d.zip

Unconditional Generation

QM9 Training Example:

CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_qm9_jodo

QM9 Sampling Example:

# sample from our pretrained checkpoint
CUDA_VISIBLE_DEVICES=2 python main.py --config configs/vpsde_qm9_uncond_jodo.py --mode eval --workdir exp_uncond/vpsde_qm9_jodo --config.eval.ckpts '30' --config.eval.batch_size 2500 --config.sampling.steps 1000

GEOM-Drugs Training Example:

# Base
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_geom_jodo_base --config.model.n_layers 6 --config.model.nf 128

# Medium
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_geom_jodo_media

# Large
CUDA_VISIBLE_DEVICES=0,1 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_geom_jodo_large --config.model.nf 384 --config.training.n_iters 1500000

GEOM-Drugs Sampling Example:

# Base
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode eval --workdir exp_uncond/vpsde_geom_jodo_base --config.model.n_layers 6 --config.model.nf 128 --config.eval.ckpts '30' --config.eval.batch_size 800 --config.sampling.steps 1000

# Medium
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode eval --workdir exp_uncond/vpsde_geom_jodo_media --config.eval.ckpts '30' --config.eval.batch_size 1000 --config.sampling.steps 1000

# Large
CUDA_VISIBLE_DEVICES=0,1 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode eval --workdir exp_uncond/vpsde_geom_jodo_large --config.model.nf 384 --config.eval.ckpts '30' --config.eval.batch_size 500 --config.sampling.steps 1000

Using the simplified DGT without extra attention heads can also achieve relatively good performance:

# QM9 Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_qm9_jodo_sim --config.model.name DGT_concat_sim

# GEOM-Drugs Medium Training
CUDA_VISIBLE_DEVICES=2,3 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_geom_jodo_media_sim --config.model.name DGT_concat_sim

Conditional Generation

# Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_gap --config.cond_property gap
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_homo --config.cond_property homo
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_lumo --config.cond_property lumo
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_mu --config.cond_property mu
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_Cv --config.cond_property Cv
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_alpha --config.cond_property alpha

# Sampling
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_gap --config.cond_property gap --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_homo --config.cond_property homo --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_lumo --config.cond_property lumo --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_mu --config.cond_property mu --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_Cv --config.cond_property Cv --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_alpha --config.cond_property alpha --config.eval.ckpts '40'
# Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode train --workdir exp_cond_multi/vpsde_qm9_cond_jodo_Cv_mu --config.cond_property1 Cv --config.cond_property2 mu
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode train --workdir exp_cond_multi/vpsde_qm9_cond_jodo_gap_mu --config.cond_property1 gap --config.cond_property2 mu
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode train --workdir exp_cond_multi/vpsde_qm9_cond_jodo_alpha_mu --config.cond_property1 alpha --config.cond_property2 mu

# Sampling
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode eval --workdir exp_cond_multi/vpsde_qm9_cond_jodo_Cv_mu --config.cond_property1 Cv --config.cond_property2 mu --config.eval.ckpts '50'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode eval --workdir exp_cond_multi/vpsde_qm9_cond_jodo_gap_mu --config.cond_property1 gap --config.cond_property2 mu --config.eval.ckpts '50'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode eval --workdir exp_cond_multi/vpsde_qm9_cond_jodo_alpha_mu --config.cond_property1 alpha --config.cond_property2 mu --config.eval.ckpts '50'

Molecular Graph Generation

ZINC250k:

# Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_zinc_2d_jodo.py --mode train --workdir exp_2d/vpsde_zinc_2d_jodo --config.model.nf 1024 --config.model.n_heads 64 --config.model.n_layers 6 --config.training.snapshot_freq 300000

# Sampling
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_zinc_2d_jodo.py --mode eval --workdir exp_2d/vpsde_zinc_2d_jodo --config.model.nf 1024 --config.model.n_heads 64 --config.model.n_layers 6 --config.training.snapshot_freq 300000 --config.eval.ckpts '5'

MOSES:

# Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_moses_2d_jodo.py --mode train --workdir exp_2d/vpsde_moses_2d_jodo --config.model.nf 1024 --config.model.n_heads 64 --config.model.n_layers 6 --config.training.snapshot_freq 300000

# Sampling
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_moses_2d_jodo.py --mode eval --workdir exp_2d/vpsde_moses_2d_jodo --config.model.nf 1024 --config.model.n_heads 64 --config.model.n_layers 6 --config.training.snapshot_freq 300000 --config.eval.ckpts '4'

Training CDGS on QM9 and GEOM-Drugs:

# QM9
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_2d_cdgs.py --mode train --workdir exp_2d/vpsde_qm9_2d_cdgs

# GEOM-Drugs
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_2d_cdgs.py --mode train --workdir exp_2d/vpsde_geom_2d_cdgs

Citation

@article{huang2023learning,
  title={Learning Joint 2D \& 3D Diffusion Models for Complete Molecule Generation},
  author={Huang, Han and Sun, Leilei and Du, Bowen and Lv, Weifeng},
  journal={arXiv preprint arXiv:2305.12347},
  year={2023}
}

@article{huang2023conditional,
  title={Conditional Diffusion Based on Discrete Graph Structures for Molecular Graph Generation},
  author={Huang, Han and Sun, Leilei and Du, Bowen and Lv, Weifeng},
  journal={arXiv preprint arXiv:2301.00427},
  year={2023}