Awesome
PyTorch-SSO (alpha release)
Scalable Second-Order methods in PyTorch.
-
Open-source library for second-order optimization and Bayesian inference.
-
An earlier iteration of this library (chainerkfac) holds the world record for large-batch training of ResNet-50 on ImageNet by Kronecker-Factored Approximate Curvature (K-FAC), scaling to batch sizes of 131K.
-
This library is basis for the Natural Gradient for Bayesian inference (Variational Inference) on ImageNet.
- Kazuki Osawa et al, “Practical Deep Learning with Bayesian Principles”, NeurIPS 2019.
- [paper (preprint)]
Scalable Second-Order Optimization
Optimizers
PyTorch-SSO provides the following optimizers.
- Second-Order Optimization
torchsso.optim.SecondOrderOptimizer
[source]- updates the parameters with the gradients pre-conditioned by the curvature of the loss function (
torch.nn.functional.cross_entropy
) for eachparam_group
.
- Variational Inference (VI)
torchsso.optim.VIOptimizer
[source]- updates the posterior distribution (mean, covariance) of the parameters by using the curvature for each
param_group
.
Curvatures
You can specify a type of the information matrix to be used as the curvature from the following.
-
Hessian [WIP]
-
Fisher information matrix
-
Covariance matrix (empirical Fisher)
Refer Information matrices and generalization by Valentin Thomas et al. (2019) for the definitions and the properties of these information matrices.
Refer Section 6 of Optimization Methods for Large-Scale Machine Learning by L´eon Bottou et al. (2018) for a clear explanation of the second-order optimzation using these matrices as curvature.
Approximation Methods
PyTorch-SSO calculates the curvature as a layer-wise block-diagonal matrix.
You can specify the approximation method for the curvatures in each layer from the follwing.
- Full (No approximation)
- Diagonal approximation
- Kronecker-Factored Approximate Curvature (K-FAC)
PyTorch-SSO currently supports the following layers (Modules) in PyTorch:
Layer (Module) | Full | Diagonal | K-FAC |
---|---|---|---|
torch.nn.Linear | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
torch.nn.Conv2d | - | :heavy_check_mark: | :heavy_check_mark: |
torch.nn.BatchNorm1d/2d | - | :heavy_check_mark: | - |
To apply PyTorch-SSO,
- Set
requires_grad
toTrue
for each Module. - The network you define cannot contain any other modules.
- E.g., You need to use
torch.nn.functional.relu/max_pool2d
instead oftorch.nn.ReLU/MaxPool2d
to define a ConvNet.
Distributed Training
PyTorch-SSO supports data parallelism and MC samples parallelism (for VI) for distributed training among multiple processes (GPUs).
Installation
To build PyTorch-SSO run (on a Python 3 environment)
git clone git@github.com:cybertronai/pytorch-sso.git
cd pytorch-sso
python setup.py install
To use the library
import torchsso
Additional requirements
PyTorch-SSO depends on CuPy for fast GPU computation and ChainerMN for communication. To use GPUs, you need to install the following requirements before the installation of PyTorch-SSO.
Running environment | Requirements |
---|---|
single GPU | CuPy |
multiple GPUs | Cupy with NCCL, MPI4py |
Refer CuPy installation guide and ChainerMN installation guide for details.
Examples
- Image classification with a single process (MNIST, CIFAR-10)
- Image classification with multiple processes (CIFAR-10/100, ImageNet)
Authors
Kazuki Osawa (@kazukiosawa) and Yaroslav Bulatov (@yaroslavvb)