Home

Awesome

Soft Masking for Cost-Constrained Channel Pruning

<!-- ![image](resources/image.png) --> <div align="center"> <img src="./SMCP_teaser.JPG" height="300"> </div> <p align="center"> Figure 1: Top-1 accuracy tradeoff curve for pruning ResNet50 on the ImageNet classification dataset using a latency cost constraint. Baseline is from PyTorch model hub. Accuracy against FPS speed (left) and FLOPs (right) show the benefit of our method, particularly at high pruning ratios. For FPS, top-right is better. For FLOPs, top-left is better. FPS measured on an NVIDIA TITAN V GPU. See paper for more details </p>

Project page | Paper

Soft Masking for Cost-Constrained Channel Pruning.<br> Ryan Humble, Maying Shen, Jorge Albericio Latorre, Eric Darve, and Jose M. Alvarez.<br> ECCV 2022.

Official Pytorch code repository for the "Soft Masking for Cost-Constrained Channel Pruning" paper presented at ECCV 2022 (contact josea@nvidia.com for further inquiries).

Abstract

Structured channel pruning has been shown to significantly accelerate inference time for convolution neural networks (CNNs) on modern hardware, with a relatively minor loss of network accuracy. Recent works permanently zero these channels during training, which we observe to significantly hamper final accuracy, particularly as the fraction of the network being pruned increases. We propose Soft Masking for cost-constrained Channel Pruning (SMCP) to allow pruned channels to adaptively return to the network while simultaneously pruning towards a target cost constraint. By adding a soft mask re-parameterization of the weights and channel pruning from the perspective of removing input channels, we allow gradient updates to previously pruned channels and the opportunity for the channels to later return to the network. We then formulate input channel pruning as a global resource allocation problem. Our method outperforms prior works on both the ImageNet classification and PASCAL VOC detection datasets.

Training Notes

NHWC Memory Layout

The code sets the memory layout as NHWC (PyTorch's channel_last as described here). This comes with performance benefits as described in the NVIDIA DL performance documentation.

Soft channel

We adopt a input channel pruning approach, as described in the paper. The importance and the masks are always done along input channels. However, the cost can be done more flexibly, with the channel-doublesided-weight argument: 1 (the default) is to measure with output channels fixed, 0 is to measure with input channels fixed (like HALP), and numbers in between are a combination.

Soft channel pruning only supports limited architectures. We automatically detect the channel structure of the network (which layers need to be pruned together, which layers can be layer pruned, etc.); this detection logic is only known to work for standard ResNet architectures, MobileNetV1, and SSD512-RN50. Main limitations:

Getting pruned model and measuring latency

Once training is complete, the slimmed model can be obtained by using the method in model_clean.py (which uses channel_slimmer.py internally). This removes the pruned channels and saves the network in its entirety (instead of storing as the state dict; see this for more details). The code does not support saving/loading just the slimmed state dict.

For measuring latency, we can just load the cleaned model back up and measure the forward pass as usual.

Training setup

This repository uses PyTorch Lightning to handle most of the training intricacies, including (but not limited to):

PyTorch Lightning exposes a nice callback mechanism to integrate custom behavior. We implement a DynamicPruning callback class that integrates our pruning code (which does not depend on PyTorch Lightning) and the PyTorch Lightning training setup.

Experiments

Image Classification on ImageNet/CIFAR10

Code located in folder Classification

Run ResNet50 on ImageNet without pruning:

 python -m scmp.classification.image_classifier --dataset Imagenet --data-root=/some/path --gpus=1 --fp16

With dynamic input channel pruning

 ... --prune --channel-type=Global --channel-ratio=0.3

See full set of command line arguments here.

Object Detection on Pascal VOC

Code located in folder Object Detection

Run SSD512-RN50 on PascalVOC without pruning:

 python -m smcp.detection.object_detection --dataset PascalVOC --data-root=/som/path --gpus=1 --fp16

With dynamic input channel pruning

 ... --prune --channel-type=Global --channel-ratio=0.3

Code layout and description (as of 7/19/22)

Classification

Detection

Spare operations

License

Please check the LICENSE file. SMCP may be used non-commercially. For business inquiries, please contact researchinquiries@nvidia.com.

Citation

@article{Humble2022pruning,
  title={Soft Masking for Cost-Constrained Channel Pruning},
  author={Humble, Ryan and Shen, Maying  and Albericio-Latorre, Jorge and Darve, Eric and Alvarez, Jose M},
  journal={ECCV},
  year={2022}
}