Home

Awesome

Prune Your Model Before Distill It

Welcome

This is an PyTorch implement of the paper ``Prune Your Model Before Distill It''.

Table of Contents

1 How to run

2 ResNet Optimization

<a name=how_to_run></a>1 How to run

1.1 Full framework

python main.py --pre_train --pruning --kd
  1. code progress

1.2 Train Model

The following command will train the model.

python main.py --pre_train

1.3 Prune the Teacher

python main.py --pruning

1.4 Distill the Teacher to Student

python main.py --kd

1.5 Available models

    cifar100_models = {
        'vgg11',
        'vgg19',
        'vgg19-rwd-cl1',
        'vgg19-rwd-cl2',
        'vgg19-rwd-st36',
        'vgg19-rwd-st59',
        'vgg19-rwd-st79',
        'vgg19dbl',
        'vgg19dbl-rwd-st36',
        'vgg19dbl-rwd-st59',
        'vgg19dbl-rwd-st79',
        'vgg-custom'
    }
    tiny_imagenet_models = {
        'vgg16',
        'resnet18',
        'resnet18-rwd-st36',
        'resnet18-rwd-st59',
        'resnet18-rwd-st79',
        'resnet18dbl',
        'resnet18dbl-rwd-st36',
        'resnet18dbl-rwd-st59',
        'resnet18dbl-rwd-st79',
        'resnet50',
        'resnet50-rwd-st36',
        'resnet50-rwd-st59',
        'resnet50-rwd-st79',        
        'mobilenet-v2'
    }

<a name=resnet_opt></a>2 ResNet Optimization

If the maxpooling layer is used in conv1, a larger batche_size and learning rate should be used, and the accuracy gain obtained by using the pruned teacher may slightly decrease.