Awesome
Joint Learning of Label and Environment Causal Independence for Graph Out-of-Distribution Generalization
<a href="https://openreview.net/forum?id=z3HACY5CMa"> <img alt="License" src="https://img.shields.io/static/v1?label=Pub&message=NeurIPS%2723&color=blue"> </a>
This is the official code for the implementation of "Joint Learning of Label and Environment Causal Independence for Graph Out-of-Distribution Generalization" which is accepted by NeurIPS 2023. :smile:
Table of contents
Overview
In this work, we propose to simultaneously incorporate label and environment causal independence (LECI) to release the potential of pre-collected environment information in graph tasks, thereby addressing the challenges faced by prior methods on identifying causal/invariant subgraphs. We further develop an adversarial training strategy to jointly optimize these two properties for causal subgraph discovery with theoretical guarantees.
Installation
Conda dependencies
LECI depends on PyTorch (>=1.6.0), PyG (>=2.0), and RDKit (>=2020.09.5). For more details: conda environment
Note that we currently test on PyTorch (==1.10.1), PyG (==2.0.3), RDKit (==2020.09.5); thus we strongly encourage to install these versions.
Project installation
git clone https://github.com/divelab/LECI.git && cd LECI
pip install -e .
Run LECI
goodtg --config_path final_configs/GOODHIV/scaffold/covaraite/LECI.yaml --exp_round [1/2/3/4/5/6/7/8/9/10] --gpu_idx [0..9]
goodtg --config_path final_configs/GOODHIV/size/covaraite/LECI.yaml --exp_round [1/2/3/4/5/6/7/8/9/10] --gpu_idx [0..9]
goodtg --config_path final_configs/LBAPcore/assay/covaraite/LECI.yaml --exp_round [1/2/3/4/5/6/7/8/9/10] --gpu_idx [0..9]
goodtg --config_path final_configs/GOODMotif/basis/covaraite/LECI.yaml --exp_round [1/2/3/4/5/6/7/8/9/10] --gpu_idx [0..9]
goodtg --config_path final_configs/GOODMotif/size/covaraite/LECI.yaml --exp_round [1/2/3/4/5/6/7/8/9/10] --gpu_idx [0..9]
goodtg --config_path final_configs/GOODCMNIST/color/covaraite/LECI.yaml --exp_round [1/2/3/4/5/6/7/8/9/10] --gpu_idx [0..9]
goodtg --config_path final_configs/GOODSST2/length/covaraite/LECI.yaml --exp_round [1/2/3/4/5/6/7/8/9/10] --gpu_idx [0..9]
goodtg --config_path final_configs/GOODTwitter/length/covaraite/LECI.yaml --exp_round [1/2/3/4/5/6/7/8/9/10] --gpu_idx [0..9]
To run the code without installing the project, please replace goodtg
with python -m GOOD.kernel.main
.
Explanations of the arguments can be found in this file.
How to train LECI?
Valid LECI: The training of LECI is valid only when the optimal discriminator Proposition 3.2 is approximately learned, e.g.,
the environment branch loss at least should not indicate a random prediction when the adversarial training is not applied (or is weak). Note that the adversarial intensity
increases from 0 to $\lambda_{EA}$ as the training proceeds, which is controlled by self.config.train.alpha
in the code.
How to select the right learning rate? Since the environment labels $E$ are noisier than normal classification labels, LECI starts with lower learning rates than general GNNs.
How to select the valid hyperparameters? If the EA/LA loss never decreases (invalid LECI), please try decreasing $\lambda_{EA}$ and $\lambda_{LA}$.
For more details, please refer to the appendix of our paper.
Citing LECI
If you find this repository helpful, please cite our paper/preprint.
@inproceedings{gui2023joint,
title={Joint Learning of Label and Environment Causal Independence for Graph Out-of-Distribution Generalization},
author={Gui, Shurui and Liu, Meng and Li, Xiner and Luo, Youzhi and Ji, Shuiwang},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=z3HACY5CMa}
}
License
- The GOOD datasets are under MIT license.
- The DrugOOD dataset is under GPLv3
- The LECI code are under GPLv3 license, since the code architecture is based on GOOD.
Discussion
Please submit new issues or start a new discussion for any technical or other questions.
Contact
Please feel free to contact Shurui Gui or Shuiwang Ji!
Acknowledgements
This work was supported in part by National Science Foundation grants IIS-2006861 and IIS-1908220.