Home

Awesome

MAE-CT: Masked Autoencoder Contrastive Tuning

[Project Page] [arXiv] [Models] [BibTeX] [Follow-up work (MIM-Refiner)]

Pytorch implementation of Masked AutoEncoder Contrastive Tuning (MAE-CT) from our paper <br/> Contrastive Tuning: A Little Help to Make Masked Autoencoders Forget.

<p align="center"> <img width="31.5%" alt="maect_schematic" src="https://github.com/ml-jku/MAE-CT/blob/137d969be4c78d156465bb18d09f52d3c762114f/.github/schematic_contrastive_tuning.svg"> <img width="67%" alt="lowshot_vitl" src="https://github.com/ml-jku/MAE-CT/blob/2ff19e68df9c3a1a7cb17a1846ac0d937359392c/.github/lowshot_aug_L_white.svg"> </p>

This repository provides:

Pretrained Checkpoints

MAE reimplementation

WeightsPretrainProbeProbek-NN
ViT-B/16hp66.7log51.1
ViT-L/16hp75.9log60.6
ViT-H/16hp78.0log61.1
ViT-H/14original77.2log58.9

MAE-CT

EncoderPretrainProbeProbek-NN
ViT-B/16hp73.5log64.1
ViT-L/16hp80.2log78.0
ViT-H/16hp81.5log79.4
ViT-H/14hp81.3log79.1

MAE-CT<sub>aug</sub>

EncoderPretrainProbeProbek-NN
ViT-B/16hp76.9log73.4
ViT-L/16hp81.5log79.1
ViT-H/16hp82.2log79.8
ViT-H/14hp82.0log78.9

Reproducability

Use checkpoints as backbone for other tasks

The script eval_probe.py demonstrates how one can load our models from a checkpoint and use it for a downstream task. The script extracts the features of the encoder and feeds it to a linear probe as task, but the code can be adjusted for other downstream tasks as well.

Setup

Setup a conda environment: conda env create --file environment_linux.yml --name maect

We use FlashAttention (paper) to greatly accelerate computations. We recommend to install it, but this repo can also be used without FlashAttention (without modification).

Configuration of dataset paths and environment specific things

For low-shot evaluations, we use the official splits from SimCLRv2 and MSN.

To generate these ImageNet subsets we use the ImageNetSubsetGenerator repository.

[Optional] Configure Weights & Biases

This repo uses Weights & Biases for experiment tracking, but offers an alternative in case you do not want to use it. By default W&B logging is disabled via the default_wandb_mode: disabled configuration in the static_config.yaml. You can enable it by setting default_wandb_mode: online in static_config.yaml or via the CLI --wandb_mode online.

If you enabled W&B logging, the W&B entity and project will (by default) be fetched from the wandb_config.yaml. You can create this via cp template_wandb_config.yaml wandb_config.yaml and adjust the values to your setup.

Run

To run your own experiments or reproduce our results you have to specify the desired hyperparameters via a yaml file. Start the training/evaluation run by specifying the following CLI arguments for main_train.py

Example: Train MAE with ViT-B/16 on 4 GPUs: python main_train.py --hp yamls/mae/base16.yaml --devices 0,1,2,3

Output

Each yaml file will create a folder in your output directory (defined via output_path in static_config.yaml). The output directory is structured into subdirectories with the stage_name and the stage_id. Example: ~/output_path/pretrain/9j3kl092

The output directory of each run is organized as follows:

The yamls used for our paper can be found here. Each step of MAE-CT requires its own yaml file where the later steps require a reference to a checkpoint of a previous step. This can be defined by changing the stage_id of the initializer objects within the yaml.

Examples

Train models

Evaluate pretrained models

Citation

If you find this repository useful, please consider giving it a star :star: and cite us

@article{lehner2023maect,
      title={Contrastive Tuning: A Little Help to Make Masked Autoencoders Forget}, 
      author={Johannes Lehner and Benedikt Alkin and Andreas Fürst and Elisabeth Rumetshofer and Lukas Miklautz and Sepp Hochreiter},
      journal={arXiv preprint arXiv:2304.10520},
      year={2023}
}