Awesome
Knowledge Distillation from A Stronger Teacher (DIST)
Official implementation of paper "Knowledge Distillation from A Stronger Teacher" (DIST), NeurIPS 2022.
By Tao Huang, Shan You, Fei Wang, Chen Qian, Chang Xu.
:fire: DIST: a simple and effective KD method.
Updates
-
December 27, 2022: Update CIFAR-100 distillation code and logs.
-
September 20, 2022: Release code for semantic segmentation task.
-
September 15, 2022: DIST was accepted by NeurIPS 2022!
-
May 30, 2022: Code for object detection is available.
-
May 27, 2022: Code for ImageNet classification is available.
Getting started
Clone training code
git clone https://github.com/hunto/DIST_KD.git --recurse-submodules
cd DIST_KD
The loss function of DIST is in classification/lib/models/losses/dist_kd.py.
- classification: prepare your environment and datasets following the
README.md
inclassification
. - object detection: coming soon.
- semantic segmentation: coming soon.
Reproducing our results
ImageNet
cd classification
sh tools/dist_train.sh 8 ${CONFIG} ${MODEL} --teacher-model ${T_MODEL} --experiment ${EXP_NAME}
-
Baseline settings (
R34-R101
andR50-MBV1
):CONFIG=configs/strategies/distill/resnet_dist.yaml
Student Teacher DIST MODEL T_MODEL Log Ckpt ResNet-18 (69.76) ResNet-34 (73.31) 72.07 tv_resnet18
tv_resnet34
log ckpt MobileNet V1 (70.13) ResNet-50 (76.16) 73.24 mobilenet_v1
tv_resnet50
log ckpt -
Stronger teachers (
R18
andR34
students with various ResNet teachers):Student Teacher KD (T=4) DIST ResNet-18 (69.76) ResNet-34 (73.31) 71.21 72.07 ResNet-18 (69.76) ResNet-50 (76.13) 71.35 72.12 ResNet-18 (69.76) ResNet-101 (77.37) 71.09 72.08 ResNet-18 (69.76) ResNet-152 (78.31) 71.12 72.24 ResNet-34 (73.31) ResNet-50 (76.13) 74.73 75.06 ResNet-34 (73.31) ResNet-101 (77.37) 74.89 75.36 ResNet-34 (73.31) ResNet-152 (78.31) 74.87 75.42 -
Stronger training strategies:
CONFIG=configs/strategies/distill/dist_b2.yaml
ResNet-50-SB
: stronger ResNet-50 trained by TIMM (ResNet strikes back) .Student Teacher KD (T=4) DIST MODEL T_MODEL Log ResNet-18 (73.4) ResNet-50-SB (80.1) 72.6 74.5 tv_resnet18
timm_resnet50
log ResNet-34 (76.8) ResNet-50-SB (80.1) 77.2 77.8 tv_resnet34
timm_resnet50
log MobileNet V2 (73.6) ResNet-50-SB (80.1) 71.7 74.4 tv_mobilenet_v2
timm_resnet50
log EfficientNet-B0 (78.0) ResNet-50-SB (80.1) 77.4 78.6 <details> timm_tf_efficientnet_b0
timm_resnet50
</details>log ResNet-50 (78.5) Swin-L (86.3) 80.0 80.2 tv_resnet50
<details> timm_swin_large_patch4_window7_224
</details>log ckpt Swin-T (81.3) Swin-L (86.3) 81.5 82.3 - - log Swin-L
student: We implement our DIST on the official code of Swin-Transformer.
CIFAR-100
Download and extract the teacher checkpoints to your disk, then specify the path of the corresponding checkpoint pth
file using --teacher-ckpt
:
cd classification
sh tools/dist_train.sh 1 configs/strategies/distill/dist_cifar.yaml ${MODEL} --teacher-model ${T_MODEL} --experiment ${EXP_NAME} --teacher-ckpt ${CKPT}
NOTE: For MobileNetV2
, ShuffleNetV1
, and ShuffleNetV2
, lr
and warmup-lr
should be 0.01
:
sh tools/dist_train.sh 1 configs/strategies/distill/dist_cifar.yaml ${MODEL} --teacher-model ${T_MODEL} --experiment ${EXP_NAME} --teacher-ckpt ${CKPT} --lr 0.01 --warmup-lr 0.01
Student | Teacher | DIST | MODEL | T_MODEL | Log |
---|---|---|---|---|---|
WRN-40-1 (71.98) | WRN-40-2 (75.61) | 74.43±0.24 | cifar_wrn_40_1 | cifar_wrn_40_2 | log |
ResNet-20 (69.06) | ResNet-56 (72.34) | 71.75±0.30 | cifar_resnet20 | cifar_resnet56 | log |
ResNet-8x4 (72.50) | ResNet-32x4 (79.42) | 76.31±0.19 | cifar_resnet8x4 | cifar_resnet32x4 | log |
MobileNetV2 (64.60) | ResNet-50 (79.34) | 68.66±0.23 | cifar_mobile_half | cifar_ResNet50 | log |
ShuffleNetV1 (70.50) | ResNet-32x4 (79.42) | 76.34±0.18 | cifar_ShuffleV1 | cifar_resnet32x4 | log |
ShuffleNetV2 (71.82) | ResNet-32x4 (79.42) | 77.35±0.25 | cifar_ShuffleV2 | cifar_resnet32x4 | log |
COCO Detection
The training code is in MasKD/mmrazor. An example to train cascade_mask_rcnn_x101-fpn_r50
:
sh tools/mmdet/dist_train_mmdet.sh configs/distill/dist/dist_cascade_mask_rcnn_x101-fpn_x50_coco.py 8 work_dirs/dist_cmr_x101-fpn_x50
Student | Teacher | DIST | DIST+mimic | Config | Log |
---|---|---|---|---|---|
Faster RCNN-R50 (38.4) | Cascade Mask RCNN-X101 (45.6) | 40.4 | 41.8 | [DIST] [DIST+Mimic] | [DIST] [DIST+Mimic] |
RetinaNet-R50 (37.4) | RetinaNet-X101 (41.0) | 39.8 | 40.1 | [DIST] [DIST+Mimic] | [DIST] [DIST+Mimic] |
Cityscapes Segmentation
Detailed instructions of reproducing our results are in segmentation
folder (README).
Student | Teacher | DIST | Log |
---|---|---|---|
DeepLabV3-R18 (74.21) | DeepLabV3-R101 (78.07) | 77.10 | log |
PSPNet-R18 (72.55) | DeepLabV3-R101 (78.07) | 76.31 | log |
License
This project is released under the Apache 2.0 license.
Citation
@article{huang2022knowledge,
title={Knowledge Distillation from A Stronger Teacher},
author={Huang, Tao and You, Shan and Wang, Fei and Qian, Chen and Xu, Chang},
journal={arXiv preprint arXiv:2205.10536},
year={2022}
}