Home

Awesome

AGDIFF: Attention-Enhanced Diffusion for Molecular Geometry Prediction

License: MIT

paper

This repository contains the official implementation of the work "AGDIFF: Attention-Enhanced Diffusion for Molecular Geometry Prediction".

AGDIFF introduces a novel approach that enhances diffusion models with attention mechanisms and an improved SchNet architecture, achieving state-of-the-art performance in predicting molecular geometries.

Unique Features of AGDIFF

<p align="center"> <img src="assets/agdiff_framework.png" alt="photo not available" width="80%" height="100%"> </p>

https://github.com/user-attachments/assets/78feda75-3a20-422a-9b3f-f96fceea69cc

Content

  1. Environment Setup
  2. Dataset
  3. Training
  4. Generation
  5. Evaluation
  6. Acknowledgment
  7. Citation

Environment Setup

Install dependencies via Conda/Mamba

conda env create -f agdiff.yml
conda activate agdiff
pip install torch_geometric
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
pip install torch-cluster -f https://data.pyg.org/whl/torch-2.4.0+cu121.html

Once you installed all the dependencies, you should install the package locally in editable mode:

pip install -e .

Dataset

Official Dataset

The preprocessed datasets (GEOM) provided by GEODIFF can be found in this [Google Drive folder]. After downloading and unzipping the dataset, it should be placed in the folder path specified by the dataset variable in the configuration files located at ./configs/*.yml. You may also want to use the pretrained model provided in the same link.

The official raw GEOM dataset is also available [here].

Training

AGDIFF's training details and hyper-parameters are provided in the config files (./configs/*.yml). Feel free to tune these parameters as needed.

To train the model, use the following commands:

python scripts/train.py ./configs/qm9_default.yml
python scripts/train.py ./configs/drugs_default.yml

Model checkpoints, configuration YAML files, and training logs will be saved in a directory specified by --logdir in train.py.

Generation

To generate conformations for entire or part of test sets, use:

python scripts/test.py ./logs/path/to/checkpoints/${iter}.pt ./configs/qm9_default.yml \
    --start_idx 0 --end_idx 200

Here start_idx and end_idx indicate the range of the test set that we want to use. To reproduce the paper's results, you should use 0 and 200 for start_idx and end_idx, respectively. All hyper-parameters related to sampling can be set in test.py files. Specifically, for testing the qm9 model, you could add the additional arg --w_global 0.3, which empirically shows slightly better results.

We also provide an example of conformation generation for a specific molecule (alanine dipeptide) in the examples folder. To generate conformations for alanine dipeptide, use:

python examples/test_alanine_dipeptide.py ./logs/path/to/checkpoints/${iter}.pt ./configs/qm9_default.yml 

Evaluation

After generating conformations, evaluate the results of benchmark tasks using the following commands.

Task 1. Conformation Generation

Calculate COV and MAT scores on the GEOM datasets with:

python scripts/evaluation/eval_covmat.py path/to/samples/sample_all.pkl

Acknowledgement

Our implementation is based on GEODIFF, PyTorch, PyG, SchNet

Citation

If you use our code or method in your work, please consider citing the following:

@misc{wyzykowskiAGDIFFAttentionEnhancedDiffusion2024,
  title = {{{AGDIFF}}: {{Attention-Enhanced Diffusion}} for {{Molecular Geometry Prediction}}},
  shorttitle = {{{AGDIFF}}},
  author = {Wyzykowski, Andr{\'e} Brasil Vieira and Fathi Niazi, Fatemeh and Dickson, Alex},
  year = {2024},
  month = oct,
  publisher = {ChemRxiv},
  doi = {10.26434/chemrxiv-2024-wrvr4},
  urldate = {2024-10-09},
  archiveprefix = {ChemRxiv},
  langid = {english},
  keywords = {attention,conformer,diffusion models,generative,GNN,graph neural network,machine learning,structure}
}

Please direct any questions to André Wyzykowski (abvwmc@gmail.com) and Alex Dickson (alexrd@msu.edu).