Home

Awesome

Matching Guided Distillation

Project Webpage | Paper | Zhihu Blog [知乎]

Updates

Introduction

This implementation is based on the official PyTorch ImageNet training code, which supports two training modes DataParallel (DP) and DistributedDataParallel (DDP). MGD for object detection is also re-implemented in Detectron2 as an external project.

introfig

Note: T : teacher feature tensors. S : student feature tensors. dp : distance function for distillation. Ci: i-th channel.

BibTex

@inproceedings{eccv20mgd,
    title     = {Matching Guided Distillation},
    author    = {Yue, Kaiyu and Deng, Jiangfan and Zhou, Feng},
    booktitle = {European Conference on Computer Vision (ECCV)},
    year      = {2020}
}

Software Version Used for Paper

Quick & Easy Start

We take using ResNet-50 to distill ResNet-18 as an example, as shown in the below figure.

<div align="center"> <img src=".github/demo.png", width="333"> </div>

Note: models are from torchvision.

0. Install Dependencies

Install OR-Tools by pip install ortools.

1. Expose Intermediate Features

The function exposes intermediate features and final output logits. The only thing to do is copy the original forward context and expose any tensors you want to work with for distillation. Reference.

def extract_feature(self, x, preReLU=False):
    ...

    feat3 = self.layer3(x) # we expose layer3 output

    x = self.layer4(feat3)

    ...

    if not preReLU:
        feat3 = F.relu(feat3)

    return [feat3], x

2. Expose BN

The function exposes BN layers before the distillation position. Reference.

def get_bn_before_relu(self):
    if isinstance(self.layer1[0], Bottleneck):
        bn3 = self.layer3[-1].bn3
    elif isinstance(self.layer1[0], BasicBlock):
        bn3 = self.layer3[-1].bn2
    else:
        print('ResNet unknown block error !!!')
        raise
    
    return [bn3]

3. Indicate Channel Number

The function tells MGD the channel number of the intermediate feature maps. Reference.

def get_channel_num(self):
    return [1024]

4. Build Model

t_net = resnet50() # teacher model
s_net = resnet18() # student model

import mgd.builder
d_net = mgd.builder.MGDistiller(
    t_net,
    s_net,
    ignore_inds=[],
    reducer='amp',
    sync_bn=False,
    with_kd=True,
    preReLU=True,
    distributed=False, # DP mode: False | DDP mode: True
    det=False # work within Detectron2
)

5. Add MGD Steps In Training Procedure

Reference.

# init mgd params in the first start
mgd_update(train_loader, d_net)

# training loop
for epoch in range(total_epochs):

    # UPDATE_FREQ can be set by yourself
    if (epoch+1)%UPDATE_FREQ == 0:
        mgd_update(train_loader, d_net)

MGD In Tasks

Classification | Object Detecton | Unsupervised Learning.

Acknowledgements

We learn and use some part of codes from following projects. We thank these excellent works:

License

MIT. See LICENSE for details.