Home

Awesome

BooVAE: Boosting Approach for Continual Learning of VAEs

PyTorch implementation of the paper:

BooVAE: Boosting Approach for Continual Learning of VAEs
Anna Kuzina*, Evgenii Egorov*, Evgeny Burnaev

* - equal contribution

Abstract

Variational autoencoder (VAE) is a deep generative model for unsupervised learning, allowing to encode observations into the meaningful latent space. VAE is prone to catastrophic forgetting when tasks arrive sequentially, and only the data for the current one is available. We address this problem of continual learning for VAEs. It is known that the choice of the prior distribution over the latent space is crucial for VAE in the non-continual setting. We argue that it can also be helpful to avoid catastrophic forgetting. We learn the approximation of the aggregated posterior as a prior for each task. This approximation is parametrised as an additive mixture of distributions induced by an encoder evaluated at trainable pseudo-inputs. We use a greedy boosting-like approach with entropy regularisation to learn the components. This method encourages components diversity, which is essential as we aim at memorising the current task with the fewest components possible. Based on the learnable prior, we introduce an end-to-end approach for continual learning of VAEs and provide empirical studies on commonly used benchmarks (MNIST, Fashion MNIST, NotMNIST) and CelebA datasets. For each dataset, the proposed method avoids catastrophic forgetting in a fully automatic way.

Experiments

Environment setup

The exact specification of our environment is provided in the file environment.yml and can be created via

conda env create -f environment.yml

The command above will create an environment boovae with all the required dependencies.

Experiments for the paper

All the hyperparameters can be found in the file config.py. To start training use run_experiment.py, e.g. for VAE with standard normal prior:

python run_experiment.py --config.prior='standard' --config.dataset_name='mnist' --config.incremental=True --config.max_tasks=10

Citation

If you find our paper or code useful, feel free to cite it:

@article{kuzina2021boovae,
  title={BooVAE: Boosting Approach for Continual Learning of VAE},
  author={Kuzina, Anna and Egorov, Evgenii and Burnaev, Evgeny},
  journal={Advances in Neural Information Processing Systems},
  volume={35},
  year={2021}
}