Home

Awesome

CRATE (Coding RAte reduction TransformEr)

This repository is the official PyTorch implementation of the papers:

Also, we have released a larger journal-length overview paper of this line of research, which contains a superset of all the results presented above, and also more results in NLP and vision SSL.

Table of Contents

Theoretical Background: What is CRATE?

CRATE (Coding RAte reduction TransformEr) is a white-box (mathematically interpretable) transformer architecture, where each layer performs a single step of an alternating minimization algorithm to optimize the sparse rate reduction objective

<p align="center"> <img src="figs/fig_objective.png" width="400"\> </p> <p align="center">

where $R$ and $R^{c}$ are different coding rates for the input representations w.r.t.~different codebooks, and the $\ell^{0}$-norm promotes the sparsity of the final token representations $\boldsymbol{Z} = f(\boldsymbol{X})$. The function $f$ is defined as $$f=f^{L} \circ f^{L-1} \circ \cdots \circ f^{1} \circ f^{\mathrm{pre}},$$ where $f^{\mathrm{pre}}$ is the pre-processing mapping, and $f^{\ell}$ is the $\ell$-th layer forward mapping that transforms the token distribution to optimize the above sparse rate reduction objective incrementally. More specifically, $f^{\ell}$ transforms the $\ell$-th layer token representations $\boldsymbol{Z}^{\ell}$ to $\boldsymbol{Z}^{\ell+1}$ via the $\texttt{MSSA}$ (Multi-Head Subspace Self-Attention) block and the $\texttt{ISTA}$ (Iterative Shrinkage-Thresholding Algorithms) block, i.e., $$\boldsymbol{Z}^{\ell+1} = f^{\ell}(\boldsymbol{Z}^{\ell}) = \texttt{ISTA}(\boldsymbol{Z}^{\ell} + \texttt{MSSA}(\boldsymbol{Z}^{\ell})).$$

1. CRATE Architecture overview

The following figure presents an overview of the pipeline for our proposed CRATE architecture:

<p align="center"> <img src="figs/fig_pipeline.png" width="900"\> </p> <p align="center">

2. One layer/block of CRATE

The following figure shows the overall architecture of one layer of CRATE as the composition of $\texttt{MSSA}$ and $\texttt{ISTA}$ blocks.

<p align="center"> <img src="figs/fig_arch.png" width="900"\> </p> <p align="center">

3. Per-layer optimization in CRATE

In the following figure, we measure the compression term [ $R^{c}$ ($\boldsymbol{Z}^{\ell+1/2}$) ] and the sparsity term [ $||\boldsymbol{Z}^{\ell+1}||_0$ ] defined in the sparse rate reduction objective, and we find that each layer of CRATE indeed optimizes the targeted objectives, showing that our white-box theoretical design is predictive of practice.

<p align="center"> <img src="figs/fig_layerwise.png" width="900"\> </p> <p align="center">

4. Segmentation visualization of CRATE

In the following figure, we visualize self-attention maps from a supervised CRATE model with 8x8 patches (similar to the ones shown in DINO :t-rex:).

<p align="center"> <img src="figs/fig_seg.png" width="900"\> </p> <p align="center">

We also discover a surprising empirical phenomenon where each attention head in CRATE retains its own semantics.

<p align="center"> <img src="figs/fig_seg_headwise.png" width="900"\> </p> <p align="center">

Autoencoding

We can also use our theory to build a principled autoencoder, which has the following architecture.

<p align="center"> <img src="figs/fig_arch_autoencoder.png" width="900"\> </p> <p align="center">

It has many of the same empirical properties as the base CRATE model, such as segmented attention maps and amenability to layer-wise analysis. We train it on the masked autoencoding task (calling this model CRATE-MAE), and it achieves comparable performance in linear probing and reconstruction quality as the base ViT-MAE.

<p align="center"> <img src="figs/fig_masked_reconstruction.png" width="900"\> </p> <p align="center">

Implementation and Experiments

Constructing a CRATE model

A CRATE model can be defined using the following code, (the below parameters are specified for CRATE-Tiny)

