Home

Awesome

<h1 align="center">CIGA: Causality Inspired Invariant Graph LeArning</h1> <p align="center"> <a href="https://arxiv.org/abs/2202.05441"><img src="https://img.shields.io/badge/arXiv-2202.05441-b31b1b.svg" alt="Paper"></a> <a href="https://github.com/LFhase/CIGA"><img src="https://img.shields.io/badge/-Github-grey?logo=github" alt="Github"></a> <!-- <a href="https://colab.research.google.com/drive/1t0_4BxEJ0XncyYvn_VyEQhxwNMvtSUNx?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab"></a> --> <a href="https://openreview.net/forum?id=A6AFK_JwrIW"> <img alt="License" src="https://img.shields.io/static/v1?label=Pub&message=NeurIPS%2722&color=blue"> </a> <a href="https://github.com/LFhase/CIGA/blob/main/LICENSE"> <img alt="License" src="https://img.shields.io/github/license/LFhase/CIGA?color=blue"> </a> <a href="https://neurips.cc/virtual/2022/poster/54643"> <img src="https://img.shields.io/badge/Video-grey?logo=Kuaishou&logoColor=white" alt="Video"></a> <a href="https://lfhase.win/files/slides/CIGA.pdf"> <img src="https://img.shields.io/badge/Slides-grey?&logo=MicrosoftPowerPoint&logoColor=white" alt="Slides"></a> <!-- <a href="https://icml.cc/media/PosterPDFs/ICML%202022/a8acc28734d4fe90ea24353d901ae678.png"> <img src="https://img.shields.io/badge/Poster-grey?logo=airplayvideo&logoColor=white" alt="Poster"></a> --> </p>

This repo contains the sample code for reproducing the results of our NeurIPS 2022 paper: Learning Causally Invariant Representations for Out-of-Distribution Generalization on Graphs, which was also presented at ICML SCIS Workshop, known as Invariance Principle Meets Out-of-Distribution Generalization on Graphs. 😆😆😆

TODO items:

Introduction

Despite recent success in using the invariance principle for out-of-distribution (OOD) generalization on Euclidean data (e.g., images), studies on graph data are still limited. Different from images, the complex nature of graphs poses unique challenges to adopting the invariance principle:

  1. Distribution shifts on graphs can appear in <ins>a variety of forms</ins>:

    • Node attributes;
    • Graph structure;
    • A mixure of both;
  2. Each distribution shift can spuriously correlate with the label in <ins>different modes</ins>. We divide the modes into FIIF and PIIF, according to whether the latent causal feature $C$ fully determines the label $Y$, i.e., or $(S,E)\perp\mkern-9.5mu\perp Y|C$:

    • Fully Informative Invariant Features (FIIF): $Y\leftarrow C\rightarrow S\leftarrow E$;
    • Partially Informative Invariant Features (PIIF): $C\rightarrow Y\leftarrow S \leftarrow E$;
    • Mixed Informative Invariant Features (MIIF): mixed with both FIIF and PIIF;
  3. <ins>Domain or environment partitions</ins>, which are often required by OOD methods on Euclidean data, can be highly expensive to obtain for graphs.

<p align="center"><img src="./data/arch.png" width=50% height=50%></p> <p align="center"><em>Figure 1.</em> The architecture of CIGA.</p>

This work addresses the above challenges by generalizing the causal invariance principle to graphs, and instantiating it as CIGA. Shown as in Figure 1, CIGA is powered by an information-theoretic objective that extracts the subgraphs which maximally preserve the invariant intra-class information. With certain assumptions, CIGA provably identifies the underlying invariant subgraphs (shown as the orange subgraphs). Learning with these subgraphs is immune to distribution shifts.

We implement CIGA using the interpretable GNN architecture, where the featurizer $g$ is designed to extract the invariant subgraph, and a classifier $f_c$ is designed to classify the extracted subgraph. The objective is imposed as an additional contrastive penalty to enforce the invariance of the extracted subgraphs at a latent sphere space (CIGAv1).

  1. When the size of underlying invariant subgraph $G_c$ is known and fixed across different graphs and environments, CIGAv1 is able to identify $G_c$.
  2. While it is often the case that the underlying $G_c$ varies, we further incorporate an additional penalty that maximizes $I(G_s;Y)$ to absorb potential spurious parts in the estimated $G_c$ (CIGAv2).

Extensive experiments on $16$ synthetic or real-world datasets, including a challenging setting -- DrugOOD, from AI-aided drug discovery, validate the superior OOD generalization ability of CIGA.

Use CIGA in Your Code

CIGA is consist of two key regularization terms: one is the contrastive loss that maximizes $I(\widehat{G}_c;\widetilde{G}_c|Y)$; the other is the hinge loss that maximizes $I(\widehat{G}_s;Y)$.

The contrastive loss is implemented via a simple call (line 480 in main.py):

get_contrast_loss(causal_rep, label)

which requires two key inputs:

The hinge loss is implemented in line 430 to line 445 in main.py:

# a simple implementation of hinge loss
spu_loss_weight = torch.zeros(spu_pred_loss.size()).to(device)
spu_loss_weight[spu_pred_loss > pred_loss] = 1.0
spu_pred_loss = spu_pred_loss.dot(spu_loss_weight) / (sum(spu_pred_loss > pred_loss) + 1e-6)

which requires two key inputs:

Then we can calculate the weights spu_loss_weight in the hinge loss for each sample based on the sample-wise loss values, and apply the weights to spu_pred_loss.

Instructions

Installation and data preparation

Our code is based on the following libraries:

torch==1.9.0
torch-geometric==1.7.2
scikit-image==0.19.1 

plus the DrugOOD benchmark repo.

The data used in the paper can be obtained following these instructions.

Reproduce results

We provide the hyperparamter tuning and evaluation details in the paper and appendix. In the below we give a brief introduction of the commands and their usage in our code. We provide the corresponding running scripts in the script folder.

To obtain results of ERM, simply run

python main.py --erm

with corresponding datasets and model specifications.

Runing with CIGA:

Running with the baselines:

Due to the additional dependence of an ERM reference model in CNC, we need to train an ERM model and save it first, and then load the model to generate ERM predictions for positive/negative pairs sampling in CNC. Here is a simplistic example:

python main.py --erm --contrast 0 --save_model
python main.py --erm --contrast 1  -c_sam 'cnc'

Misc

As discussed in the paper that the current code is merely a prototypical implementation based on an interpretable GNN architecture, i.e., GAE, in fact there could be more implementation choices:

You can also find more discussions on the limitations and future works in Appendix B of our paper.

That being said, CIGA is definitely not the ultimate solution and it intrinsically has many limitations. Nevertheless, we hope the causal analysis and the inspired solution in CIGA could serve as an initial step towards more reliable graph learning algorithms that are able to generalize various OOD graphs from the real world.

If you find our paper and repo useful, please cite our paper:

@InProceedings{chen2022ciga,
  title       = {Learning Causally Invariant Representations for Out-of-Distribution Generalization on Graphs},
  author      = {Yongqiang Chen and Yonggang Zhang and Yatao Bian and Han Yang and Kaili Ma and Binghui Xie and Tongliang Liu and Bo Han and James Cheng},
  booktitle   = {Advances in Neural Information Processing Systems},
  year        = {2022}
}

Ack: The readme is inspired by GSAT. 😄