Home

Awesome

Mixed Barlow Twins for Self-Supervised Representation Learning

Guarding Barlow Twins Against Overfitting with Mixed Samples<br>

arXiv Hugging Face Model Card

Wele Gedara Chaminda Bandara (Johns Hopkins University), Celso M. De Melo (U.S. Army Research Laboratory), and Vishal M. Patel (Johns Hopkins University) <br>

1 Overview of Mixed Barlow Twins

TL;DR

<img src="figs/mix-bt.svg" width="1024">

$C^{MA} = (Z^M)^TZ^A$

$C^{MB} = (Z^M)^TZ^B$

$C^{MA}_{gt} = \lambda (Z^A)^TZ^A + (1-\lambda)\mathtt{Shuffle}^*(Z^B)^TZ^A$

$C^{MB}_{gt} = \lambda (Z^A)^TZ^B + (1-\lambda)\mathtt{Shuffle}^*(Z^B)^TZ^B$

2 Usage

2.1 Requirements

Before using this repository, make sure you have the following prerequisites installed:

You can install PyTorch with the following command (in Linux OS):

conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia

2.2 Installation

To get started, clone this repository:

git clone https://github.com/wgcban/mix-bt.git

Next, create the conda environment named ssl-aug by executing the following command:

conda env create -f environment.yml

All the train-val-test statistics will be automatically upload to wandb, and please refer wandb-quick-start documentation if you are not familiar with using wandb.

2.3 Supported Pre-training Datasets

This repository supports the following pre-training datasets:

CIFAR-10, CIFAR-100, and STL-10 datasets are directly available in PyTorch.

To use TinyImageNet, please follow the preprocessing instructions provided in the TinyImageNet-Script. Download these datasets and place them in the data directory.

2.4 Supported Transfer Learning Datasets

You can download and place transfer learning datasets under their respective paths, such as 'data/DTD'. The supported transfer learning datasets include:

2.5 Supported SSL Methods

This repository supports the following Self-Supervised Learning (SSL) methods:

2.6 Pre-Training with Mixed Barlow Twins

To start pre-training and obtain k-NN evaluation results for Mixed Barlow Twins on CIFAR-10, CIFAR-100, TinyImageNet, and STL-10 with ResNet-18/50 backbones, please run:

sh scripts-pretrain-resnet18/[dataset].sh
sh scripts-pretrain-resnet50/[dataset].sh

To start the pre-training on ImageNet with ResNet-50 backbone, please run:

sh scripts-pretrain-resnet18/imagenet.sh

2.7 Linear Evaluation of Pre-trained Models

Before running linear evaluation, ensure that you specify the model_path argument correctly in the corresponding .sh file.

To obtain linear evaluation results on CIFAR-10, CIFAR-100, TinyImageNet, STL-10 with ResNet-18/50 backbones, please run:

sh scripts-linear-resnet18/[dataset].sh
sh scripts-linear-resnet50/[dataset].sh

To obtain linear evaluation results on ImageNet with ResNet-50 backbone, please run:

sh scripts-linear-resnet50/imagenet_sup.sh

2.8 Transfer Learning of Pre-trained Models

To perform transfer learning from pre-trained models on CIFAR-10, CIFAR-100, and STL-10 to fine-grained classification datasets, execute the following command, making sure to specify the model_path argument correctly:

sh scripts-transfer-resnet18/[dataset]-to-x.sh

3 Pre-Trained Checkpoints

Download the pre-trained models from GitHub (Releases v1.0.0) and store them in checkpoints/. This repository provides pre-trained checkpoints for both ResNet-18 and ResNet-50 architectures.

3.1 ResNet-18 [CIFAR-10, CIFAR-100, TinyImageNet, and STL-10]

Dataset$d$$\lambda_{BT}$$\lambda_{reg}$Download Link to Pretrained ModelKNN Acc.Linear Acc.
CIFAR-1010240.00781254.04wdhbpcf_cifar10.pth90.5292.58
CIFAR-10010240.00781254.076kk7scz_cifar100.pth61.2569.31
TinyImageNet10240.00097654.002azq6fs_tiny_imagenet.pth38.1151.67
STL-1010240.00781252.0i7det4xq_stl10.pth88.9491.02

