Home

Awesome

MAGMA: Manifold Regularization for MAEs

This repository contains the code for our paper MAGMA: Manifold Regularization for MAEs (Alin Dondera, Anuj Singh, Hadi Jamali-Rad).

MAGMA is a novel regularization technique that enhances self-supervised representation learning within masked autoencoders (MAE). The core of our approach lies in applying a manifold-based loss term that encourages consistency and smoothness between representations across different layers of the network.

Teaser

This codebase builds upon the solo-learn repository.

Abstract

Masked Autoencoders (MAEs) are an important divide in self-supervised learning (SSL) due to their independence from augmentation techniques for generating positive (and/or negative) pairs as in contrastive frameworks. Their masking and reconstruction strategy also nicely aligns with SSL approaches in natural language processing. Most MAEs are built upon Transformer-based architectures where visual features are not regularized as opposed to their convolutional neural network (CNN) based counterparts, which can potentially hinder their performance. To address this, we introduce \magma{}, a novel batch-wide layer-wise regularization loss applied to representations of different Transformer layers. We demonstrate that by plugging in the proposed regularization loss, one can significantly improve the performance of MAE-based models. We further demonstrate the impact of the proposed loss on optimizing other generic SSL approaches (such as VICReg and SimCLR), broadening the impact of the proposed approach.

Installation

For installaing the environment follow the steps outlined in solo-learn's README

Usage

Data Preparation

Prepare your datasets (e.g., ImageNet, CIFAR-100) following the instructions in the solo-learn repository.

Training

To train an MAE model with MAGMA regularization, use the following command:

python main_pretrain.py --config-path scripts/pretrain/imagenet --config-name mae-reg-uniformity.yaml 

You can modify the configuration file to adjust hyperparameters, dataset paths, and other settings.

Important parameters:

Results

MethodCIFAR-100 (linear)CIFAR-100 (k-nn)STL-10 (linear)STL-10 (k-nn)Tiny-ImageNet (linear)Tiny-ImageNet (k-nn)ImageNet-100 (linear)ImageNet-100 (k-nn)
MAE38.236.666.562.017.817.758.047.5
M-MAE (ours)43.340.771.065.920.920.569.049.8
U-MAE45.345.974.972.121.519.069.556.8
MU-MAE (ours)46.446.475.673.025.223.973.460.1
SimCLR62.858.790.486.950.943.567.865.3
M-SimCLR (ours)63.259.490.586.951.044.668.765.6
VICReg63.660.887.484.545.240.568.462.1
M-VICReg (ours)64.761.987.484.545.840.570.465.1

Table notes: Linear probing accuracy and k-nn accuracy (k=10) of models pre-trained and evaluated on the given datasets. Adding the proposed regularisation term to the baseline method generally increases performance.

Key takeaways from the table:

Qualitative results

These figures provide visual insights into how MAGMA enhances MAE models.

Less Noise, Better Features

Cute cat This shows the most important features learned by different ViT-B models (using PCA). You'll notice M-MAE (trained with MAGMA) has significantly less noise in its features compared to the baseline MAE, especially in the initial and final layers. This means MAGMA helps the model learn cleaner, more meaningful information.

Focused Attention = Better Understanding

Bridge Attention maps reveal where the model focuses when looking at an image. This figure compares attention maps from different models:

Citation

@misc{dondera2024magmamanifoldregularizationmaes,
      title={MAGMA: Manifold Regularization for MAEs}, 
      author={Alin Dondera and Anuj Singh and Hadi Jamali-Rad},
      year={2024},
      eprint={2412.02871},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2412.02871}, 
}

Contact

Corresponding author: Alin Dondera (dondera.alin@gmail.com)

Acknowledgements

This codebase builds upon the excellent work of the solo-learn repository. We thank the authors for their valuable contribution to the self-supervised learning community.

License

This project is licensed under the MIT License.