Home

Awesome

learning useful representations for shifting tasks and distributions

Official Pytorch implementation of paper

Jianyu Zhang, LΓ©on Bottou

<p align="center" width="500"> <image src='figures/story.png'/> </p>

Requirements

Datasets

We consider the following datasets:

Download and extract ImageNet and Inaturalist18 datasets to data/imagenet and data/inaturalist18. The resulting folder structure should be:

πŸ“¦ RRL
 ┣ πŸ“‚data
 ┃ ┣ πŸ“‚imagenet
 ┃ ┣ πŸ“‚inaturalist18

Supervised transfer learning (ResNet)

Download (ImageNet1k) pretrained checkpoints:

You can get pretrained checkpoints either:

The resulting folder structure should be:

πŸ“¦ RRL
 ┣ πŸ“‚checkpoints
 ┃ ┣ πŸ“‚supervised_pretrain
 ┃ ┃ ┣ πŸ“‚resnet50
 ┃ ┃ ┃ β”£πŸ“œ checkpoint_run0.pth.tar 
 ┃ ┃ ┃ ┃ ...            
 ┃ ┃ ┃ β”—πŸ“œ checkpoint_run9.pth.tar 
 ┃ ┃ β”£πŸ“œ 2resnet50_imagenet1k_supervised.pth.tar
 ┃ ┃ β”£πŸ“œ 4resnet50_imagenet1k_supervised.pth.tar
 ┃ ┃ β”£πŸ“œ resnet50w2_imagenet1k_supervised.pth.tar
 ┃ ┃ β”—πŸ“œ resnet50w4_imagenet1k_supervised.pth.tar
 ┃ ┃ β”—πŸ“œ resnet50_imagenet1k_supervised_distill5.pth.tar

Transfer by Linear Probing, Fine-Tuning, and Two-stage Fine-Tuning:

Transfer the learned representation (on ImageNet1k) to Cifar10, Cifar100, and Inaturalist18 by:

The following table provides scripts for these transfer learning experiments:

methodarchitecturetarget tasklinear probingfine-tuningtwo-stage fine-tuning
ERMresnet50Cifar10/Cifar100scriptsscripts-
ERMresnet50w2/w4 2x/4xresnet50Cifar10/Cifar100scriptsscripts-
CAT-Cifar10/Cifar100scriptsscriptsscripts
Distillresnet50Cifar10/Cifar100scriptsscripts-
ERMresnet50Inaturalist18scriptsscripts-
ERMresnet50w2/w4 2x/4xresnet50Inaturalist18scriptsscripts-
CAT-Inaturalist18scriptsscriptsscripts
Distillresnet50Inaturalist18scriptsscripts-
<p align="center"> <em> Tab1: transfer learning experiments scripts. </em> </p>

The following figure shows (focus on solid curves) the transfer learning performance of different representations (ERM / CAT / Distill) and transfer methods (pinear probing / fine-tuning / two-stage fine-tuning).

<p align="center"> <image src='figures/imagenet_sl_tf_v3.png'/> </p> <p align="center"> <em> Fig1: Supervised transfer learning from ImageNet to Inat18, Cifar100, and Cifar10. The top row shows the superior linear probing performance of the CATn networks (blue, β€œcat”). The bottom row shows the performance of fine-tuned CATn, which is poor with normal fine-tuning (gray, β€œ[init]cat”) and excellent for two-stage fine tuning (blue, β€œ[2ft]cat”). DISTILLn (pink, β€œdistill”) representation is obtained by distilling CATn into one ResNet50. </em> </p>

Supervised transfer learning (ViT)

Download (Imagenet21k) pretrained & (ImageNet1k) finetuned ViT checkpoints according to download_checkpoint.md

The resulting folder structure looks like:

πŸ“¦ RRL
 ┣ πŸ“‚checkpoints
 ┃ ┣ πŸ“‚supervised_pretrain
 ┃ ┃ ┣ πŸ“‚vit
 ┃ ┃ ┃ β”£πŸ“œ vitaugreg/imagenet21k/ViT-B_16.npz
 ┃ ┃ ┃ β”£πŸ“œ vitaugreg/imagenet21k/ViT-L_16.npz
 ┃ ┃ ┃ β”£πŸ“œ vit/imagenet21k/ViT-B_16.npz
 ┃ ┃ ┃ β”—πŸ“œ vit/imagenet21k/ViT-L_16.npz
 ┃ ┃ β”£πŸ“œ vitaugreg/imagenet21k/imagenet2012/ViT-L_16.npz
 ┃ ┃ β”£πŸ“œ vitaugreg/imagenet21k/imagenet2012/ViT-L_16.npz
 ┃ ┃ β”£πŸ“œ vit/imagenet21k/imagenet2012/ViT-L_16.npz
 ┃ ┃ β”£πŸ“œ vit/imagenet21k/imagenet2012/ViT-L_16.npz

