Awesome
GRED
This repo contains the official code for our ICML 2024 paper: Recurrent Distance Filtering for Graph Representation Learning (https://arxiv.org/abs/2312.01538)
Setup
The code is developed using JAX and Flax (since we want to use associative_scan
):
python==3.10.12
jaxlib==0.4.14+cuda11.cudnn86
flax==0.7.2
optax==0.1.7 # optimizer lib
numpy==1.25.2
scipy==1.11.2
scikit-learn==1.3.0
Please refer to the JAX and Flax installation pages.
To keep the data preprocessing consistent with other baselines, we load the datasets using torch-geometric==2.3.1
and convert them into NumPy arrays.
You need torch-geometric
to run preprocess.py
and preprocess_peptides.py
, but you don't need it to run training scripts.
Data Preprocessing
To prepare MNIST, CIFAR10, ZINC, PATTERN, and CLUSTER, please run:
python preprocess.py
To prepare Peptides-func and Peptides-struct, please run:
python preprocess_peptides.py
Training
For MNIST and CIFAR10:
python train_pixel.py --name MNIST --num_layers 4 --num_hops 3 --dim_h 128 --dim_v 96
python train_pixel.py --name CIFAR10 --num_layers 8 --num_hops 5 --dim_h 96 --dim_v 64
For ZINC:
python train_zinc.py
For CLUSTER and PATTERN:
python train_sbm.py --name CLUSTER --num_layers 16 --dim_h 64 --dim_v 64 --weight_decay 0.2 --r_min 0.9
python train_sbm.py --name PATTERN --num_layers 10 --dim_h 72 --dim_v 64 --weight_decay 0.1 --r_min 0.5
For Peptides-func and Peptides-struct:
python train_peptides_func.py
python train_peptides_struct.py
If you have any questions regarding the code, please feel free to raise an issue. If you find our paper helpful in your research, please consider citing it:
@inproceedings{ding2024gred,
title={Recurrent Distance Filtering for Graph Representation Learning},
author={Yuhui Ding and Antonio Orvieto and Bobby He and Thomas Hofmann},
booktitle={Forty-first International Conference on Machine Learning},
year={2024}
}