Home

Awesome

PyTorch-SSO (alpha release)

Scalable Second-Order methods in PyTorch.

Scalable Second-Order Optimization

Optimizers

PyTorch-SSO provides the following optimizers.

Curvatures

You can specify a type of the information matrix to be used as the curvature from the following.

Refer Information matrices and generalization by Valentin Thomas et al. (2019) for the definitions and the properties of these information matrices.

Refer Section 6 of Optimization Methods for Large-Scale Machine Learning by L´eon Bottou et al. (2018) for a clear explanation of the second-order optimzation using these matrices as curvature.

Approximation Methods

PyTorch-SSO calculates the curvature as a layer-wise block-diagonal matrix.

You can specify the approximation method for the curvatures in each layer from the follwing.

  1. Full (No approximation)
  2. Diagonal approximation
  3. Kronecker-Factored Approximate Curvature (K-FAC)

PyTorch-SSO currently supports the following layers (Modules) in PyTorch:

Layer (Module)FullDiagonalK-FAC
torch.nn.Linear:heavy_check_mark::heavy_check_mark::heavy_check_mark:
torch.nn.Conv2d-:heavy_check_mark::heavy_check_mark:
torch.nn.BatchNorm1d/2d-:heavy_check_mark:-

To apply PyTorch-SSO,

Distributed Training

PyTorch-SSO supports data parallelism and MC samples parallelism (for VI) for distributed training among multiple processes (GPUs).

Installation

To build PyTorch-SSO run (on a Python 3 environment)

git clone git@github.com:cybertronai/pytorch-sso.git
cd pytorch-sso
python setup.py install

To use the library

import torchsso

Additional requirements

PyTorch-SSO depends on CuPy for fast GPU computation and ChainerMN for communication. To use GPUs, you need to install the following requirements before the installation of PyTorch-SSO.

Running environmentRequirements
single GPUCuPy
multiple GPUsCupy with NCCL, MPI4py

Refer CuPy installation guide and ChainerMN installation guide for details.

Examples

Authors

Kazuki Osawa (@kazukiosawa) and Yaroslav Bulatov (@yaroslavvb)