Home

Awesome

Bayes-Newton

Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in JAX (with objax), built and maintained by Will Wilkinson.

Bayes-Newton provides a unifying view of approximate Bayesian inference, and allows for the combination of many models (e.g. GPs, sparse GPs, Markov GPs, sparse Markov GPs) with the inference method of your choice (VI, EP, Laplace, Linearisation). For a full list of the methods implemented scroll down to the bottom of this page.

The methodology is outlined in the following article:

Installation

Latest (stable) release from PyPI

pip install bayesnewton

For development, you might want to use the latest source from GitHub: In a check-out of the develop branch of the BayesNewton GitHub repository, run

pip install -e .

Step-by-step: Getting started with the examples

For running the demos or experiments in this repository or building on top of it, you can follow these steps for creating a virtual environment and activating it:

python3 -m venv venv
source venv/bin/activate

Installing all required dependencies for the examples:

python -m pip install -r requirements.txt
python -m pip install -e .

Running the tests requires additionally a specific version of GPflow to test against:

python -m pip install pytest
python -m pip install tensorflow==2.10 tensorflow-probability==0.18.0 gpflow==2.5.2

Run tests

cd tests; pytest

Simple Example

Given some inputs x and some data y, you can construct a Bayes-Newton model as follows,

kern = bayesnewton.kernels.Matern52()
lik = bayesnewton.likelihoods.Gaussian()
model = bayesnewton.models.MarkovVariationalGP(kernel=kern, likelihood=lik, X=x, Y=y)

The training loop (inference and hyperparameter learning) is then set up using objax's built in functionality:

lr_adam = 0.1
lr_newton = 1
opt_hypers = objax.optimizer.Adam(model.vars())
energy = objax.GradValues(model.energy, model.vars())

@objax.Function.with_vars(model.vars() + opt_hypers.vars())
def train_op():
    model.inference(lr=lr_newton, **inf_args)  # perform inference and update variational params
    dE, E = energy(**inf_args)  # compute energy and its gradients w.r.t. hypers
    opt_hypers(lr_adam, dE)  # update the hyperparameters
    return E

As we are using JAX, we can JIT compile the training loop:

train_op = objax.Jit(train_op)

and then run the training loop,

iters = 20
for i in range(1, iters + 1):
    loss = train_op()

Full demos are available here.

Citing Bayes-Newton

@article{wilkinson2023bayes,
  title={{B}ayes--{N}ewton Methods for Approximate {B}ayesian Inference with {PSD} Guarantees},
  author={Wilkinson, William J and S{\"a}rkk{\"a}, Simo and Solin, Arno},
  journal={Journal of Machine Learning Research},
  volume={24},
  number={83},
  pages={1--50},
  year={2023}
}

Implemented Models

For a full list of the all the models available see the model class list.

Variational GPs

Expectation Propagation GPs

Laplace/Newton GPs

Linearisation GPs

Gauss-Newton GPs

(Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

Quasi-Newton GPs

(Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

GPs with PSD Constraints via Riemannian Gradients

Others

License

This software is provided under the Apache License 2.0. See the accompanying LICENSE file for details.