Awesome
Post-Training Sparsity-Aware Quantization
This repository is the official implementation of Post-Training Sparsity-Aware Quantization.
@article{shomron2021sparq,
title={Post-Training Sparsity-Aware Quantization},
author={Shomron, Gil and Gabbay, Freddy and Kurzum, Samer and Weiser, Uri},
journal={arXiv preprint arXiv:2105.11010},
year={2021}
}
Requirements
Install PyTorch. Specifically, we use version 1.5.1 with CUDA 10.1.
pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
Pick PyTorch 1.5.1 with the appropriate CUDA version from the official PyTorch website.
Then, install the other packages and our custom CUDA package:
pip install -r requirements.txt
cd cu_gemm_2x48
python ./setup install
The ImageNet path, as well as the seeds used to achieve the paper's results, are configured in Config.py
.
Throughout this work, we used Ubuntu 18.04, Python 3.6.9, and NVIDIA TITAN V GPU.
8-bit Model Quantization
SPARQ operates on top 8-bit models. To quantize the models, execute the following command:
python ./main.py -a resnet18_imagenet --action QUANTIZE --x_bits 8 --w_bits 8
We support the following models: resnet18_imagenet
, resnet34_imagenet
, resnet50_imagenet
, resnet101_imagenet
, googlenet_imagenet
, inception_imagenet
, densenet_imagenet
.
SPARQ Evaluation
To evaluate the quantized models, execute the following:
python ./main.py -a resnet18_imagenet
--action INFERENCE
--chkp [PATH^]
--x_bits 8 --w_bits 8
--eval --round_mode RAW --shift_opt 5 --bit_group 4
where round_mode
is either RAW
or ROUND
, shift_opt
correspond to opt5, opt3, and op2, and bit_group
correspond to 4, 3, and 2 window bit-width.
[^] PATH
points to the quantized model and is relative to the data/results
folder.
[^^] shift_opt 5
when paired with bit_group 3
and 2
correspond to opt6 and opt7, respectively.
[^^^] In the paper, we used --batch_size 32
with InceptionV3.
Pre-trained 2:4 Pruned Models
Models were trained using this PyTorch script and NVIDIA ASP library. We used 90 epochs with a learning rate starting from 0.1 and divided by 10 on epochs 30 and 60. Weight decay and momentum are set to 0.0001 and 0.9, respectively. The pruned models can be downloaded from the links in the table below.
Model | Top-1 | Download |
---|---|---|
ResNet-18 | 69.77% | Link |
ResNet-50 | 76.16% | Link |
ResNet-101 | 77.38% | Link |
To evaluate SPARQ on top of the pruned models, quantize the models as before, and add --stc
flag to the evaluation command line.