With the same experiment protocol as Tab1, we can have the following transfer learning curves with Vision Transformer:

<p align="center"> <image src='figures/vit_tf_v3.png' width="500"/> </p> <p align="center"> <em> Fig2: </em> </p>

self-supervised transfer learning

Download SWAV and SEER checkpoints according to download_checkpoint.md

The resulting folder structure looks like:

πŸ“¦ RRL
 ┣ πŸ“‚checkpoints
 ┃ ┣ πŸ“‚self_supervised_pretrain
 ┃ ┃ β”£πŸ“œ swav_400ep_pretrain.pth.tar
 ┃ ┃ β”£πŸ“œ swav_RN50w2_400ep_pretrain.pth.tar
 ┃ ┃ β”£πŸ“œ swav_RN50w4_400ep_pretrain.pth.tar
 ┃ ┃ β”£πŸ“œ swav_RN50w5_400ep_pretrain.pth.tar
 ┃ ┃ β”£πŸ“œ swav_400ep_pretrain_seed5.pth.tar
 ┃ ┃ β”£πŸ“œ swav_400ep_pretrain_seed6.pth.tar
 ┃ ┃ β”£πŸ“œ swav_400ep_pretrain_seed7.pth.tar
 ┃ ┃ β”£πŸ“œ swav_400ep_pretrain_seed8.pth.tar
 ┃ ┃ β”£πŸ“œ seer_regnet32gf.pth
 ┃ ┃ β”£πŸ“œ seer_regnet64gf.pth
 ┃ ┃ β”£πŸ“œ seer_regnet128gf.pth
 ┃ ┃ β”£πŸ“œ seer_regnet256gf.pth
 ┃ ┃ β”£πŸ“œ seer_regnet32gf_finetuned.pth
 ┃ ┃ β”£πŸ“œ seer_regnet64gf_finetuned.pth
 ┃ ┃ β”£πŸ“œ seer_regnet128gf_finetuned.pth
 ┃ ┃ β”£πŸ“œ seer_regnet256gf_finetuned.pth

With the same experiment protocol as Tab1, we can have the following self-supervised transfer learning curves:

<p align="center"> <image src='figures/ssl_tf.png'/> </p> <p align="center"> <em> Fig2: Self-supervised transfer learning with SWAV trained on unlabeled ImageNet(1K) (top row) and with SEER on Instagram1B (bottom row). The constructed rich representation, CATn, yields the best linear probing performance (β€œcat” and β€œcatsub”) for supervised ImageNet, INAT18, CIFAR100, and CIFAR10 target tasks. The two-stage fine-tuning (β€œ[2ft]cat”) matches equivalently sized baseline models (β€œ[init]wide” and β€œ[init]wide&deep”), but with much easier training. The sub-networks of CAT5 (and CAT2) in SWAV hold the same architecture </em> </p> <!-- ### Transfer by Linear Probing, Fine-Tuning, and Two-stage Fine-Tuning (SWAV pretrained ImageNet1k): -->

Meta-learning & few-shots learning and Out-of-distribution generalization

<p align="center"> <image src='figures/meta_learning_full_v4.png' width="500"/> </p> <p align="center"> <em> Fig3: Few-shot learning performance on MINIIMAGENET and CUB. Four common few-shot learning algorithms are shown in red (results from Chen et al. (2019)(https://arxiv.org/abs/1904.04232)). Two supervised transfer methods, with either a linear classifier (BASELINE) or cosine- based classifier (BASELINE++) are shown in blue. The DISTILL and CAT results, with a cosine-base classifier, are respectively shown in orange and gray. The CAT5-S and DISTILL5-S results were obtained using five snapshots taken during a single training episode with a relatively high step size. The dark blue line shows the best individual snapshot. Standard deviations over five repeats are reported. </em> </p> <p align="center"> <image src='figures/ood_general.png' width="500"/> </p> <p align="center"> <em> Fig4: Test accuracy on the CAMELYON17 dataset with DENSENET121. We compare various initialization (ERM, CATn, DISTILLn, and Bonsai(https://arxiv.org/pdf/2203.15516.pdf)) for two algorithms VREX and ERM using either the IID or OOD hyperparameter tuning method. The standard deviations over 5 runs are reported. </em> </p>

Citation

If you find this code useful for your research, please consider citing our work:

@inproceedings{zhang2023learning,
  title={Learning useful representations for shifting tasks and distributions},
  author={Zhang, Jianyu and Bottou, L{\'e}on},
  booktitle={International Conference on Machine Learning},
  pages={40830--40850},
  year={2023},
  organization={PMLR}
}