Awesome
<h1 align="center"> Discovering Invariant Rationales for Graph Neural Networks 🔥 </h1> <div align="center"> </div>Overview
DIR (ICLR 2022) aims to train intrinsic interpretable Graph Neural Networks that are robust and generalizable to out-of-distribution datasets. The core of this work lies in the construction of interventional distributions, from which causal features are identified. See the quick lead-in below.
-
Q: What are interventional distributions?
They are basically the distributions when we intervene on one variable or a set of variables in the data generation process. For example, we could intervene on the base graph (highlighted in green or blue), which gives us multiple distributions:
<figure> <img src="figures/interventional-distributions.png" height="220"></figure> -
Q: How to construct the interventional distributions?
<figure> <img src="figures/framework.gif" height="350"></figure>
We design the following model structure to do the intervention in the representation space, where the distribution intervener is in charge of sampling one subgraph from the non-causal pool and fixing it at one end of the rationale generator. -
Q: How can these interventional distributions help us approach the causal features for rationalization?
Here is the simple philosophy: No matter what values we assign to the non-causal part, the class label is invariant as long as we observe the causal part. Intuitively, interventional distributions offer us "multiple eyes" to discover the features that make the label invariant upon interventions. And we propose the DIR objective to achieve this goal
<figure> <img src="figures/dir-objective.png" height="50"></figure>See our paper for the formal description and the principle behind it.
Installation
Note that we require 1.7.0 <= torch_geometric <= 2.0.2
. Simple run the cmd to install the python environment (you may want to change cudatoolkit accordingly based on your cuda version) or see requirements.txt
for the packages.
sh setup_env.sh
conda activate dir
<!-- - Main packages: PyTorch >= 1.5.0, 2.0.2 >= Pytorch Geometric >= 1.7.0, OGB >= 1.3.0.
- See `requirements.txt` for other packages. -->
Data download
- Spurious-Motif: this dataset can be generated via
spmotif_gen/spmotif.ipynb
. - Graph-SST2: this dataset can be downloaded here.
- MNIST-75sp: this dataset can be downloaded here. Download
mnist_75sp_train.pkl
,mnist_75sp_test.pkl
, andmnist_75sp_color_noise.pt
to the directorydata/MNISTSP/raw/
.
Run DIR
The hyper-parameters used to train the intrinsic interpretable models are set as default in the argparse.ArgumentParser
in the training files. Feel free to change them if needed. We use separate files to train each dataset.
Simply run python -m train.{dataset}_dir
to reproduce the results in the paper.
Common Questions:
How does the Rationale Generator update its parameters?: https://github.com/Wuyxin/DIR-GNN/issues/7
Reference
@inproceedings{
wu2022dir,
title={Discovering Invariant Rationales for Graph Neural Networks},
author={Ying-Xin Wu and Xiang Wang and An Zhang and Xiangnan He and Tat-seng Chua},
booktitle={ICLR},
year={2022},
}