Home

Awesome

Optimal Gradient Checkpoint Search for Arbitrary Computation Graphs

This is the official implementation of the paper:

Jianwei Feng and Dong Huang, Optimal Gradient Checkpoint Search for Arbitrary Computation Graphs, CVPR, 2021 (Oral). ArXiv version

Citation:

@inproceedings{fenghuang2021,
  title={Optimal Gradient Checkpoint Search for Arbitrary Computation Graphs},
  author={Jianwei Feng and Dong Huang},
  booktitle={CVPR},
  year={2021}
}

Regular Training vs. Gradient CheckPointing(GCP) Training: (a) The regular training stores all tensors during forward, and uses these tensors to compute gradients during backward. (b) GCP stores a subset of tensors during the first forward, and conducts extra local re-forwards to compute tensors and gradients during backward. Our approach automatically searches the optimal set of Gradient Checkpoints (GCs) for memory cut-off. Such that on the same physical GPU memory (e.g., in 4 RTX2080Ti GPUs), GCP training can accommodate models that require 2+ times extra GPU memory.

scheme_compare<!-- .element height="20%" width="20%" --> table_compare<!-- .element height="20%" width="20%" -->

Reducing Training Memory by Optimal Gradient Checkpointing

Model NameInput SizeRegular Training Memory (MB)OGC Training Memory (MB)Memory Cut offRegular Training Time (s)OGC Training Time (s)Time Overhead
Alexnet(1024, 3, 224, 224)4955328734%0.3880.51934%
VGG 11(64, 3, 224, 224)3577278122%0.2660.35634%
VGG 13(64, 3, 224, 224)5136356531%0.4180.55833%
VGG 16(64, 3, 224, 224)5136356531%0.5030.66632%
VGG 19(64, 3, 224, 224)5189356531%0.5810.77433%
Resnet 18(256, 3, 224, 224)5635367735%0.4220.54830%
Resnet 34(128, 3, 224, 224)4079183855%0.3640.49335%
Resnet 50(64, 3, 224, 224)5323197363%0.3940.51631%
Resnet 101(32, 3, 224, 224)3934102474%0.3560.48235%
Resnet 152(16, 3, 224, 224)276752681%0.2410.33137%
Densenet 121(32, 3, 224, 224)402789878%0.2180.29234%
Densenet 161(16, 3, 224, 224)375166682%0.2520.34136%
Densenet 169(32, 3, 224, 224)486289782%0.2700.35732%
Densenet 201(16, 3, 224, 224)314647485%0.2000.30653%
Inception v3(32, 3, 300, 300)307488171%0.2910.37429%
NASNet Cifar10(64, 3, 32, 32)5832112981%0.4080.53531%
AmoebaNet Cifar10(64, 3, 32, 32)4944105879%0.3310.45036%
DARTS Cifar10(64, 3, 32, 32)5627111580%0.3180.49455%

The memory numbers in the table are calculated after removing stationary memory cost such as model weights and pytorch cuda interface.

Updates

Next step: More tests and bug fix in automatic computation graph parser

2021.06.17, we implemented automatic computation graph parser using torch.jit! Now user can input an arbitrary pytorch model (nn.module) and runs optimal gradient checkpointing solver without manually parsing the computation graph.

Installation

Requirements:

Step-by-step

Install pytorch 1.5 from https://pytorch.org/

Install other dependencies.

pip install -r requirements.txt

Usage

Benchmark Time and Memory Performance of Optimal Gradient Checkpointing

Run Optimal Gradient Checkpointing on Resnet101 and benchmark training memory cost and time cost

python benchmark.py --arch resnet101 --device cuda:0

Train Model with Optimal Gradient Checkpointing

An example of training on CIFAR 10 dataset can be seen at train_cifar10.ipynb

The script will do the following:

  1. Define a pytorch model (nn.Module) and its inputs
  2. Parse the computation graph of the pytorch model with torch.jit, runs optimal gradient checkpointing algorithm, and returns a new model whose forward function will execute gradient checkpointing training with optimal gradient checkpoints
  3. Run gradient checkpointing training and evaluation on CIFAR 10 for 2 epochs

When defining the network, we highly recommend using modules and operations in torch.nn instead of in torch.nn.functional, for example use torch.nn.ReLU instead of torch.nn.functional.relu.

So far our computation graph parser can handle operations in torch.nn pretty well but might be buggy with torch.nn.functional. We will continue to improve it.

Implement parse_graph function for Custom Network

You can also implement your own parse_graph function to create computation graph for your network.

An example of custom model and parse_graph function can be seen at manual_parse_graph.ipynb