Awesome
XTRA: Sample- and Parameter-Efficient Auto-Regressive Image Models
Official PyTorch (Lightning) implementation and pretrained/finetuned models of the paper Sample- and Parameter-Efficient Auto-Regressive Image Models.
Installation
conda env create -f environment.yml
Method
XTRA is a self-supervised auto-regressive vision model that leverages a Block Causal Mask to enhance sample and parameter efficiency. Empirical results demonstrate that this approach enables XTRA to learn abstract and semantically meaningful representations using less data and smaller model sizes. More specifically:
- XTRA is sample efficient. Although trained on 152x fewer samples (13.1M vs. 2B), XTRA ViT-H/14 outperforms the previous state-of-the-art auto-regressive model of the same size in top-1 average accuracy across 15 diverse image recognition benchmarks.
- XTRA is parameter efficient. XTRA ViT-B/16 outperforms auto-regressive models trained on ImageNet-1k in linear and attentive probing tasks, while using 7–16$x fewer parameters (85M vs. 1.36B/0.63B).
Repository Structure
.
├── configs # '.yaml' configs
│ ├── ablations # ablation configs
│ ├── eval # evaluation configs
│ ├── pretrain # pre-training configs
├── scripts # scheduler scripts
│ ├── lsf # LSF scheduler scripts
│ ├── slurm # Slurm scheduler scripts
├── src # source files
│ ├── modules # moodule implemtations
│ ├── utils # shared utilities
│ ├── config.py # config class
│ ├── dataset.py # datasets and data loaders
│ ├── train.py # training script
Multi-GPU Training
ViT-B/16:
sbatch -J xtra_b_in1k ./scripts/slurm/train.sh ./config/pretrain/xtra_b_in1k_pt.yaml
ViT-H/14:
sbatch -J xtra_h_in21k ./scripts/slurm/train.sh ./config/pretrain/xtra_h_in21k_pt.yaml
Single-GPU Fine-tuning
See ./config/eval/ for the various models and datasets. Run with the following command:
python src/train.py --config <chosen_config_file> --pretrained <path_to_pretrained_model>