Home

Awesome

Channel Distillation

PyTorch implement of Channel Distillation: Channel-Wise Attention for Knowledge Distillation

Innovation

  1. Channel Distillation (CD)
  2. Guided Knowledge Distillation (GKD)
  3. Early Decay Teacher (EDT)

Note

In our code, kdv2 means GKD and lrd2 means EDT.

Structure of Repository

├── cifar_config.py  # Hyperparameters
├── cifar_train.py
├── data
│   └── directory_of_data.md
├── imagenet_config.py  # Hyperparameters
├── imagenet_train.py
├── losses
│   ├── cd_loss.py  # CD Loss
│   ├── ce_loss.py
│   ├── __init__.py
│   └── kd_loss.py  # GKD Loss
├── models
│   ├── channel_distillation.py  # Distillation Network
│   ├── __init__.py
│   └── resnet.py
├── pretrain
│   └── path_of_teacher_checkpoint.md
├── README.md
└── utils
    ├── average_meter.py
    ├── data_prefetcher.py
    ├── __init__.py
    ├── logutil.py
    ├── metric.py
    └── util.py  # Early Decay Teacher

Requirements

python >= 3.7
torch >= 1.4.0
torchvision >= 0.5.0

Experiments

ImageNet

Prepare Dataset

images should be arranged in this way

./data/ILSVRC2012/train/dog/xxx.png
./data/ILSVRC2012/train/cat/xxy.png
./data/ILSVRC2012/val/dog/xxx.png
./data/ILSVRC2012/val/cat/xxy.png

Training

Note

Teacher checkpoint will be downloaded automatically.

Running the following command and experiment will be launched.

CUDA_VISIBLE_DEVICES=0 python3 ./imagenet_train.py

If you want to run other experiments, you just need modify following losses in imagenet_config.py

loss_list = [
    {"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv1"},
    {"loss_name": "CDLoss", "loss_rate": 6, "factor": 1, "loss_type": "fd_family", "loss_rate_decay": "lrdv1"},
]
loss_list = [
    {"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv1"},
    {"loss_name": "KDLossv2", "T": 1, "loss_rate": 1, "factor": 1, "loss_type": "kdv2_family", "loss_rate_decay": "lrdv1"},
    {"loss_name": "CDLoss", "loss_rate": 6, "factor": 0.9, "loss_type": "fd_family", "loss_rate_decay": "lrdv1"},
]
loss_list = [
    {"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv2"},
    {"loss_name": "KDLossv2", "T": 1, "loss_rate": 1, "factor": 1, "loss_type": "kdv2_family", "loss_rate_decay": "lrdv2"},
    {"loss_name": "CDLoss", "loss_rate": 6, "factor": 0.9, "loss_type": "fd_family", "loss_rate_decay": "lrdv2"},
]

Result

MethodModelTop-1 error(%)Top-5 error(%)
teacherResNet3426.738.74
studentResNet1830.4310.76
KDResNet34-ResNet1829.509.52
CD(our)ResNet34-ResNet1828.539.56
CD+GKD(our)ResNet34-ResNet1828.269.41
CD+GKD+EDT(our)ResNet34-ResNet1827.619.2
MethodModelTop-1 error(%)Top-5 error(%)
teacherResNet3426.738.74
studentResNet1830.4310.76
KDResNet34-ResNet1829.509.52
FitNetsResNet34-ResNet1829.3410.77
ATResNet34-ResNet1829.3010.00
RKDResNet34-ResNet1828.469.74
CD+GKD+EDT(our)ResNet34-ResNet1827.619.2

CIFAR100

Prepare Dataset

CIFAR100 dataset will be downloaded automatically.

Training

Note

Download the teacher checkpoint from here
Then, put the checkpoint in the pretrain directory

Running the following command and experiment will be launched.

CUDA_VISIBLE_DEVICES=0 python3 ./cifar_train.py

If you want to run other experiments, you just need modify following losses in cifar_config.py

loss_list = [
    {"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv1"},
    {"loss_name": "CDLoss", "loss_rate": 6, "factor": 1, "loss_type": "fd_family", "loss_rate_decay": "lrdv1"},
]
loss_list = [
    {"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv1"},
    {"loss_name": "KDLossv2", "T": 1, "loss_rate": 0.1, "factor": 1, "loss_type": "kdv2_family", "loss_rate_decay": "lrdv1"},
    {"loss_name": "CDLoss", "loss_rate": 6, "factor": 0.9, "loss_type": "fd_family", "loss_rate_decay": "lrdv1"},
]
loss_list = [
    {"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv2"},
    {"loss_name": "KDLossv2", "T": 1, "loss_rate": 0.1, "factor": 1, "loss_type": "kdv2_family","loss_rate_decay": "lrdv2"},
    {"loss_name": "CDLoss", "loss_rate": 6, "factor": 0.9, "loss_type": "fd_family", "loss_rate_decay": "lrdv2"},
]

Result

MethodModelTop-1 error(%)Top-5 error(%)
teacherResNet15219.094.45
studentResNet5022.025.74
KDResNet152-ResNet5020.364.94
CD(our)ResNet152-ResNet5020.084.78
CD+GKD(our)ResNet152-ResNet5019.494.85
CD+GKD+EDT(our)ResNet152-ResNet5018.634.29