from model.crate import CRATE
dim = 384
n_heads = 6
depth = 12
model = CRATE(image_size=224,
              patch_size=16,
              num_classes=1000,
              dim=dim,
              depth=depth,
              heads=n_heads,
              dim_head=dim // n_heads)

Pre-trained Checkpoints (ImageNet-1K)

modeldimn_headsdepthpre-trained checkpoint
CRATE-T(iny)384612TODO
CRATE-S(mall)5761212download link
CRATE-B(ase)7681212TODO
CRATE-L(arge)10241624TODO

Training CRATE on ImageNet

To train a CRATE model on ImageNet-1K, run the following script (training CRATE-tiny)

As an example, we use the following command for training CRATE-tiny on ImageNet-1K:

python main.py 
  --arch CRATE_tiny 
  --batch-size 512 
  --epochs 200 
  --optimizer Lion 
  --lr 0.0002 
  --weight-decay 0.05 
  --print-freq 25 
  --data DATA_DIR

and replace DATA_DIR with [imagenet-folder with train and val folders].

Finetuning pretrained / training random initialized CRATE on CIFAR10

python finetune.py 
  --bs 256 
  --net CRATE_tiny 
  --opt adamW  
  --lr 5e-5 
  --n_epochs 200 
  --randomaug 1 
  --data cifar10 
  --ckpt_dir CKPT_DIR 
  --data_dir DATA_DIR

Replace CKPT_DIR with the path for the pretrained CRATE weight, and replace DATA_DIR with the path for the CIFAR10 dataset. If CKPT_DIR is None, then this script is for training CRATE from random initialization on CIFAR10.

Demo: Emergent segmentation in CRATE

CRATE models exhibit emergent segmentation in their self-attention maps solely through supervised training. We provide a Colab Jupyter notebook to visualize the emerged segmentations from a supervised CRATE model. The demo provides visualizations which match the segmentation figures above.

Link: crate-emergence.ipynb (in colab)

<p align="center"> <img src="figs/fig_seg_headwise.png" width="900"\> </p> <p align="center">

Constructing a CRATE autoencoding model

A CRATE-autoencoding model (specifically CRATE-MAE-Base) can be defined using the following code:

from model.crate_ae.crate_ae import mae_crate_base
model = mae_crate_base()

The other sizes in the paper are also importable in that way. Modifying the model/crate_ae/crate_ae.py file will let you initialize and serve your own config.

Pre-trained Checkpoints (ImageNet-1K)

modeldimn_headsdepthpre-trained checkpoint
CRATE-MAE-S(mall)5761212TODO
CRATE-MAE-B(ase)7681212link

Training/Fine-Tuning CRATE-MAE

To train or fine-tune a CRATE-MAE model on ImageNet-1K, please refer to the codebase on MAE training from Meta FAIR. The models_mae.py file in that codebase can be replaced with the contents of model/crate_ae/crate_ae.py, and the rest of the code should go through with minimal alterations.

Demo: Emergent segmentation in CRATE-MAE

CRATE-MAE models also exhibit emergent segmentation in their self-attention maps. We provide a Colab Jupyter notebook to visualize the emerged segmentations from a CRATE-MAE model. The demo provides visualizations which match the segmentation figures above.

Link: crate-mae.ipynb (in colab)

Reference

For technical details and full experimental results, please check the CRATE paper, CRATE segmentation paper, CRATE autoencoding paper, or the long-form overview paper. Please consider citing our work if you find it helpful to yours:

@article{yu2024white,
  title={White-Box Transformers via Sparse Rate Reduction},
  author={Yu, Yaodong and Buchanan, Sam and Pai, Druv and Chu, Tianzhe and Wu, Ziyang and Tong, Shengbang and Haeffele, Benjamin and Ma, Yi},
  journal={Advances in Neural Information Processing Systems},
  volume={36},
  year={2024}
}
@inproceedings{yu2024emergence,
  title={Emergence of Segmentation with Minimalistic White-Box Transformers},
  author={Yu, Yaodong and Chu, Tianzhe and Tong, Shengbang and Wu, Ziyang and Pai, Druv and Buchanan, Sam and Ma, Yi},
  booktitle={Conference on Parsimony and Learning},
  pages={72--93},
  year={2024},
  organization={PMLR}
}
@inproceedings{pai2024masked,
  title={Masked Completion via Structured Diffusion with White-Box Transformers},
  author={Pai, Druv and Buchanan, Sam and Wu, Ziyang and Yu, Yaodong and Ma, Yi},
  booktitle={The Twelfth International Conference on Learning Representations},
  year={2024}
}
@article{yu2023white,
  title={White-Box Transformers via Sparse Rate Reduction: Compression Is All There Is?},
  author={Yu, Yaodong and Buchanan, Sam and Pai, Druv and Chu, Tianzhe and Wu, Ziyang and Tong, Shengbang and Bai, Hao and Zhai, Yuexiang and Haeffele, Benjamin D and Ma, Yi},
  journal={arXiv preprint arXiv:2311.13110},
  year={2023}
}