Awesome
AIM: Autoregressive Image Models
Alaaeldin El-Nouby, Michal Klein, Shuangfei Zhai, Miguel Angel Bautista, Alexander Toshev, Vaishaal Shankar, Joshua M Susskind, and Armand Joulin
To appear at ICML 2024
This software project accompanies the research paper, Scalable Pre-training of Large Autoregressive Image Models.
We introduce AIM a collection of vision models pre-trained with an autoregressive generative objective. We show that autoregressive pre-training of image features exhibits similar scaling properties to their textual counterpart (i.e. Large Language Models). Specifically, we highlight two findings:
- the model capacity can be trivially scaled to billions of parameters, and
- AIM effectively leverages large collections of uncurated image data.
Installation
Please install PyTorch using the official installation instructions. Afterward, install the package as:
pip install git+https://git@github.com/apple/ml-aim.git
We also offer MLX backend support for research and experimentation on Apple silicon. To enable MLX support, simply run:
pip install mlx
Usage
Below we provide an example of usage in PyTorch:
from PIL import Image
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="torch")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
logits, features = model(inp)
<details>
<summary>and in both <a href="https://ml-explore.github.io/mlx/">MLX</a></summary>
from PIL import Image
import mlx.core as mx
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="mlx")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
inp = mx.array(inp.numpy())
logits, features = model(inp)
</details>
<details>
<summary>and <a href="https://jax.readthedocs.io/">JAX</a></summary>
from PIL import Image
import jax.numpy as jnp
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model, params = load_pretrained("aim-600M-2B-imgs", backend="jax")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
inp = jnp.array(inp)
(logits, features), _ = model.apply(params, inp, mutable=['batch_stats'])
</details>
Pre-trained checkpoints
The pre-trained models can be accessed via PyTorch Hub as:
import torch
aim_600m = torch.hub.load("apple/ml-aim", "aim_600M")
aim_1b = torch.hub.load("apple/ml-aim", "aim_1B")
aim_3b = torch.hub.load("apple/ml-aim", "aim_3B")
aim_7b = torch.hub.load("apple/ml-aim", "aim_7B")
or via HuggingFace Hub as:
from aim.torch.models import AIMForImageClassification
aim_600m = AIMForImageClassification.from_pretrained("apple/aim-600M")
aim_1b = AIMForImageClassification.from_pretrained("apple/aim-1B")
aim_3b = AIMForImageClassification.from_pretrained("apple/aim-3B")
aim_7b = AIMForImageClassification.from_pretrained("apple/aim-7B")
Pre-trained backbones
The following table contains pre-trained backbones used in our paper.
<table style="margin: auto"> <thead> <tr> <th>model</th> <th>#params</th> <th>attn (best layer)</th> <th>backbone, SHA256</th> </tr> </thead> <tbody> <tr> <td>AIM-0.6B</td> <td>0.6B</td> <td>79.4%</td> <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_600m_2bimgs_attnprobe_backbone.pth">link</a>, 0d6f6b8f</td> </tr> <tr> <td>AIM-1B</td> <td>1B</td> <td>82.3%</td> <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_1b_5bimgs_attnprobe_backbone.pth">link</a>, d254ecd3</td> </tr> <tr> <td>AIM-3B</td> <td>3B</td> <td>83.3%</td> <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_3b_5bimgs_attnprobe_backbone.pth">link</a>, 8475ce4e</td> </tr> <tr> <td>AIM-7B</td> <td>7B</td> <td>84.0%</td> <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_7b_5bimgs_attnprobe_backbone.pth">link</a>, 184ed94c</td> </tr> </tbody> </table>Pre-trained attention heads
The table below contains the classification results on ImageNet-1k validation set.
<table style="margin: auto"> <thead> <tr> <th rowspan="2">model</th> <th colspan="2">top-1 IN-1k</th> <th colspan="2">attention head, SHA256</th> </tr> <tr> <th>last layer</th> <th>best layer</th> <th>last layer</th> <th>best layer</th> </tr> </thead> <tbody> <tr> <td>AIM-0.6B</td> <td>78.5%</td> <td>79.4%</td> <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_600m_2bimgs_attnprobe_head_last_layers.pth">link</a>, 5ce5a341</td> <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_600m_2bimgs_attnprobe_head_best_layers.pth">link</a>, ebd45c05</td> </tr> <tr> <td>AIM-1B</td> <td>80.6%</td> <td>82.3%</td> <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_1b_5bimgs_attnprobe_head_last_layers.pth">link</a>, db3be2ad</td> <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_1b_5bimgs_attnprobe_head_best_layers.pth">link</a>, f1ed7852</td> </tr> <tr> <td>AIM-3B</td> <td>82.2%</td> <td>83.3%</td> <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_3b_5bimgs_attnprobe_head_last_layers.pth">link</a>, 5c057b30</td> <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_3b_5bimgs_attnprobe_head_best_layers.pth">link</a>, ad380e16</td> </tr> <tr> <td>AIM-7B</td> <td>82.4%</td> <td>84.0%</td> <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_7b_5bimgs_attnprobe_head_last_layers.pth">link</a>, 1e5c99ba</td> <td><a href="https://huggingface.co/apple/AIM/resolve/main/aim_7b_5bimgs_attnprobe_head_best_layers.pth">link</a>, 73ecd732</td> </tr> </tbody> </table>Reproducing the IN-1k classification results
The commands below reproduce the attention probe results on ImageNet-1k validation set. We run the evaluation using 1 node with 8 GPUs:
torchrun --standalone --nnodes=1 --nproc-per-node=8 main_attnprobe.py \
--model=aim-7B \
--batch-size=64 \
--data-path=/path/to/imagenet \
--probe-layers=best \
--backbone-ckpt-path=/path/to/backbone_ckpt.pth \
--head-ckpt-path=/path/to/head_ckpt.pth
By default, we probe features from the intermediate 6 layers that provide the best performance. To change this, simply pass --probe-layers=last
.
Citation
If you find our work useful, please consider citing us as:
@article{el2024scalable,
title={Scalable Pre-training of Large Autoregressive Image Models},
author={El-Nouby, Alaaeldin and Klein, Michal and Zhai, Shuangfei and Bautista, Miguel Angel and Toshev, Alexander and Shankar, Vaishaal and Susskind, Joshua M and Joulin, Armand},
journal={International Conference on Machine Learning},
year={2024}
}