Awesome
Knowledge Distillation via the Target-aware Transformer (CVPR2022)
Codebase of our TaT on ImageNet. Refer to TaT-seg for the experiments on semantic segmentation.
Overview
Executable code can be found in examples/image_classification.py. The implementation of TaT is AttnEmbed. The loss function MaskedFM is decoupled with the model.
Note
- This codebase currently do not support resume. However, it allows you to load a pre-trained model for specific purposes, i.e., distilling a contrastive learning model.
- The classification model is wrapped with the learnable KD parameters. Please be careful on the model parameters you want to save.
Customization
If you would like to customize your own model, please put all the learnable parameters on here. And you can set up the calculation of the loss funcion on here.
We use the Forward Hook to extract the intermediate representations. Just modify the yaml file to access the model layers of your interest. This example notebook will give you a better idea of the usage. You may refer to our config.
Examples
Requirments
- Python 3.7
- pytorch 1.5
- einops
- ml-collection
Before getting started
Please modify the ImageNet path of the config.
We use 8 GPUs with 256 images per GPU.
Training
sh ./train_local.sh
Testing
sh ./test_local.sh
Issues / Contact
Feel free to create an issue if you get a question or just email me ( sihao.lin@student.rmit.edu.au ).
Acknowledgement
This repo is built upon torchdistill. Thanks to Yoshitomo.