Home

Awesome

ImageFolderπŸš€: Autoregressive Image Generation with Folded Tokens

<div align="center">

project pageΒ  arXivΒ  huggingface weightsΒ 

</div> <!-- <p align="center" style="font-size: larger;"> <a href="placeholder">πŸ”₯ImageFolder: Autoregressive Image Generation with Folded Tokens</a> </p> --> <p align="center"> <div align=center> <img src=assets/teaser.png/> </div>

Updates

Model Zoo

We provide pre-trained tokenizers for image reconstruction on ImageNet.

TrainingEvalCodebook SizerFID ↓LinkResolutionUtilization
ImageNetImageNet40960.80Huggingface256x256100%
ImageNetImageNet81920.70Huggingface256x256100%
ImageNetImageNet163840.67Huggingface256x256100%

We provide a pre-trained generator for class-conditioned image generation on ImageNet 256x256 resolution.

TypeDatasetModel SizegFID ↓LinkResolution
VARImageNet362M2.60Huggingface256x256

Installation

Install all packages as

conda env create -f environment.yml

Dataset

We download the ImageNet2012 from the website and collect it as

ImageNet2012
β”œβ”€β”€ train
└── val

If you want to train or finetune on other datasets, collect them in the format that ImageFolder (pytorch's ImageFolder) can recognize.

Dataset
β”œβ”€β”€ train
β”‚   β”œβ”€β”€ Class1
β”‚   β”‚   β”œβ”€β”€ 1.png
β”‚   β”‚   └── 2.png
β”‚   β”œβ”€β”€ Class2
β”‚   β”‚   β”œβ”€β”€ 1.png
β”‚   β”‚   └── 2.png
β”œβ”€β”€ val

Training code for tokenizer

Please login to Wandb first using

wandb login

rFID will be automatically evaluated and reported on Wandb. The best checkpoint on the val set will be saved.

torchrun --nproc_per_node=8 tokenizer/tokenizer_image/msvq_train.py --config configs/tokenizer.yaml

Please modify the configuration file as needed for your specific dataset. We list some important ones here.

vq_ckpt: ckpt_best.pt                # resume
cloud_save_path: output/exp-xx       # output dir
data_path: ImageNet2012/train        # training set dir
val_data_path: ImageNet2012/val      # val set dir
enc_tuning_method: 'full'            # ['full', 'lora', 'frozen']
dec_tuning_method: 'full'            # ['full', 'lora', 'frozen']
codebook_embed_dim: 32               # codebook dim
codebook_size: 4096                  # codebook size
product_quant: 2                     # branch number
codebook_drop: 0.1                   # quantizer dropout rate
semantic_guide: dinov2               # ['none', 'dinov2']

Tokenizer linear probing

torchrun --nproc_per_node=8 tokenizer/tokenizer_image/linear_probing.py --config configs/tokenizer.yaml

Training code for VAR

We follow the VAR training code and our training cmd for reproducibility is

torchrun --nproc_per_node=8 train.py --bs=768 --alng=1e-4 --fp16=1 --alng=1e-4 --wpe=0.01 --tblr=8e-5 --data_path /mnt/localssd/ImageNet2012/ --encoder_model vit_base_patch14_dinov2.lvd142m --decoder_model vit_base_patch14_dinov2.lvd142m --product_quant 2 --semantic_guide dinov2 --num_latent_tokens 121 --v_patch_nums 1 1 2 3 3 4 5 6 8 11 --pn 1_1_2_3_3_4_5_6_8_11 --patch_size 11 --vae_ckpt /path/to/ckpt.pt --sem_half True 

Inference code for ImageFolder

torchrun --nproc_per_node=8 inference.py --infer_ckpt /path/to/ckpt --data_path /path/to/ImageNet --encoder_model vit_base_patch14_dinov2.lvd142m --decoder_model vit_base_patch14_dinov2.lvd142m --product_quant 2 --semantic_guide dinov2 --num_latent_tokens 121 --v_patch_nums 1 1 2 3 3 4 5 6 8 11 --pn 1_1_2_3_3_4_5_6_8_11 --patch_size 11 --sem_half True --cfg 3.25 3.25 --top_k 750 --top_p 0.95

Ablation

IDMethodLengthrFID ↓gFID ↓ACC ↑
πŸ”Ά1Multi-scale residual quantization (Tian et al., 2024)6801.927.52-
πŸ”Ά2+ Quantizer dropout6801.716.03-
πŸ”Ά3+ Smaller patch size K = 112653.246.56-
πŸ”Ά4+ Product quantization & Parallel decoding2652.065.96-
πŸ”Ά5+ Semantic regularization on all branches2651.975.21-
πŸ”Ά6+ Semantic regularization on one branch2651.573.5340.5
πŸ”·7+ Stronger discriminator2651.042.9450.2
πŸ”·8+ Equilibrium enhancement2650.802.6058.0

πŸ”Ά1-6 are already in the released paper, and after that πŸ”·7+ are advanced training settings used similar to VAR (gFID 3.30).

Generation

<div align=center> <img src=assets/visualization.png/> </div>

License

Adobe Research License

Acknowledge

We would like to thank the following repositories: LlamaGen, VAR and ControlVAR.

Citation

If our work assists your research, feel free to give us a star ⭐ or cite us using

@misc{li2024imagefolderautoregressiveimagegeneration,
      title={ImageFolder: Autoregressive Image Generation with Folded Tokens}, 
      author={Xiang Li and Hao Chen and Kai Qiu and Jason Kuen and Jiuxiang Gu and Bhiksha Raj and Zhe Lin},
      year={2024},
      eprint={2410.01756},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2410.01756}, 
}