Home

Awesome

pytorch-checkpoint

Gradient checkpointing is a technique to reduce GPU memory cost.

Official implementation

There exists a PyTorch implementaion in the official repo. However, it is extremely slow with multiple GPUs.

This implementation

This repo contains a PyTorch implemention that can work on multiple GPUs.

Main results

Method# GPUBatchMemoryTime
Naive22565.25G0.27s
Official22562.98G1.41s
This repo22562.97G0.31s

Documentation

The main functionality is in checkpoint.py

import checkpoint
checkpoint.CheckpointFunction.apply(function, n, *args)

Parameters:

Returns:

Note: We recommend using checkpointing with cp_BatchNorm2d instead of torch.nn.BatchNorm2d, to avoid accumulating the same batch norm statistics more than once.

DenseNet example

We provide an example of applying our checkpointing on memory efficient densenet. It only involves changing a few lines in the original implementation. (The original implementation uses PyTorch official checkpointing.)

# bn_function is a function containing conv1, norm1, relu1.
# naive no checkpointing: bottleneck_output = bn_function(*prev_features)
# official implementation: bottleneck_output = cp.checkpoint(bn_function, *prev_features)
args = prev_features + tuple(self.norm1.parameters()) + tuple(self.conv1.parameters())
# The parameters to optimize in the bn_function are tuple(self.norm1.parameters()) + tuple(self.conv1.parameters())
bottleneck_output = cp.CheckpointFunction.apply(bn_function, len(prev_features), *args)

Demo

python-fire is not required for checkpointing, but is required for the efficient densenet demo.

pip install fire
CUDA_VISIBLE_DEVICES=0,1 python cp_demo.py --efficient True --data cifar --save model --batch_size 256
CUDA_VISIBLE_DEVICES=0,1 python original_demo.py --efficient True --data cifar --save model --batch_size 256

Environment

This code is tested with PyTorch 1.0.0.dev20181102

Speed tested on TITAN X (Pascal)

Full results

Method# GPUBatchMemoryTime
Naive12569.93G0.42s
Naive240.65G0.10s
Naive22565.25G0.27s
Naive25129.93G0.50s
Official12565.38G0.52s
Official151210.1G1.00s
Official240.62G1.40s
Official22562.98G1.41s
Official25125.39G1.53s
This repo12565.37G0.50s
This repo151210.1G0.97s
This repo240.62G0.13s
This repo22562.97G0.31s
This repo25125.37G0.58s

Credits

Part of our code in checkpoint.py and cp_BatchNorm2d.py is from https://github.com/pytorch/pytorch

The efficient densenet demo is taken from https://github.com/gpleiss/efficient_densenet_pytorch