Home

Awesome

Data-Free Knowledge Distillation For Deep Neural Networks

<div align="center"> <img alt="Production pipeline image" src="imgs/production_pipeline.png" /> </div>

Abstract

Recent advances in model compression have provided procedures for compressing large neural networks to a fraction of their original size while retaining most if not all of their accuracy. However, all of these approaches rely on access to the original training set, which might not always be possible if the network to be compressed was trained on a very large non-public dataset. In this work, we present a method for data-free knowledge distillation, which is able to compress deep models to a fraction of their size leveraging only some extra metadata to be provided with a pretrained model release. We also explore different kinds of metadata that can be used with our method, and discuss tradeoffs involved in using each of them.

Paper

The paper is currently under review for AISTATS 2018 and NIPS 2017 LLD Workshop. You can feel free to read the arxiv or the poster for this work, or contact me with any questions.

Raphael Gontijo LopesStefano Fenu

Overview

Our method for knowledge distillation has a few different steps: training, computing layer statistics on the dataset used for training, reconstructing (or optimizing) a new dataset based solely on the trained model and the activation statistics, and finally distilling the pre-trained "teacher" model into the smaller "student" network. Each of these steps constitute a "procedure", which are implemented in the procedures/ module. Each procedure implements a run function, which does everything from loading models to training.

When optimizing a dataset reconstruction, there's also the choice of different optimization objectives (top layer, all layers, spectral all layers, spectral layer pairs, all discussed in the paper). These are implemented in procedures/_optimization_objectives.py, and take care of creating the optimization and loss operations, as well as of sampling from the saved activation statistics and creating a feed_dict that loads all necessary placeholders.

Every dataset goes under datasets/, and needs to implement the same interface as datasets/mnist.py. Namely, the dataset class needs to have an io_size property that specifies the input size and the label output size. It also needs two iterator methods: train_epoch_in_batches and test_epoch_in_batches.

Note: Credit for the attribute and data files of CelebA dataset is given to this repo.

We provide four models in models/: two fully connected and two convolutional. The fully connected models are hinton-1200 and hinton-800, as described in the original knowledge distillation paper. The convolutional models are LeNet-5, and a modified version of it which has half the number of convolutional filters per layer. Each model implemented to be a teacher network needs to implement all three functions in the interface: create_model, load_model, and load_and_freeze_model. If a model is meant to be a student network, like lenet_half and hinton800, then it need only implement create_model.

Every artifact created will be saved under summaries/, the default --summary_folder. This includes tf summaries, checkpoints, optimized datasets, log files with information about the experiment run, activation statistics, etc.

On the newly added VGG11, 16 and 19 models, there is an option to initialize the layers with ImageNet pre-trained layers. You can get those *.npy files here.

Requirements

This code requires that you have tensorflow 1.0 installed, along with numpy and scikit-image 0.13.0 on python 3.6+.

The visualization scripts (used to debug optimized/reconstructed datasets) also require opencv 3.2.0 and matplotlib.

Usage

Train and Save a Model

First, we need to have the model trained on the original dataset. This step can be skipped if you already have a pre-trained model that you can easily load through the same interface as the ones in models/.

The procedure flag specifies what to do with the model and dataset. In this case, train it from scratch.

python main.py --run_name=experiment --model=hinton1200 --dataset=mnist \
    --procedure=train

Compute and Save Statistics for that Model

We use the original dataset to compute layer statistics for the model. These are the "metadata" mentioned in the paper, which we save so we can reconstruct a dataset representative of the original one.

The model_meta and model_checkpoint flags are required because the compute_stats procedure loads a pre-trained model. If you are planning on optimizing a dataset with a spectral optimization objective, you need to compute stats with the flag compute_graphwise_stats=True. The reason why this is not done by default is because graphwise statistics are computationally expensive.

python main.py --run_name=experiment --model=hinton1200 --dataset=mnist \
    --procedure=compute_stats \
    --model_meta=summaries/experiment/train/checkpoint/hinton1200-8000.meta \
    --model_checkpoint=summaries/experiment/train/checkpoint/hinton1200-8000

Optimize a Dataset Using the Saved Model and the Statistics

This is where the real magic happens. We use the saved metadata and the pre-trained model (but not the original dataset) to reconstruct/optimize a new dataset that maximally reconstruct samples from the activation statistics. These samples and the corresponding objective loss can take different forms (top_layer, all_layers, all_layers_dropout, spectral_all_layers, spectral_layer_pairs), which are discussed in the paper. Note that all_layers_dropout is meant for teacher models that are trained with dropout. Currently, we only provide hinton1200 that does. Also note that spectral optimization objectives require that the compute_graphwise_stats be set when running compute_stats

The pre-trained model is loaded, and a new graph is constructed using its saved weights, but as tf.constant. This ensures that the only thing being back-propagated to is the input tf.Variable, which is initialized to random noise.

The optimization_objective flag is needed to determine what loss to use (see paper for details, coming soon on arxiv). The dataset flag is only needed to determine io_size, so if you're using a pre-trained model+statistics that you don't have the original data for, you can mock the dataset class and simply provide the self.io_size attribute. Using all of this, a new dataset will be reconstructed and saved.

python main.py --run_name=experiment --model=hinton1200 --dataset=mnist \
    --procedure=optimize_dataset \
    --model_meta=summaries/experiment/train/checkpoint/hinton1200-8000.meta \
    --model_checkpoint=summaries/experiment/train/checkpoint/hinton1200-8000 \
    --optimization_objective=top_layer --lr=0.07
    # or all_layers, spectral_all_layers, spectral_layer_pairs

Distilling a Model Using One of the Reconstructed Datasets

