Awesome
Continual Repeated Annealed Flow Transport Monte Carlo (CRAFT)
and Annealed Flow Transport Monte Carlo (AFT)
The release contains implementations of
- Continual Repeated Annealed Flow Transport Monte Carlo (CRAFT), Matthews et al (2022).
- Annealed Flow Transport Monte Carlo (AFT), Arbel et al (2021).
- Stochastic Normalizing Flows (SNF), Wu et al (2020).
- Sequential Monte Carlo samplers (SMC), Del Moral et al (2006).
- Variational inference with Normalizing Flows (VI), Rezende and Mohamed (2015).
- Particle Markov Chain Monte Carlo (PIMH), Andrieu et al (2010).
The implementation of AFT is based on Algorithm 2 of that paper. See https://arxiv.org/abs/2102.07501 for more details.
The implementation of SNFs differs from the original one in that it exploits the connection with Annealed Importance Sampling with added normalizing flows. The training dynamics are still the same.
The implementation of Particle Markov Chain Monte Carlo is specialized to the case of a final target of interest rather than a time series and assumes an independent proposal, which can be based on SMC, VI or CRAFT.
This is research code. If you are interested in comparing to our work and have questions about how to use it please do reach out.
Installation
The code uses Python 3. We recommend using pip install -e .
which makes an
editable install. A reliable way to do this is within a
virtual environment.
virtualenv -p python3.10 ~/venv/annealed_flow_transport
source ~/venv/annealed_flow_transport/bin/activate
pip install -e .
A GPU is highly recommended. To use one you will need to install JAX with CUDA support. For example:
pip install --upgrade jax jaxlib==0.1.68+cuda111 -f
https://storage.googleapis.com/jax-releases/jax_releases.html
The CUDA version will need to match your GPU drivers. See the JAX documentation for more discussion.
To run the unit tests use the following command:
python -m pytest
Usage
The entry point to the code is main.py
taking a config file as an argument.
As an example from the base directory the following command runs a simple
one dimensional toy example:
python main.py --config=configs/single_normal.py
This example anneals between two one dimensional normal distributions with the same scale and two different locations using CRAFT. The script should print a sequence of steps and return a log normalizing constant estimate.
The config files use the ConfigDict
from
ml_collections to specify all
details of the desired experiment. For example: the algorithm, the MCMC kernel,
and the base distribution and target distribution. More examples can be found in
the configs
directory.
If you specify a model snapshot destination in the config file by setting
config.save_checkpoint = True
and specifying config.params_filename
then you can
store a file to evaluate using PIMH. This is then called using:
python evaluate.py --config=configs/single_normal.py
We have not released code for writing results to disk. The function
train.run_experiments
called from main.py
returns a NamedDict
containing a
summary of results that could be caught and recorded if required.
License information
The code is licensed under the Apache 2.0 license, which can be found in full in
the LICENSE
file.
We have released a pickle model parameters file for the VAE example which is licensed under the Creative Commons Attribution 4.0 International (CC BY 4.0), Full text is found at https://creativecommons.org/licenses/by/4.0/legalcode.
Giving Credit
If you use this code in your work, please cite the corresponding paper i.e Arbel et al. 2021 for AFT and Matthews et al. 2022 for CRAFT.
If you use our baselines such as VI, SMC, SNF, PIMH, please cite the paper of the two below where we first used the method: For SMC and VI this is Arbel et al. 2021, whereas for PIMH and SNF this is Matthews et al. 2022.
@InProceedings{CRAFT2022,
title={Continual Repeated Annealed Flow Transport Monte Carlo},
author={Alexander G. D. G. Matthews and Michael Arbel and Danilo J. Rezende and Arnaud Doucet},
booktitle = {Proceedings of the 39th International Conference on Machine Learning},
series = {Proceedings of Machine Learning Research},
year={2022},
month = {Jul}
}
@InProceedings{AnnealedFlowTransport2021,
title={Annealed Flow Transport Monte Carlo},
author={Michael Arbel and Alexander G. D. G. Matthews and Arnaud Doucet},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
series = {Proceedings of Machine Learning Research},
year={2021},
month = {18--24 Jul}
}
Disclaimer
This is not an official Google product.