3.2 ResNet-50 [CIFAR-10, CIFAR-100, TinyImageNet, and STL-10]

Dataset$d$$\lambda_{BT}$$\lambda_{reg}$Download Link to Pretrained ModelKNN Acc.Linear Acc.
CIFAR-1010240.00781254.0v3gwgusq_cifar10.pth91.3993.89
CIFAR-10010240.00781254.0z6ngefw7_cifar100.pth64.3272.51
TinyImageNet10240.00097654.0kxlkigsv_tiny_imagenet.pth42.2151.84
STL-1010240.00781252.0pbknx38b_stl10.pth87.7991.70

3.3. ResNet-50 on ImageNet (300 epochs)

Setting: epochs = 300, $d$ = 8192, $\lambda_{BT}$ = 0.0051

$\lambda_{reg}$Linear Acc.Download Link to Pretrained ModelTrain LogDownload Link to Linear-Probed ModelVal. Log
0.0 (BT)71.33on0l4wl_resnet50.pthtrain_logcheckpoint_3tb4tcvp.pthval_log
0.002570.9l418b9zw_resnet50.pthtrain_logcheckpoint_09g7ytcz.pthval_log
0.171.613awtq23_resnet50.pthtrain_logcheckpoint_pgawzr4e.pthval_log
1.072.2 (best)3fb1op86_resnet50.pthtrain_logcheckpoint_wvi0hle8.pthval_log
2.072.15n9yqio0_resnet50.pthtrain_logcheckpoint_p9aeo8ga.pthval_log
3.072.0q03u2xjz_resnet50.pthtrain_logcheckpoint_00atvp6x.pthval_log

3.4. ResNet-50 on ImageNet (1000 epochs)

Setting: epochs = 1000, $d$ = 8192, $\lambda_{BT}$ = 0.0051, $\lambda_{reg}$=2.0

Linear Eval. Top1Linear Eval. Top5Download Link to Pretrained ModelTrain LogDownload Link to Linear-Probed ModelVal. Log
74.06 (best)91.474wpu8wmd_resnet50.pthtrain_logvfd2nu64_checkpoint.pthval_log

4 Training/Val Logs

3.1 Pre-trianing for 300 epochs

Logs are available on wandb and can access via following links:

Here we provide some training and validation (linear probing) statistics for Barlow Twins vs. Mixed Barlow Twins with ResNet-50 backbone on ImageNet:

<img src="figs/in-loss-bt.png" width="256"/> <img src="figs/in-loss-reg.png" width="256"/> <img src="figs/in-linear.png" width="256"/>

3.1 Pre-trianing for 1000 epochs

We also provide trianing-val statistics for our pre-trained model for 1000 epochs. <img src="figs/in-loss-bt-1000e.png" width="256"/> <img src="figs/in-loss-reg-1000e.png" width="256"/> <img src="figs/in-linear-1000e.png" width="256"/>

:fire: Access pre-training statistcis on wandb: wandb-imagenet-pretrain

5 Disclaimer

A large portion of the code is from Barlow Twins HSIC (for experiments on small datasets: CIFAR-10, CIFAR-100, TinyImageNet, and STL-10) and official implementation of Barlow Twins here (for experiments on ImageNet), which is a great resource for academic development.

Also, note that the implementation of SOTA methods (SimCLR, BYOL, and Witening-MSE) in ssl-sota are copied from Witening-MSE.

We would like to thank all of them for making their repositories publicly available for the research community. 🙏

6 Reference

If you feel our work is useful, please consider citing our work. Thanks!

@misc{bandara2023guarding,
      title={Guarding Barlow Twins Against Overfitting with Mixed Samples}, 
      author={Wele Gedara Chaminda Bandara and Celso M. De Melo and Vishal M. Patel},
      year={2023},
      eprint={2312.02151},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

7 License

This code is under MIT licence, you can find the complete file here.