Awesome
learning useful representations for shifting tasks and distributions
Official Pytorch implementation of paper
<p align="center" width="500"> <image src='figures/story.png'/> </p>Requirements
- python==3.7
- torch>=1.13.1
- torchvision>=0.14.1
- pyyaml==6.0
- classy-vision==0.6.0
- gdown>=5.2.0
Datasets
We consider the following datasets:
Download and extract ImageNet
and Inaturalist18
datasets to data/imagenet
and data/inaturalist18
. The resulting folder structure should be:
π¦ RRL
β£ πdata
β β£ πimagenet
β β£ πinaturalist18
Supervised transfer learning (ResNet)
Download (ImageNet1k) pretrained checkpoints:
You can get pretrained checkpoints either:
- by automatically download according to
python tools/download.py
or - by manually download according to download_checkpoint.md or
- by training from scratch according to download_checkpoint.md
The resulting folder structure should be:
π¦ RRL
β£ πcheckpoints
β β£ πsupervised_pretrain
β β β£ πresnet50
β β β β£π checkpoint_run0.pth.tar
β β β β ...
β β β βπ checkpoint_run9.pth.tar
β β β£π 2resnet50_imagenet1k_supervised.pth.tar
β β β£π 4resnet50_imagenet1k_supervised.pth.tar
β β β£π resnet50w2_imagenet1k_supervised.pth.tar
β β βπ resnet50w4_imagenet1k_supervised.pth.tar
β β βπ resnet50_imagenet1k_supervised_distill5.pth.tar
Transfer by Linear Probing, Fine-Tuning, and Two-stage Fine-Tuning:
Transfer the learned representation (on ImageNet1k) to Cifar10, Cifar100, and Inaturalist18 by:
- Linear Probing: concatenate these representation and learn a big linear classifier on top.
- (Normal) Fine tuning: concatenate pretrained representations then fine tuning all weights.
- (Two-stage) Fine tuning: fine-tune each pretrained representation on target tasks separately, then concatenate the representation and apply linear probing.
The following table provides scripts for these transfer learning experiments:
method | architecture | target task | linear probing | fine-tuning | two-stage fine-tuning |
---|---|---|---|---|---|
ERM | resnet50 | Cifar10/Cifar100 | scripts | scripts | - |
ERM | resnet50w2/w4 2x/4xresnet50 | Cifar10/Cifar100 | scripts | scripts | - |
CAT | - | Cifar10/Cifar100 | scripts | scripts | scripts |
Distill | resnet50 | Cifar10/Cifar100 | scripts | scripts | - |
ERM | resnet50 | Inaturalist18 | scripts | scripts | - |
ERM | resnet50w2/w4 2x/4xresnet50 | Inaturalist18 | scripts | scripts | - |
CAT | - | Inaturalist18 | scripts | scripts | scripts |
Distill | resnet50 | Inaturalist18 | scripts | scripts | - |
The following figure shows (focus on solid curves) the transfer learning performance of different representations (ERM / CAT / Distill) and transfer methods (pinear probing / fine-tuning / two-stage fine-tuning).
<p align="center"> <image src='figures/imagenet_sl_tf_v3.png'/> </p> <p align="center"> <em> Fig1: Supervised transfer learning from ImageNet to Inat18, Cifar100, and Cifar10. The top row shows the superior linear probing performance of the CATn networks (blue, βcatβ). The bottom row shows the performance of fine-tuned CATn, which is poor with normal fine-tuning (gray, β[init]catβ) and excellent for two-stage fine tuning (blue, β[2ft]catβ). DISTILLn (pink, βdistillβ) representation is obtained by distilling CATn into one ResNet50. </em> </p>Supervised transfer learning (ViT)
Download (Imagenet21k) pretrained & (ImageNet1k) finetuned ViT checkpoints according to download_checkpoint.md
The resulting folder structure looks like:
π¦ RRL
β£ πcheckpoints
β β£ πsupervised_pretrain
β β β£ πvit
β β β β£π vitaugreg/imagenet21k/ViT-B_16.npz
β β β β£π vitaugreg/imagenet21k/ViT-L_16.npz
β β β β£π vit/imagenet21k/ViT-B_16.npz
β β β βπ vit/imagenet21k/ViT-L_16.npz
β β β£π vitaugreg/imagenet21k/imagenet2012/ViT-L_16.npz
β β β£π vitaugreg/imagenet21k/imagenet2012/ViT-L_16.npz
β β β£π vit/imagenet21k/imagenet2012/ViT-L_16.npz
β β β£π vit/imagenet21k/imagenet2012/ViT-L_16.npz
With the same experiment protocol as Tab1, we can have the following transfer learning curves with Vision Transformer:
<p align="center"> <image src='figures/vit_tf_v3.png' width="500"/> </p> <p align="center"> <em> Fig2: </em> </p>self-supervised transfer learning
Download SWAV and SEER checkpoints according to download_checkpoint.md
The resulting folder structure looks like:
π¦ RRL
β£ πcheckpoints
β β£ πself_supervised_pretrain
β β β£π swav_400ep_pretrain.pth.tar
β β β£π swav_RN50w2_400ep_pretrain.pth.tar
β β β£π swav_RN50w4_400ep_pretrain.pth.tar
β β β£π swav_RN50w5_400ep_pretrain.pth.tar
β β β£π swav_400ep_pretrain_seed5.pth.tar
β β β£π swav_400ep_pretrain_seed6.pth.tar
β β β£π swav_400ep_pretrain_seed7.pth.tar
β β β£π swav_400ep_pretrain_seed8.pth.tar
β β β£π seer_regnet32gf.pth
β β β£π seer_regnet64gf.pth
β β β£π seer_regnet128gf.pth
β β β£π seer_regnet256gf.pth
β β β£π seer_regnet32gf_finetuned.pth
β β β£π seer_regnet64gf_finetuned.pth
β β β£π seer_regnet128gf_finetuned.pth
β β β£π seer_regnet256gf_finetuned.pth
With the same experiment protocol as Tab1, we can have the following self-supervised transfer learning curves:
<p align="center"> <image src='figures/ssl_tf.png'/> </p> <p align="center"> <em> Fig2: Self-supervised transfer learning with SWAV trained on unlabeled ImageNet(1K) (top row) and with SEER on Instagram1B (bottom row). The constructed rich representation, CATn, yields the best linear probing performance (βcatβ and βcatsubβ) for supervised ImageNet, INAT18, CIFAR100, and CIFAR10 target tasks. The two-stage fine-tuning (β[2ft]catβ) matches equivalently sized baseline models (β[init]wideβ and β[init]wide&deepβ), but with much easier training. The sub-networks of CAT5 (and CAT2) in SWAV hold the same architecture </em> </p> <!-- ### Transfer by Linear Probing, Fine-Tuning, and Two-stage Fine-Tuning (SWAV pretrained ImageNet1k): -->Meta-learning & few-shots learning and Out-of-distribution generalization
<p align="center"> <image src='figures/meta_learning_full_v4.png' width="500"/> </p> <p align="center"> <em> Fig3: Few-shot learning performance on MINIIMAGENET and CUB. Four common few-shot learning algorithms are shown in red (results from Chen et al. (2019)(https://arxiv.org/abs/1904.04232)). Two supervised transfer methods, with either a linear classifier (BASELINE) or cosine- based classifier (BASELINE++) are shown in blue. The DISTILL and CAT results, with a cosine-base classifier, are respectively shown in orange and gray. The CAT5-S and DISTILL5-S results were obtained using five snapshots taken during a single training episode with a relatively high step size. The dark blue line shows the best individual snapshot. Standard deviations over five repeats are reported. </em> </p> <p align="center"> <image src='figures/ood_general.png' width="500"/> </p> <p align="center"> <em> Fig4: Test accuracy on the CAMELYON17 dataset with DENSENET121. We compare various initialization (ERM, CATn, DISTILLn, and Bonsai(https://arxiv.org/pdf/2203.15516.pdf)) for two algorithms VREX and ERM using either the IID or OOD hyperparameter tuning method. The standard deviations over 5 runs are reported. </em> </p>Citation
If you find this code useful for your research, please consider citing our work:
@inproceedings{zhang2023learning,
title={Learning useful representations for shifting tasks and distributions},
author={Zhang, Jianyu and Bottou, L{\'e}on},
booktitle={International Conference on Machine Learning},
pages={40830--40850},
year={2023},
organization={PMLR}
}