Home

Awesome

[NeurIPS 2024] U-DiTs: Downsample Tokens in U-Shaped Diffusion Transformers

<p align="left"> <a href="https://arxiv.org/abs/2405.02730" alt="arXiv"> <img src="https://img.shields.io/badge/arXiv-2405.02730-b31b1b.svg?style=flat" /></a> <a href="https://huggingface.co/yuchuantian/U-DiT/tree/main" alt="Hugging Face Models"> <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue" /></a> <a href="https://www.modelscope.cn/models/YuchuanTian/U-DiT/files" alt="ModelScope Models"> <img src="https://img.shields.io/badge/ModelScope-Models-blue" /></a> <a href="https://colab.research.google.com/drive/17ZimD7GdK2ZZHRg52_I9PNxDTs0LKd20?usp=sharing" alt="ModelScope Models"> <img src="https://colab.research.google.com/assets/colab-badge.svg" /></a> </p>

This is the official implementation of "U-DiTs: Downsample Tokens in U-Shaped Diffusion Transformers".

9/30/2024: U-DiT is cited by Playground V3!

9/26/2024: U-DiT is accepted to NeurIPS 2024!๐ŸŽ‰๐ŸŽ‰๐ŸŽ‰ See you in Vancouver!

scheme

Outline

๐Ÿค” In this work, we rethink "Could U-Net arch boost DiTs?"

๐Ÿ˜ฎ Self-attention with downsampling reduces cost by ~3/4, but improves U-Net performance.

๐Ÿฅณ We develop a series of powerful U-DiTs.

๐Ÿš€ U-DiT-B could outcompete DiT-XL/2 with only 1/6 of its FLOPs.

effect

Preparation

Please run command pip install -r requirements.txt to install the supporting packages.

(Optional) Please download the VAE from this link. The VAE could be automatically downloaded as well.

Training

Here we provide two ways to train a U-DiT model: 1. train on the original ImageNet dataset; 2. train on preprocessed VAE features (Recommended).

Training Data Preparation Use the original ImageNet dataset + VAE encoder. Firstly, download ImageNet as follows:

imagenet/
โ”œโ”€โ”€train/
โ”‚  โ”œโ”€โ”€ n01440764
โ”‚  โ”‚   โ”œโ”€โ”€ n01440764_10026.JPEG
โ”‚  โ”‚   โ”œโ”€โ”€ n01440764_10027.JPEG
โ”‚  โ”‚   โ”œโ”€โ”€ ......
โ”‚  โ”œโ”€โ”€ ......
โ”œโ”€โ”€val/
โ”‚  โ”œโ”€โ”€ n01440764
โ”‚  โ”‚   โ”œโ”€โ”€ ILSVRC2012_val_00000293.JPEG
โ”‚  โ”‚   โ”œโ”€โ”€ ILSVRC2012_val_00002138.JPEG
โ”‚  โ”‚   โ”œโ”€โ”€ ......
โ”‚  โ”œโ”€โ”€ ......

Then run the following command:

torchrun --nnodes=1 --nproc_per_node=8 train.py --data-path={path to imagenet/train} --image-size=256 --model={model name} --epochs={iteration//5000} # fp32 Training

accelerate launch --mixed_precision fp16 train_accelerate.py --data-path {path to imagenet/train} --image-size=256 --model={model name} --epochs={iteration//5000} # fp16 Training

Training Feature Preparation (RECOMMENDED)

Following Fast-DiT, it is recommended to load VAE features directly for faster training. You don't need to download the enormous ImageNet dataset (> 100G); instead, a much smaller "VAE feature" dataset (~21G for ImageNet 256x256) is available here on HuggingFace and MindScope. Please do the following steps:

  1. Download imagenet_feature.tar

  2. Unzip the tar ball by running tar -xf imagenet_feature.tar

imagenet_feature/
โ”œโ”€โ”€ imagenet256_features/ # VAE features
โ””โ”€โ”€ imagenet256_labels/ # labels
  1. Append parser --feature-path={path to imagenet_feature} to the training command.

Inference

Weights Available

๐Ÿ”ฅ We released our models via HuggingFace and ModelScope. Please feel free to download them!

Sampling

Run the following command for parallel sampling:

torch --nnodes=1 --nproc_per_node=8 sample_ddp.py --ckpt={path to checkpoint} --image-size=256 --model={model name} --cfg-scale={cfg scale}

After sampling, an .npz file that contains 50000 images is automatically generated.

Metric Evaluation

We borrow the FID evaluation codes from here. Metrics including FIDs are calculated based on the .npz file. Before evaluation, make sure to download the reference batch for ImageNet 256x256. Then run the following command for metric evaluation:

python evaluator.py {path to reference batch} {path to generated .npz}

Future work (Stay Tuned!)

BibTex Formatted Citation

If you find this repo useful, please cite:

@misc{tian2024udits,
      title={U-DiTs: Downsample Tokens in U-Shaped Diffusion Transformers}, 
      author={Yuchuan Tian and Zhijun Tu and Hanting Chen and Jie Hu and Chao Xu and Yunhe Wang},
      year={2024},
      eprint={2405.02730},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgement

We acknowledge the authors of the following repos:

https://github.com/facebookresearch/DiT (Codebase)

https://github.com/chuanyangjin/fast-DiT (FP16 training; Training on features)

https://github.com/openai/guided-diffusion (Metric evalutation)

https://huggingface.co/stabilityai/sd-vae-ft-ema (VAE)