Home

Awesome

Learning Invariances with Laplace Approximations (LILA)

A convenient gradient-based method for selecting the data augmentation without validation data and during training of a deep neural network. Code accompanying the paper:<br> <br> "Invariance Learning in Deep Neural Networks with Differentiable Laplace Approximations"<br> Alexander Immer*, Tycho F.A. van der Ouderaa*, Vincent Fortuin, Gunnar Rätsch, Mark van der Wilk.<br> In NeurIPS 2022.<br>

*: equal contribution

LILA on Transformed MNIST Datasets

LILA Illustration

<center><img src="https://github.com/tychovdo/lila/blob/main/figs/paper_figure1.png" width="75%"></center>

Setup

Python 3.8 is required.

pip install -r requirements.txt

Create directory for results: mkdir results in the root of the project.

Install custom Laplace and ASDL

pip install git+https://github.com/kazukiosawa/asdfghjkl.git@dev-alex
pip install git+https://github.com/AlexImmer/Laplace.git@lila

Example runs

Run Illustrative Example and Plot Predictive

python classification_illustration.py --method avgfunc --approximation_structure kron --curvature_type ggn --n_epochs 500 --n_obs 200 --rotation_max 120 --sigma_noise 0.06 --n_samples_aug 100 --rotation_init 0 --optimize_aug --plot --posterior_predictive --lr_aug 0.005 --lr_aug_min 0.00001

Example of ResNet on CIFAR-10

To run LILA:

python classification_image.py --dataset cifar10 --model resnet_8_8 --approx ggn_kron --n_epochs 200 --batch_size 250 --marglik_batch_size 125 --partial_batch_size 50 --lr 0.1 --n_epochs_burnin 10 --n_hypersteps 100 --n_hypersteps_prior 4 --lr_aug 0.05 --lr_aug_min 0.005 --use_jvp --method avgfunc --n_samples_aug 20 --optimize_aug --download

Example of MLP on translated MNIST

To run LILA:

python classification_image.py --method avgfunc --dataset translated_mnist --n_epochs 1000 --device cuda --n_samples_aug 31 --save --optimize_aug --approx ggn_kron --use_jvp --batch_size 1000 --download

To run Augerino:

python classification_image.py --method augerino --dataset translated_mnist --n_epochs 1000 --device cuda --n_samples_aug 31 --save --optimize_aug --approx ggn_kron --batch_size 1000 --seed 1 --download

To run non-invariant baseline:

python classification_image.py --method avgfunc --dataset translated_mnist --n_epochs 1000 --device cuda --n_samples_aug 1 --save --approx ggn_kron --batch_size 1000 --seed 1 --download

Reproducibility

All experiments in the paper can be replicated using runscripts in

experimentscripts/

If you do run into issues, please get in touch so we can provide support.