Awesome
Continual Learning via Sequential Function-Space Variational Inference (S-FSVI)
This repository contains the official implementation for
Continual Learning via Sequential Function-Space Variational Inference; Tim G. J. Rudner, Freddie Bickford Smith, Qixuan Feng, Yee Whye Teh, Yarin Gal. ICML 2022.
Abstract: Sequential Bayesian inference over predictive functions is a natural framework for continual learning from streams of data. However, applying it to neural networks has proved challenging in practice. Addressing the drawbacks of existing techniques, we propose an optimization objective derived by formulating continual learning as sequential function-space variational inference. In contrast to existing methods that regularize neural network parameters directly, this objective allows parameters to vary widely during training, enabling better adaptation to new tasks. Compared to objectives that directly regularize neural network predictions, the proposed objective allows for more flexible variational distributions and more effective regularization. We demonstrate that, across a range of task sequences, neural networks trained via sequential function-space variational inference achieve better predictive accuracy than networks trained with related methods while depending less on maintaining a set of representative points from previous tasks.
<p align="center"> — <a href="https://timrudner.com/sfsvi"><b>View Paper</b></a> — </p>In particular, this codebase includes:
- An implementation of the sequential function-space variational objective [1];
- Notebooks that reproduce the results in the paper;
- A general, easy-to-extend continual learning training and evaluation protocol;
- A set of framework-agnostic dataloader methods for widely used continual learning tasks;
[1] The implementation is based on the approximation proposed in <a href="https://timrudner.com/fsvi">Tractable Function-Space Variational Inference in Bayesian Neural Networks</a> (Rudner et al., 2022).
<br> <br> <p align="center"> <img src="images/schematic.png" alt="Figure 1" width="80%"/><br> <b>Figure 1.</b> Schematic of sequential function-space variational inference. </p>Installation
To install requirements:
$ conda env update -f environment.yml
$ conda activate fsvi
This environment includes all necessary dependencies.
To create an fsvi
executable to run experiments, run pip install -e .
.
Reproducing results
Split MNIST, Permuted MNIST, and Split FashionMNIST
Method | Split MNIST (MH) <br/> | Split FashionMNIST (MH) <br/> | Permuted MNIST (SH) <br/> | Split MNIST (SH) <br/> |
---|---|---|---|---|
S-FSVI (ours) | 99.54% ± 0.04 | 99.05% ± 0.03 | 95.76% ± 0.02 | 92.87% ± 0.14 |
S-FSVI (larger networks) | 99.76% ± 0.00 | 98.50% ± 0.11 | 97.50% ± 0.01 | 93.38% ± 0.10 |
S-FSVI (no coreset) | 99.62% ± 0.01 | 99.17% ± 0.06 | 84.06% ± 0.46 | 20.15% ± 0.52 |
S-FSVI (minimal coreset [2]) | NA [3] | NA [3] | 89.59% ± 0.30 | 51.44% ± 1.22 |
[2] "Minimal coresets" are constructed by randomly selecting one data point per class for a given task.
[3] Since S-FSVI already performs well without a coreset, the minimal coreset option is not useful.
Split CIFAR
Method | Split CIFAR (MH) <br/> |
---|---|
S-FSVI [4] | 77.57% ± 0.84 |
Sequential Omniglot
Method | Sequential Omniglot (MH) <br/> |
---|---|
S-FSVI [4] | 83.29% ± 1.2 |
[4] To speed up training and reduce the memory requirements, only the variance parameters in the final layer of the network are learned variationally and the linearization is computed on the final layer only.
2D Visualization
This notebook demonstrates continual learning via S-FSVI on a sequence of five binary-classification tasks in a 2D input space.
<p align="center"> <img src="images/toy2D.png" alt="Figure 2" width="90%"/><br> <b>Figure 2.</b> Predictive distributions of a model trained via S-FSVI on tasks 1-5. </p>Adding new methods or tasks
- To implement a new method, create a file
method_cl_methodname.py
in/benchmarking
. For reference, see/benchmarking/method_cl_template.py
and/benchmarking/method_cl_fsvi.py
. - To implement a new dataloader, add a new method to
benchmarking/data_loaders
.
Citation
@InProceedings{rudner2022continual,
author={Tim G. J. Rudner and Freddie Bickford Smith and Qixuan Feng and Yee Whye Teh and Yarin Gal},
title = {{C}ontinual {L}earning via {S}equential {F}unction-{S}pace {V}ariational {I}nference},
booktitle ={Proceedings of the 39th International Conference on Machine Learning},
year = {2022},
series ={Proceedings of Machine Learning Research},
publisher ={PMLR},
}
Please cite our paper if you use this code in your own work.