You can then train a student network on the reconstructed dataset, and the temperature-scaled teacher model activations. This time, the dataset flag is the location where the reconstructed dataset was saved. Additionally, a student_model needs to be specified to be trained from scratch. If you want to evaluate the student's performance on the original test set (if you have access to it), you can specify it as the eval_dataset.

python main.py --run_name=experiment --model=hinton1200 \
    --dataset="summaries/experiment/data/data_optimized_top_layer_experiment_<clas>_<batch>.npy" \
    --procedure=distill \
    --model_meta=summaries/experiment/train/checkpoint/hinton1200-8000.meta \
    --model_checkpoint=summaries/experiment/train/checkpoint/hinton1200-8000 \
    --eval_dataset=mnist --student_model=hinton800 --epochs=30 --lr=0.00001

Distilling a Model Using Vanilla Knowledge Distillation

If you do have access to the original dataset, or you want to run Hinton's original Knowledge Distillation paper, you can just specify that dataset flag.

In order to run this baseline, you only need to have a pre-trained model, and the dataset it was originally trained with. This means you can skip the compute_stats and optimize_dataset steps.

python main.py --run_name=experiment --model=hinton1200 --dataset=mnist \
    --procedure=distill \
    --model_meta=summaries/experiment/train/checkpoint/hinton1200-8000.meta \
    --model_checkpoint=summaries/experiment/train/checkpoint/hinton1200-8000 \
    --eval_dataset=mnist --student_model=hinton800 --lr=0.0001

Tips and Tricks

When using the lenet models, it should be noted that the original paper specified that mnist was resized from 28x28 pixel images to 32x32. Thus, then using the convolutional models we provide, make sure to use the mnist_conv dataset, which automatically resize the input images. The rest of the usage should be exactly the same.

python main.py --run_name=experiment --model=lenet --dataset=mnist_conv \
    --procedure=train

python main.py --run_name=experiment --model=lenet --dataset=mnist_conv \
    --procedure=compute_stats \
    --model_meta=summaries/experiment/train/checkpoint/lenet-8000.meta \
    --model_checkpoint=summaries/experiment/train/checkpoint/lenet-8000

python main.py --run_name=experiment --model=lenet --dataset=mnist_conv \
    --procedure=optimize_dataset \
    --model_meta=summaries/experiment/train/checkpoint/lenet-8000.meta \
    --model_checkpoint=summaries/experiment/train/checkpoint/lenet-8000 \
    --optimization_objective=top_layer
    # or all_layers, spectral_all_layers, spectral_layer_pairs

python main.py --run_name=experiment --model=lenet \
    --dataset="summaries/experiment/data/data_optimized_top_layer_experiment_<clas>_<batch>.npy" \
    --procedure=distill \
    --model_meta=summaries/experiment/train/checkpoint/lenet-8000.meta \
    --model_checkpoint=summaries/experiment/train/checkpoint/lenet-8000 \
    --eval_dataset=mnist_conv --student_model=lenet_half --epochs=30

Visualization

The viz/ directory contains useful scripts you can run to visualize the saved statistics and optimized datasets.

Print top layer per class means and standard deviations

python viz/print_stats.py --run_name=experiment

Vizualize per-class and per-pixel means, as well as a randomly selected example of an optimized dataset.

Note that pixel_intensities_batch.py expects a single batch, so you have to add the specific batch you want to visualize. Since batches are saved separately, we no longer have a script to visualize all batches at once.

python viz/pixel_intensities_batch.py \
    --dataset=summaries/experiment/data/data_optimized_all_layers_dropout_experiment_0_0.npy
<div align="center"> <img alt="Per-class and per-pixel means and randoms image" src="imgs/means_and_random.png" width="90%"/> </div>

Compare per-class, per-output normal distribution for both student and teacher.

Note that this requires that you have run compute_stats on a distilled student model. You might also want to play with the script to make each subplot zoomed in the best area (or apply a softmax over each mean/stddev pair of the statistics).

python viz/stats_viz.py \
    --teacher_stats=summaries/experiment/stats/activation_stats_experiment.npy \
    --student_stats=summaries/train_distilled_student/stats/activation_stats_train_distilled_student.npy

See what a sample from the top layer statistics looks like.

python viz/get_stats_sample.py --run_name=experiment

License

MIT

Appendix

For a longer description of each optimization objective, please refeer to the paper, which will be published on arxiv soon. In the meantime, if this readme was not sufficient, here are some diagrams from the paper:

<table> <tr> <td colspan="2" align="center"> <img alt="Hinton's Knowledge Distillation Diagram" src="imgs/pure_distill.png" width="50%"/><br/> </td> </tr> <tr> <td colspan="2" align="center"> <i>Hinton's Knowledge Distillation</i> </td> </tr> <tr> <td align="center"> <img alt="Top Layer Input Reconstruction and Distillation Diagram" src="imgs/top_layer.png" ><br/> </td> <td align="center"> <img alt="All Layers Input Reconstruction and Distillation Diagram" src="imgs/all_layers.png" /><br/> </td> </tr> <tr> <td align="center"> <i>Top Layer Input Reconstruction and Distillation</i> </td> <td align="center"> <i>All Layers Input Reconstruction and Distillation</i> </td> </tr> <tr> <td align="center"> <img alt="Spectral All Layers Input Reconstruction and Distillation Diagram" src="imgs/spectral_all_layers.png" /><br/> </td> <td align="center"> <img alt="Spectral Layer Pairs Input Reconstruction and Distillation Diagram" src="imgs/spectral_layer_pairs.png" /><br/> </td> </tr> <tr> <td align="center"> <i>Spectral All Layers Input Reconstruction and Distillation</i> </td> <td align="center"> <i>Spectral Layer Pairs Input Reconstruction and Distillation</i> </td> </tr> </tr> </table>

TODO(sfenu3): accelerate spectral optimization objectives