Awesome
MAE-CT: Masked Autoencoder Contrastive Tuning
[Project Page] [arXiv] [Models] [BibTeX] [Follow-up work (MIM-Refiner)]
Pytorch implementation of Masked AutoEncoder Contrastive Tuning (MAE-CT) from our paper <br/> Contrastive Tuning: A Little Help to Make Masked Autoencoders Forget.
<p align="center"> <img width="31.5%" alt="maect_schematic" src="https://github.com/ml-jku/MAE-CT/blob/137d969be4c78d156465bb18d09f52d3c762114f/.github/schematic_contrastive_tuning.svg"> <img width="67%" alt="lowshot_vitl" src="https://github.com/ml-jku/MAE-CT/blob/2ff19e68df9c3a1a7cb17a1846ac0d937359392c/.github/lowshot_aug_L_white.svg"> </p>This repository provides:
- Pretrained checkpoints for
MAE,
MAE-CT and
MAE-CT<sub>aug</sub>
- Hyperparameters for all checkpoints
- Linear probes trained on the respective checkpoints + logs
- Instructions to generate low-shot datasets for evaluation
- Instructions on how to use our models as backbone
Pretrained Checkpoints
MAE reimplementation
Weights | Pretrain | Probe | Probe | k-NN |
---|---|---|---|---|
ViT-B/16 | hp | 66.7 | log | 51.1 |
ViT-L/16 | hp | 75.9 | log | 60.6 |
ViT-H/16 | hp | 78.0 | log | 61.1 |
ViT-H/14 | original | 77.2 | log | 58.9 |
MAE-CT
Encoder | Pretrain | Probe | Probe | k-NN |
---|---|---|---|---|
ViT-B/16 | hp | 73.5 | log | 64.1 |
ViT-L/16 | hp | 80.2 | log | 78.0 |
ViT-H/16 | hp | 81.5 | log | 79.4 |
ViT-H/14 | hp | 81.3 | log | 79.1 |
MAE-CT<sub>aug</sub>
Encoder | Pretrain | Probe | Probe | k-NN |
---|---|---|---|---|
ViT-B/16 | hp | 76.9 | log | 73.4 |
ViT-L/16 | hp | 81.5 | log | 79.1 |
ViT-H/16 | hp | 82.2 | log | 79.8 |
ViT-H/14 | hp | 82.0 | log | 78.9 |
Reproducability
- Models can be trained using the hyperparameters provided here. Examples how to start training runs can be found here.
- We provide instructions for reproducing our probing result in PROBING.md.
Use checkpoints as backbone for other tasks
The script eval_probe.py
demonstrates how one can load our models from a checkpoint and use it for a downstream task.
The script extracts the features of the encoder and feeds it to a linear probe as task, but the code can be adjusted
for other downstream tasks as well.
Setup
Setup a conda environment: conda env create --file environment_linux.yml --name maect
We use FlashAttention (paper) to greatly accelerate computations. We recommend to install it, but this repo can also be used without FlashAttention (without modification).
Configuration of dataset paths and environment specific things
cp template_static_config.yaml static_config.yaml
- edit values in
static_config.yaml
to your setup
For low-shot evaluations, we use the official splits from SimCLRv2 and MSN.
To generate these ImageNet subsets we use the ImageNetSubsetGenerator repository.
[Optional] Configure Weights & Biases
This repo uses Weights & Biases for experiment tracking, but offers an alternative in case you do
not want to use it. By default W&B logging is disabled via the default_wandb_mode: disabled
configuration in the static_config.yaml
. You can enable it by setting default_wandb_mode: online
in
static_config.yaml
or via the CLI --wandb_mode online
.
If you enabled W&B logging, the W&B entity and project will (by default) be fetched from the wandb_config.yaml
. You
can create this via cp template_wandb_config.yaml wandb_config.yaml
and adjust the values to your setup.
Run
To run your own experiments or reproduce our results you have to specify the desired hyperparameters via a yaml file.
Start the training/evaluation run by specifying the following CLI arguments for main_train.py
--hp <YAML>
(e.g.--hp yamls/mae/base16.yaml
)--devices <DEVICES>
(e.g.--devices 0
to run on GPU0 or--devices 0,1,2,3
to run on 4 GPUs)
Example: Train MAE with ViT-B/16 on 4 GPUs: python main_train.py --hp yamls/mae/base16.yaml --devices 0,1,2,3
Output
Each yaml file will create a folder in your output directory (defined via output_path
in static_config.yaml
). The
output directory is structured into subdirectories with the stage_name
and the stage_id
. Example:
~/output_path/pretrain/9j3kl092
The output directory of each run is organized as follows:
checkpoints
: Model weights will be stored here (choose interval by adjusting the values of thecheckpoint_logger
in the yaml file of a run)primitive
: All metrics that are written to Weights & Biases are also stored locally here. If you don't want to use W&B you can parse metrics from the files within this directory.log.txt
: logfilehp_resolved.yaml
: a copy of the yaml file that was specified in the--hp
CLI arg
The yamls used for our paper can be found here. Each step of MAE-CT
requires its own yaml file where the later steps require a reference to a checkpoint of a previous step. This can be
defined by changing the stage_id
of the initializer
objects within the yaml.
Examples
Train models
- Pretrain a MAE on 8 GPUs (stage 1): <br/>
python main_train.py --hp yamls/stage1_mae/large16.yaml --devices 0,1,2,3,4,5,6,7
- Train a NNCLR head on frozen encoder features (stage 2) with 8 GPUs:
- change the
stage_id
of theinitializer
in the encoder to thestage_id
from stage 1 python main_train.py --hp yamls/stage2_maect_prepare_head/large16.yaml --devices 0,1,2,3,4,5,6,7
- change the
- Apply contrastive tuning (stage 3) with 8 GPUs:
- change the
stage_id
of theinitializer
in the encoder and the nnclr head to thestage_id
from stage 2 python main_train.py --hp yamls/stage3_maect_contrastive_tuning/large16.yaml --devices 0,1,2,3,4,5,6,7
- change the
Evaluate pretrained models
- Adapt the
initializer
ofyamls_probe.yaml
to the model you want to evaluate python main_train.py --hp yamls/probe.yaml --devices 0,1,2,3
Citation
If you find this repository useful, please consider giving it a star :star: and cite us
@article{lehner2023maect,
title={Contrastive Tuning: A Little Help to Make Masked Autoencoders Forget},
author={Johannes Lehner and Benedikt Alkin and Andreas Fürst and Elisabeth Rumetshofer and Lukas Miklautz and Sepp Hochreiter},
journal={arXiv preprint arXiv:2304.10520},
year={2023}
}