Home

Awesome

Amortized Inference for Causal Structure Learning

Downloads PyPi HuggingFace

This is the code repository for the paper Amortized Inference for Causal Structure Learning (Lorch et al., 2022, NeurIPS 2022). Performing amortized variational inference for causal discovery (AVICI) allows inferring causal structure from data based on a simulator of the domain of interest. By training a neural network to infer structure from the simulated data, it can acquire realistic inductive biases from prior knowledge that is hard to cast as score functions or conditional independence tests.

To install the latest stable release, run:

pip install avici

The package allows training new models from scratch on custom data-generating processes and performing predictions with pretrained models from our side. The codebase is written in Python and JAX.

Quick Start: Pretrained Model

Using the avici package is as easy as running the following code snippet:

import avici
from avici import simulate_data

# g: [d, d] causal graph of `d` variables
# x: [n, d] data matrix containing `n` observations of the `d` variables
g, x, _ = simulate_data(d=50, n=200, domain="rff-gauss")

# load pretrained model
model = avici.load_pretrained(download="scm-v0")

# g_prob: [d, d] predicted edge probabilities of the causal graph
g_prob = model(x=x)

You can run a working example this snippet directly in the following Google Colab notebook:

Open In Colab

The above code automatically downloads and initializes a pretrained model checkpoint (~60MB) of the domain and predicts the causal structure underlying the simulated data.

We currently provide the following models checkpoints, which can be specified by the download argument:

We recommend the latest scm-v0 for working with arbitrary real-valued data. This model was trained on SCM data simulated from a large variety of graph models with up to 100 nodes, both linear and nonlinear causal mechanisms, and homogeneous and heterogeneous additive noise from Gaussian, Laplace, and Cauchy distributions.

The models neurips-linear, neurips-rff, neurips-grn studied in our original paper were purposely trained on narrower training distributions to assess the out-of-distribution capability of AVICI. Unless your prior domain knowledge is strong, this may make the neurips-* models less suitable for benchmarking or as general purpose/out-of-the-box tools in your application. The training distribution of scm-v0 essentially combines those of neurips-linear and neurips-rff as well as their out-of-distribution settings in Lorch et al., (2022).

For details on the exact training distributions of these models, please refer to the model cards on HuggingFace. Appendix A of Lorch et al., (2022) also defines the training distributions of the neurips-* models. The YAML domain config file for each model is available in avici/config/train/.

Calling model as obtained from avici.load_pretrained predicts the [d, d] matrix of probabilities for each possible edge in the causal graph and accepts the following arguments:

When sampling synthetic data via avici.simulate_data, the following domain specifiers (dataset distributions) are currently provided: lin-gauss, lin-gauss-heterosked, lin-laplace-cauchy, rff-gauss, rff-gauss-heterosked, rff-laplace-cauchy, gene-ecoli, but custom config files can be specified, too. All these domains are defined inside avici.config.examples.

Quick Start: Custom Data-Generating Processes

In the example-custom folder, we provide an extended README together with a corresponding implementation that illustrates a detailed example of how to train an AVICI model for a custom data-generating process.

In short, the following three components are needed for training a full model:

  1. func.py: (Optional) Python file defining custom data-generating processes

    If you would like to train on data-generating processes not already provided by avici.synthetic, this file implements subclasses of GraphModel and MechanismModel doing so.

  2. domain.yaml: YAML file defining the training data distribution

    This configuration file specifies the full distribution over datasets used for training. Several graph models and data-generating mechanisms are available out-of-the-box, so providing additional modules via func.py is optional. This file can also be used to simulate data in avici.simulate_data.

  3. train.py: Python training script

    Fully-fledged training script for multi-device training (if available) based on the above configurations.

The checkpoints created using the training script can directly be loaded by the avici.load_pretrained function from above:

import avici
model = avici.load_pretrained(checkpoint_dir="path/to/checkpoint", expects_counts=False)

Custom Installation and Branches (Apple Silicon)

When using avici for your research and applications, we recommend using the easy-to-use main branch and installing the latest stable release using PyPI's pip as explained above.

For custom installations, we recommend using conda and generating a new environment via

conda env create --file environment.yaml

You then need to install the avici package with

pip install -e .

Note to Apple Silicon/M1 chip users:

Installing the package by first setting up a conda environment using our conda environment.yaml config and then installing pip install -r requirements.txt before finally running pip install -e . works on Apple M1 MacBooks. Directly installing avici via PyPI may install incompatible versions or builds of package requirements, which may cause unexpected, low-level errors.

Reproducibility branch

In addition to main, this repository also contains a full branch, which contains comprehensive code for reproducing the the experimental results in Lorch et al., (2022). The purpose of full is reproducibility; the branch is not updated anymore and may contain outdated notation and documentation.

Reference

@article{lorch2022amortized,
  title={Amortized Inference for Causal Structure Learning},
  author={Lorch, Lars and Sussex, Scott and Rothfuss, Jonas and Krause, Andreas and Sch{\"o}lkopf, Bernhard},
  journal={Advances in Neural Information Processing Systems},
  volume={35},
  year={2022}
}