Home

Awesome

<h1 align="center"> Stabilize the Latent Space for Image Autoregressive Modeling: A Unified Perspective (NeurIPS 2024) </h1> <div align="center">

arXiv  benchmark

</div>

FID_IS

Overview

The overview of DiGIT

We present DiGIT, an auto-regressive generative model performing next-token prediction in an abstract latent space derived from self-supervised learning (SSL) models. By employing K-Means clustering on the hidden states of the DINOv2 model, we effectively create a novel discrete tokenizer. This method significantly boosts image generation performance on ImageNet dataset, achieving an FID score of 4.59 for class-unconditional tasks and 3.39 for class-conditional tasks. Additionally, the model enhances image understanding, achieving a linear-probe accuracy of 80.3.

Experimental Results

Linear-Probe Accuracy on ImageNet

Methods# TokensFeatures# ParamsTop-1 Acc. $\uparrow$
iGPT-L32 $\times$ 3215361362M60.3
iGPT-XL64 $\times$ 6430726801M68.7
VIM+VQGAN32 $\times$ 321024650M61.8
VIM+dVAE32 $\times$ 321024650M63.8
VIM+ViT-VQGAN32 $\times$ 321024650M65.1
VIM+ViT-VQGAN32 $\times$ 3220481697M73.2
AIM16 $\times$ 1615360.6B70.5
DiGIT (Ours)16 $\times$ 161024219M71.7
DiGIT (Ours)16 $\times$ 161536732M80.3

Class-Unconditional Image Generation on ImageNet (Resolution: 256 $\times$ 256)

TypeMethods# Param# EpochFID $\downarrow$IS $\uparrow$
GANBigGAN70M-38.624.70
Diff.LDM395M-39.122.83
Diff.ADM554M-26.239.70
MIMMAGE200M160011.181.17
MIMMAGE463M16009.10105.1
MIMMaskGIT227M30020.742.08
MIMDiGIT (+MaskGIT)219M2009.0475.04
ARVQGAN214M20024.3830.93
ARDiGIT (+VQGAN)219M4009.1373.85
ARDiGIT (+VQGAN)732M2004.59141.29

Class-Conditional Image Generation on ImageNet (Resolution: 256 $\times$ 256)

TypeMethods# Param# EpochFID $\downarrow$IS $\uparrow$
GANBigGAN160M-6.95198.2
Diff.ADM554M-10.94101.0
Diff.LDM-4400M-10.56103.5
Diff.DiT-XL/2675M-9.62121.50
Diff.L-DiT-7B7B-6.09153.32
MIMCQR-Trans371M3005.45172.6
MIM+ARVAR310M2004.64-
MIM+ARVAR310M2003.60*257.5*
MIM+ARVAR600M2502.95*306.1*
MIMMAGVIT-v2307M10803.65200.5
ARVQVAE-213.5B-31.1145
ARRQ-Trans480M-15.7286.8
ARRQ-Trans3.8B-7.55134.0
ARViTVQGAN650M36011.2097.2
ARViTVQGAN1.7B3605.3149.9
MIMMaskGIT227M3006.18182.1
MIMDiGIT (+MaskGIT)219M2004.62146.19
ARVQGAN227M30018.6580.4
ARDiGIT (+VQGAN)219M4004.79142.87
ARDiGIT (+VQGAN)732M2003.39205.96

*: VAR is trained with classifier-free guidance while all the other models are not.

Checkpoints

The K-Means npy file and model checkpoints can be downloaded from:

ModelLink
HF weights🤗Huggingface

For the base model we use DINOv2-base and DINOv2-large for large size model. The VQGAN we use is the same as MAGE.

DiGIT
└── data/
    ├── ILSVRC2012
        ├── dinov2_base_short_224_l3
            ├── km_8k.npy
        ├── dinov2_large_short_224_l3
            ├── km_16k.npy
└── outputs/
    ├── base_8k_stage1
    ├── ...
└── models/
    ├── vqgan_jax_strongaug.ckpt
    ├── dinov2_vitb14_reg4_pretrain.pth
    ├── dinov2_vitl14_reg4_pretrain.pth

Preparation

Installation

  1. Download the code
git clone https://github.com/DAMO-NLP-SG/DiGIT.git
cd DiGIT
  1. Install fairseq via pip install fairseq.

Dataset Preparation

Download ImageNet dataset, and place it in your dataset dir $PATH_TO_YOUR_WORKSPACE/dataset/ILSVRC2012.

Tokenizer

Extract SSL features and save them as .npy files. Use the K-Means algorithm with faiss to compute the centroids. You can also utilize our pre-trained centroids available on Huggingface.

bash preprocess/run.sh

Training Scripts

Step1

Train a GPT model with a discriminative tokenizer. You can find the training scripts in scripts/train_stage1_ar.sh and the hyper-params are in config/stage1/dino_base.yaml. For class conditional generation configuration, see scripts/train_stage1_classcond.sh.

Step2

Train a pixel decoder (either AR model or NAR model) conditioned on the discriminative tokens. You can find the autoregressive training scripts in scripts/train_stage2_ar.sh and NAR training scripts in scripts/train_stage2_nar.sh.

A folder named outputs/EXP_NAME/checkpoints will be created to save the checkpoints. TensorBoard log files are saved at outputs/EXP_NAME/tb. Logs will be recorded in outputs/EXP_NAME/train.log.

You can monitor the training process using tensorboard --logdir=outputs/EXP_NAME/tb.

Sampling Scripts

First sampling discriminative tokens with scripts/infer_stage1_ar.sh. For the base model size, we recommend setting topk=200, and for a large model size, use topk=400.

Then run scripts/infer_stage2_ar.sh to sample VQ tokens based on the previously sampled discriminative tokens.

Generated tokens and synthesized images will be stored in a directory named outputs/EXP_NAME/results.

FID and IS evaluation

Prepare the ImageNet validation set for FID evaluation:

python prepare_imgnet_val.py --data_path $PATH_TO_YOUR_WORKSPACE/dataset/ILSVRC2012 --output_dir imagenet-val

Install the evaluation tool by running pip install torch-fidelity.

Execute the following command to evaluate FID:

python fairseq_user/eval_fid.py --results-path $IMG_SAVE_DIR --subset $GEN_SUBSET

Linear Probe training

bash scripts/train_stage1_linearprobe.sh

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If you find our project useful, hope you can star our repo and cite our work as follows.

@misc{zhu2024stabilize,
    title={Stabilize the Latent Space for Image Autoregressive Modeling: A Unified Perspective},
    author={Yongxin Zhu and Bocheng Li and Hang Zhang and Xin Li and Linli Xu and Lidong Bing},
    year={2024},
    eprint={2410.12490},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}